Improve server startup/shutdown
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Mon, 14 Mar 2022 15:34:23 +0000 (16:34 +0100)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Mon, 14 Mar 2022 18:38:17 +0000 (19:38 +0100)
Closes #9671

commands/server.go
commands/server_test.go

index 7d9462b36554f8891c74a1d708f01815f0140d99..89ece40ae9bd3a2b747f1ec4f6e7a8c4cf3d83dc 100644 (file)
@@ -15,6 +15,7 @@ package commands
 
 import (
        "bytes"
+       "context"
        "fmt"
        "io"
        "net"
@@ -32,6 +33,7 @@ import (
        "time"
 
        "github.com/gohugoio/hugo/common/paths"
+       "golang.org/x/sync/errgroup"
 
        "github.com/pkg/errors"
 
@@ -139,7 +141,7 @@ func (sc *serverCmd) server(cmd *cobra.Command, args []string) error {
 
        var serverCfgInit sync.Once
 
-       cfgInit := func(c *commandeer) error {
+       cfgInit := func(c *commandeer) (rerr error) {
                c.Set("renderToMemory", !sc.renderToDisk)
                if cmd.Flags().Changed("navigateToChanged") {
                        c.Set("navigateToChanged", sc.navigateToChanged)
@@ -162,15 +164,13 @@ func (sc *serverCmd) server(cmd *cobra.Command, args []string) error {
                        return nil
                }
 
-               var err error
-
                // We can only do this once.
                serverCfgInit.Do(func() {
                        serverPorts = make([]int, 1)
 
                        if c.languages.IsMultihost() {
                                if !sc.serverAppend {
-                                       err = newSystemError("--appendPort=false not supported when in multihost mode")
+                                       rerr = newSystemError("--appendPort=false not supported when in multihost mode")
                                }
                                serverPorts = make([]int, len(c.languages))
                        }
@@ -185,12 +185,14 @@ func (sc *serverCmd) server(cmd *cobra.Command, args []string) error {
                                } else {
                                        if i == 0 && sc.cmd.Flags().Changed("port") {
                                                // port set explicitly by user -- he/she probably meant it!
-                                               err = newSystemErrorF("Server startup failed: %s", err)
+                                               rerr = newSystemErrorF("Server startup failed: %s", err)
+                                               return
                                        }
                                        c.logger.Println("port", sc.serverPort, "already in use, attempting to use an available port")
                                        sp, err := helpers.FindAvailablePort()
                                        if err != nil {
-                                               err = newSystemError("Unable to find alternative port to use:", err)
+                                               rerr = newSystemError("Unable to find alternative port to use:", err)
+                                               return
                                        }
                                        serverPorts[i] = sp.Port
                                }
@@ -199,6 +201,10 @@ func (sc *serverCmd) server(cmd *cobra.Command, args []string) error {
                        }
                })
 
+               if rerr != nil {
+                       return
+               }
+
                c.serverPorts = serverPorts
 
                c.Set("port", sc.serverPort)
@@ -229,7 +235,7 @@ func (sc *serverCmd) server(cmd *cobra.Command, args []string) error {
                        }
                }
 
-               return err
+               return
        }
 
        if err := memStats(); err != nil {
@@ -506,9 +512,15 @@ func (c *commandeer) serve(s *serverCmd) error {
 
        sigs := make(chan os.Signal, 1)
        signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
+       var servers []*http.Server
 
        for i := range baseURLs {
                mu, serverURL, endpoint, err := srv.createEndpoint(i)
+               srv := &http.Server{
+                       Addr:    endpoint,
+                       Handler: mu,
+               }
+               servers = append(servers, srv)
 
                if doLiveReload {
                        u, err := url.Parse(helpers.SanitizeURL(baseURLs[i]))
@@ -521,8 +533,8 @@ func (c *commandeer) serve(s *serverCmd) error {
                }
                jww.FEEDBACK.Printf("Web Server is available at %s (bind address %s)\n", serverURL, s.serverInterface)
                go func() {
-                       err = http.ListenAndServe(endpoint, mu)
-                       if err != nil {
+                       err = srv.ListenAndServe()
+                       if err != nil && err != http.ErrServerClosed {
                                c.logger.Errorf("Error: %s\n", err.Error())
                                os.Exit(1)
                        }
@@ -542,7 +554,17 @@ func (c *commandeer) serve(s *serverCmd) error {
 
        c.hugo().Close()
 
-       return nil
+       ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+       defer cancel()
+       wg, ctx := errgroup.WithContext(ctx)
+       for _, srv := range servers {
+               srv := srv
+               wg.Go(func() error {
+                       return srv.Shutdown(ctx)
+               })
+       }
+
+       return wg.Wait()
 }
 
 // fixURL massages the baseURL into a form needed for serving
index 0806c57d033f70810f9e2bcf3d2abc389c3fa7cf..5b91ff9dbca949f79523964ed1c5028cca89f132 100644 (file)
@@ -25,6 +25,8 @@ import (
 
        "github.com/gohugoio/hugo/config"
        "github.com/gohugoio/hugo/helpers"
+       "golang.org/x/net/context"
+       "golang.org/x/sync/errgroup"
 
        qt "github.com/frankban/quicktest"
 )
@@ -107,14 +109,14 @@ func runServerTest(c *qt.C, config string, args ...string) (result serverTestRes
        defer clean()
        c.Assert(err, qt.IsNil)
 
-       // Let us hope that this port is available on all systems ...
-       port := 1331
+       sp, err := helpers.FindAvailablePort()
+       c.Assert(err, qt.IsNil)
+       port := sp.Port
 
        defer func() {
                os.RemoveAll(dir)
        }()
 
-       errors := make(chan error)
        stop := make(chan bool)
 
        b := newCommandsBuilder()
@@ -124,24 +126,26 @@ func runServerTest(c *qt.C, config string, args ...string) (result serverTestRes
        args = append([]string{"-s=" + dir, fmt.Sprintf("-p=%d", port)}, args...)
        cmd.SetArgs(args)
 
-       go func() {
+       ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+       defer cancel()
+       wg, ctx := errgroup.WithContext(ctx)
+
+       wg.Go(func() error {
                _, err := cmd.ExecuteC()
-               if err != nil {
-                       errors <- err
-               }
-       }()
+               return err
+       })
 
        select {
        // There is no way to know exactly when the server is ready for connections.
        // We could improve by something like https://golang.org/pkg/net/http/httptest/#Server
        // But for now, let us sleep and pray!
        case <-time.After(2 * time.Second):
-       case err := <-errors:
-               result.err = err
+       case <-ctx.Done():
+               result.err = wg.Wait()
                return
        }
 
-       resp, err := http.Get("http://localhost:1331/")
+       resp, err := http.Get(fmt.Sprintf("http://localhost:%d/", port))
        c.Assert(err, qt.IsNil)
        defer resp.Body.Close()
        homeContent := helpers.ReaderToString(resp.Body)
@@ -158,6 +162,7 @@ func runServerTest(c *qt.C, config string, args ...string) (result serverTestRes
                result.publicDirnames[f.Name()] = true
        }
 
+       result.err = wg.Wait()
        return
 
 }