commands: Improve server startup to make tests less flaky
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Fri, 18 Mar 2022 07:54:44 +0000 (08:54 +0100)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Mon, 21 Mar 2022 08:32:35 +0000 (09:32 +0100)
Do this by announcing/listen on the local address before we start the server.

commands/commandeer.go
commands/server.go
commands/server_test.go
helpers/general.go
magefile.go

index 507ec430dc0b143f466d7c3ad2d9675d2911b630..8a19258931ed3db818628aa6120ae393b04363fa 100644 (file)
@@ -17,6 +17,7 @@ import (
        "bytes"
        "errors"
        "io/ioutil"
+       "net"
        "os"
        "path/filepath"
        "regexp"
@@ -88,7 +89,8 @@ type commandeer struct {
        // Used in cases where we get flooded with events in server mode.
        debounce func(f func())
 
-       serverPorts         []int
+       serverPorts []serverPortListener
+
        languagesConfigured bool
        languages           langs.Languages
        doLiveReload        bool
@@ -105,6 +107,11 @@ type commandeer struct {
        buildErr error
 }
 
+type serverPortListener struct {
+       p  int
+       ln net.Listener
+}
+
 func newCommandeerHugoState() *commandeerHugoState {
        return &commandeerHugoState{
                created: make(chan struct{}),
@@ -420,6 +427,7 @@ func (c *commandeer) loadConfig() error {
                if h == nil || c.failOnInitErr {
                        err = createErr
                }
+
                c.hugoSites = h
                // TODO(bep) improve.
                if c.buildLock == nil && h != nil {
index 4dd9116c7297478f228548b8c1bb4d866fadb511..bb6a4e15d4581177489316f39146603e9e3438ae 100644 (file)
@@ -48,7 +48,7 @@ import (
 
 type serverCmd struct {
        // Can be used to stop the server. Useful in tests
-       stop <-chan bool
+       stop chan bool
 
        disableLiveReload bool
        navigateToChanged bool
@@ -70,7 +70,7 @@ func (b *commandsBuilder) newServerCmd() *serverCmd {
        return b.newServerCmdSignaled(nil)
 }
 
-func (b *commandsBuilder) newServerCmdSignaled(stop <-chan bool) *serverCmd {
+func (b *commandsBuilder) newServerCmdSignaled(stop chan bool) *serverCmd {
        cc := &serverCmd{stop: stop}
 
        cc.baseBuilderCmd = b.newBuilderCmd(&cobra.Command{
@@ -89,7 +89,13 @@ By default hugo will also watch your files for any changes you make and
 automatically rebuild the site. It will then live reload any open browser pages
 and push the latest content to them. As most Hugo sites are built in a fraction
 of a second, you will be able to save and see your changes nearly instantly.`,
-               RunE: cc.server,
+               RunE: func(cmd *cobra.Command, args []string) error {
+                       err := cc.server(cmd, args)
+                       if err != nil && cc.stop != nil {
+                               cc.stop <- true
+                       }
+                       return err
+               },
        })
 
        cc.cmd.Flags().IntVarP(&cc.serverPort, "port", "p", 1313, "port on which the server will listen")
@@ -130,8 +136,6 @@ func (f noDirFile) Readdir(count int) ([]os.FileInfo, error) {
        return nil, nil
 }
 
-var serverPorts []int
-
 func (sc *serverCmd) server(cmd *cobra.Command, args []string) error {
        // If a Destination is provided via flag write to disk
        destination, _ := cmd.Flags().GetString("destination")
@@ -166,22 +170,21 @@ func (sc *serverCmd) server(cmd *cobra.Command, args []string) error {
 
                // We can only do this once.
                serverCfgInit.Do(func() {
-                       serverPorts = make([]int, 1)
+                       c.serverPorts = make([]serverPortListener, 1)
 
                        if c.languages.IsMultihost() {
                                if !sc.serverAppend {
                                        rerr = newSystemError("--appendPort=false not supported when in multihost mode")
                                }
-                               serverPorts = make([]int, len(c.languages))
+                               c.serverPorts = make([]serverPortListener, len(c.languages))
                        }
 
                        currentServerPort := sc.serverPort
 
-                       for i := 0; i < len(serverPorts); i++ {
+                       for i := 0; i < len(c.serverPorts); i++ {
                                l, err := net.Listen("tcp", net.JoinHostPort(sc.serverInterface, strconv.Itoa(currentServerPort)))
                                if err == nil {
-                                       l.Close()
-                                       serverPorts[i] = currentServerPort
+                                       c.serverPorts[i] = serverPortListener{ln: l, p: currentServerPort}
                                } else {
                                        if i == 0 && sc.cmd.Flags().Changed("port") {
                                                // port set explicitly by user -- he/she probably meant it!
@@ -189,15 +192,15 @@ func (sc *serverCmd) server(cmd *cobra.Command, args []string) error {
                                                return
                                        }
                                        c.logger.Println("port", sc.serverPort, "already in use, attempting to use an available port")
-                                       sp, err := helpers.FindAvailablePort()
+                                       l, sp, err := helpers.TCPListen()
                                        if err != nil {
                                                rerr = newSystemError("Unable to find alternative port to use:", err)
                                                return
                                        }
-                                       serverPorts[i] = sp.Port
+                                       c.serverPorts[i] = serverPortListener{ln: l, p: sp.Port}
                                }
 
-                               currentServerPort = serverPorts[i] + 1
+                               currentServerPort = c.serverPorts[i].p + 1
                        }
                })
 
@@ -205,22 +208,20 @@ func (sc *serverCmd) server(cmd *cobra.Command, args []string) error {
                        return
                }
 
-               c.serverPorts = serverPorts
-
                c.Set("port", sc.serverPort)
                if sc.liveReloadPort != -1 {
                        c.Set("liveReloadPort", sc.liveReloadPort)
                } else {
-                       c.Set("liveReloadPort", serverPorts[0])
+                       c.Set("liveReloadPort", c.serverPorts[0].p)
                }
 
                isMultiHost := c.languages.IsMultihost()
                for i, language := range c.languages {
                        var serverPort int
                        if isMultiHost {
-                               serverPort = serverPorts[i]
+                               serverPort = c.serverPorts[i].p
                        } else {
-                               serverPort = serverPorts[0]
+                               serverPort = c.serverPorts[0].p
                        }
 
                        baseURL, err := sc.fixURL(language, sc.baseURL, serverPort)
@@ -320,10 +321,11 @@ func (f *fileServer) rewriteRequest(r *http.Request, toPath string) *http.Reques
        return r2
 }
 
-func (f *fileServer) createEndpoint(i int) (*http.ServeMux, string, string, error) {
+func (f *fileServer) createEndpoint(i int) (*http.ServeMux, net.Listener, string, string, error) {
        baseURL := f.baseURLs[i]
        root := f.roots[i]
-       port := f.c.serverPorts[i]
+       port := f.c.serverPorts[i].p
+       listener := f.c.serverPorts[i].ln
 
        publishDir := f.c.Cfg.GetString("publishDir")
 
@@ -353,7 +355,7 @@ func (f *fileServer) createEndpoint(i int) (*http.ServeMux, string, string, erro
        // We're only interested in the path
        u, err := url.Parse(baseURL)
        if err != nil {
-               return nil, "", "", errors.Wrap(err, "Invalid baseURL")
+               return nil, nil, "", "", errors.Wrap(err, "Invalid baseURL")
        }
 
        decorate := func(h http.Handler) http.Handler {
@@ -459,7 +461,7 @@ func (f *fileServer) createEndpoint(i int) (*http.ServeMux, string, string, erro
 
        endpoint := net.JoinHostPort(f.s.serverInterface, strconv.Itoa(port))
 
-       return mu, u.String(), endpoint, nil
+       return mu, listener, u.String(), endpoint, nil
 }
 
 var logErrorRe = regexp.MustCompile(`(?s)ERROR \d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2} `)
@@ -514,8 +516,10 @@ func (c *commandeer) serve(s *serverCmd) error {
        signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
        var servers []*http.Server
 
+       wg1, ctx := errgroup.WithContext(context.Background())
+
        for i := range baseURLs {
-               mu, serverURL, endpoint, err := srv.createEndpoint(i)
+               mu, listener, serverURL, endpoint, err := srv.createEndpoint(i)
                srv := &http.Server{
                        Addr:    endpoint,
                        Handler: mu,
@@ -532,13 +536,13 @@ func (c *commandeer) serve(s *serverCmd) error {
                        mu.HandleFunc(u.Path+"/livereload", livereload.Handler)
                }
                jww.FEEDBACK.Printf("Web Server is available at %s (bind address %s)\n", serverURL, s.serverInterface)
-               go func() {
-                       err = srv.ListenAndServe()
+               wg1.Go(func() error {
+                       err = srv.Serve(listener)
                        if err != nil && err != http.ErrServerClosed {
-                               c.logger.Errorf("Error: %s\n", err.Error())
-                               os.Exit(1)
+                               return err
                        }
-               }()
+                       return nil
+               })
        }
 
        jww.FEEDBACK.Println("Press Ctrl+C to stop")
@@ -556,15 +560,19 @@ func (c *commandeer) serve(s *serverCmd) error {
 
        ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
        defer cancel()
-       wg, ctx := errgroup.WithContext(ctx)
+       wg2, ctx := errgroup.WithContext(ctx)
        for _, srv := range servers {
                srv := srv
-               wg.Go(func() error {
+               wg2.Go(func() error {
                        return srv.Shutdown(ctx)
                })
        }
 
-       return wg.Wait()
+       err1, err2 := wg1.Wait(), wg2.Wait()
+       if err1 != nil {
+               return err1
+       }
+       return err2
 }
 
 // fixURL massages the baseURL into a form needed for serving
index 5b91ff9dbca949f79523964ed1c5028cca89f132..6972bbe690a939453bcd688dd557dedeee6e7877 100644 (file)
@@ -34,7 +34,7 @@ import (
 func TestServer(t *testing.T) {
        c := qt.New(t)
 
-       r := runServerTest(c, "")
+       r := runServerTest(c, true, "")
 
        c.Assert(r.err, qt.IsNil)
        c.Assert(r.homeContent, qt.Contains, "List: Hugo Commands")
@@ -51,7 +51,7 @@ func TestServerPanicOnConfigError(t *testing.T) {
 linenos='table'
 `
 
-       r := runServerTest(c, config)
+       r := runServerTest(c, false, config)
 
        c.Assert(r.err, qt.IsNotNil)
        c.Assert(r.err.Error(), qt.Contains, "cannot parse 'Highlight.LineNos' as bool:")
@@ -88,7 +88,7 @@ baseURL="https://example.org"
                                args = strings.Split(test.flag, "=")
                        }
 
-                       r := runServerTest(c, config, args...)
+                       r := runServerTest(c, true, config, args...)
 
                        test.assert(c, r)
 
@@ -104,7 +104,7 @@ type serverTestResult struct {
        publicDirnames map[string]bool
 }
 
-func runServerTest(c *qt.C, config string, args ...string) (result serverTestResult) {
+func runServerTest(c *qt.C, getHome bool, config string, args ...string) (result serverTestResult) {
        dir, clean, err := createSimpleTestSite(c, testSiteConfig{configTOML: config})
        defer clean()
        c.Assert(err, qt.IsNil)
@@ -135,34 +135,32 @@ func runServerTest(c *qt.C, config string, args ...string) (result serverTestRes
                return err
        })
 
-       select {
-       // There is no way to know exactly when the server is ready for connections.
-       // We could improve by something like https://golang.org/pkg/net/http/httptest/#Server
-       // But for now, let us sleep and pray!
-       case <-time.After(2 * time.Second):
-       case <-ctx.Done():
-               result.err = wg.Wait()
-               return
+       if getHome {
+               // Esp. on slow CI machines, we need to wait a little before the web
+               // server is ready.
+               time.Sleep(567 * time.Millisecond)
+               resp, err := http.Get(fmt.Sprintf("http://localhost:%d/", port))
+               c.Check(err, qt.IsNil)
+               if err == nil {
+                       defer resp.Body.Close()
+                       result.homeContent = helpers.ReaderToString(resp.Body)
+               }
        }
 
-       resp, err := http.Get(fmt.Sprintf("http://localhost:%d/", port))
-       c.Assert(err, qt.IsNil)
-       defer resp.Body.Close()
-       homeContent := helpers.ReaderToString(resp.Body)
-
-       // Stop the server.
-       stop <- true
-
-       result.homeContent = homeContent
+       select {
+       case <-stop:
+       case stop <- true:
+       }
 
        pubFiles, err := os.ReadDir(filepath.Join(dir, "public"))
-       c.Assert(err, qt.IsNil)
+       c.Check(err, qt.IsNil)
        result.publicDirnames = make(map[string]bool)
        for _, f := range pubFiles {
                result.publicDirnames[f.Name()] = true
        }
 
        result.err = wg.Wait()
+
        return
 
 }
index b5f6d0dba5fc11b267f17b7094ef0d09b739011e..e31bbfc9daf9a96f4fb48a2c15b0c8766651e832 100644 (file)
@@ -62,6 +62,21 @@ func FindAvailablePort() (*net.TCPAddr, error) {
        return nil, err
 }
 
+// TCPListen starts listening on a valid TCP port.
+func TCPListen() (net.Listener, *net.TCPAddr, error) {
+       l, err := net.Listen("tcp", ":0")
+       if err != nil {
+               return nil, nil, err
+       }
+       addr := l.Addr()
+       if a, ok := addr.(*net.TCPAddr); ok {
+               return l, a, nil
+       }
+       l.Close()
+       return nil, nil, fmt.Errorf("unable to obtain a valid tcp port: %v", addr)
+
+}
+
 // InStringArray checks if a string is an element of a slice of strings
 // and returns a boolean value.
 func InStringArray(arr []string, el string) bool {
index b794e9a4bd832db60fa48724bda5683ae00f6e72..b63576c3c6908101309afa771bc41205f0d7c8ab 100644 (file)
@@ -169,7 +169,7 @@ func testGoFlags() string {
                return ""
        }
 
-       return "-test.short"
+       return "-timeout=1m"
 }
 
 // Run tests in 32-bit mode