commands: Fix handling of persistent CLI flags
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Fri, 13 Apr 2018 06:42:29 +0000 (08:42 +0200)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Fri, 13 Apr 2018 07:08:49 +0000 (09:08 +0200)
See #4607

commands/benchmark.go
commands/commands.go
commands/commands_test.go
commands/hugo.go
commands/server.go
commands/server_test.go

index 4cdac883f62e055baad700727b72896134930278..3938acf1befaf6862f93aea9871ea08e713bf68a 100644 (file)
@@ -31,7 +31,7 @@ type benchmarkCmd struct {
        *baseBuilderCmd
 }
 
-func newBenchmarkCmd() *benchmarkCmd {
+func (b *commandsBuilder) newBenchmarkCmd() *benchmarkCmd {
        cmd := &cobra.Command{
                Use:   "benchmark",
                Short: "Benchmark Hugo by building a site a number of times.",
@@ -39,7 +39,7 @@ func newBenchmarkCmd() *benchmarkCmd {
 creating a benchmark.`,
        }
 
-       c := &benchmarkCmd{baseBuilderCmd: newBuilderCmd(cmd)}
+       c := &benchmarkCmd{baseBuilderCmd: b.newBuilderCmd(cmd)}
 
        cmd.Flags().StringVar(&c.cpuProfileFile, "cpuprofile", "", "path/filename for the CPU profile file")
        cmd.Flags().StringVar(&c.memProfileFile, "memprofile", "", "path/filename for the memory profile file")
index 86486d2a4e8f8bd655f04d83858801c0866b6dac..d0cc97b85aa0dc98e846724aa81ac26838e7cc15 100644 (file)
@@ -21,29 +21,43 @@ import (
        "github.com/spf13/nitro"
 )
 
-// newHugoCompleteCmd builds the complete set of Hugo CLI commands.
-func newHugoCompleteCmd() *hugoCmd {
-       h := newHugoCmd()
-       addAllCommands(h.getCommand())
-       return h
+type commandsBuilder struct {
+       hugoBuilderCommon
+
+       commands []cmder
+}
+
+func newCommandsBuilder() *commandsBuilder {
+       return &commandsBuilder{}
+}
+
+func (b *commandsBuilder) addCommands(commands ...cmder) *commandsBuilder {
+       b.commands = append(b.commands, commands...)
+       return b
 }
 
-// addAllCommands adds child commands to the root command HugoCmd.
-func addAllCommands(root *cobra.Command) {
-       addCommands(
-               root,
-               newServerCmd(),
+func (b *commandsBuilder) addAll() *commandsBuilder {
+       b.addCommands(
+               b.newServerCmd(),
                newVersionCmd(),
                newEnvCmd(),
                newConfigCmd(),
                newCheckCmd(),
-               newBenchmarkCmd(),
+               b.newBenchmarkCmd(),
                newConvertCmd(),
                newNewCmd(),
                newListCmd(),
                newImportCmd(),
                newGenCmd(),
        )
+
+       return b
+}
+
+func (b *commandsBuilder) build() *hugoCmd {
+       h := b.newHugoCmd()
+       addCommands(h.getCommand(), b.commands...)
+       return h
 }
 
 func addCommands(root *cobra.Command, commands ...cmder) {
@@ -56,9 +70,19 @@ type baseCmd struct {
        cmd *cobra.Command
 }
 
+var _ commandsBuilderGetter = (*baseBuilderCmd)(nil)
+
+// Used in tests.
+type commandsBuilderGetter interface {
+       getCmmandsBuilder() *commandsBuilder
+}
 type baseBuilderCmd struct {
-       hugoBuilderCommon
        *baseCmd
+       *commandsBuilder
+}
+
+func (b *baseBuilderCmd) getCmmandsBuilder() *commandsBuilder {
+       return b.commandsBuilder
 }
 
 func (c *baseCmd) getCommand() *cobra.Command {
@@ -69,8 +93,8 @@ func newBaseCmd(cmd *cobra.Command) *baseCmd {
        return &baseCmd{cmd: cmd}
 }
 
-func newBuilderCmd(cmd *cobra.Command) *baseBuilderCmd {
-       bcmd := &baseBuilderCmd{baseCmd: &baseCmd{cmd: cmd}}
+func (b *commandsBuilder) newBuilderCmd(cmd *cobra.Command) *baseBuilderCmd {
+       bcmd := &baseBuilderCmd{commandsBuilder: b, baseCmd: &baseCmd{cmd: cmd}}
        bcmd.hugoBuilderCommon.handleFlags(cmd)
        return bcmd
 }
@@ -86,10 +110,10 @@ type hugoCmd struct {
        c *commandeer
 }
 
-func newHugoCmd() *hugoCmd {
+func (b *commandsBuilder) newHugoCmd() *hugoCmd {
        cc := &hugoCmd{}
 
-       cc.baseBuilderCmd = newBuilderCmd(&cobra.Command{
+       cc.baseBuilderCmd = b.newBuilderCmd(&cobra.Command{
                Use:   "hugo",
                Short: "hugo builds your site",
                Long: `hugo is the main command, used to build your Hugo site.
index ea9d3f74de33b466c49b4a593a3fe20f7c4d4b4e..6fabbbb0bcc8f86d93f9d1e8725690f58084d6dd 100644 (file)
@@ -20,6 +20,8 @@ import (
        "path/filepath"
        "testing"
 
+       "github.com/spf13/cobra"
+
        "github.com/stretchr/testify/require"
 )
 
@@ -41,7 +43,44 @@ func TestExecute(t *testing.T) {
        assert.True(len(result.Sites[0].RegularPages) == 1)
 }
 
-func TestCommands(t *testing.T) {
+func TestCommandsPersistentFlags(t *testing.T) {
+       assert := require.New(t)
+
+       noOpRunE := func(cmd *cobra.Command, args []string) error {
+               return nil
+       }
+
+       tests := []struct {
+               args  []string
+               check func(command []cmder)
+       }{{[]string{"server", "--config=myconfig.toml", "-b=https://example.com/b/", "--source=mysource"}, func(commands []cmder) {
+               for _, command := range commands {
+                       if b, ok := command.(commandsBuilderGetter); ok {
+                               v := b.getCmmandsBuilder().hugoBuilderCommon
+                               assert.Equal("myconfig.toml", v.cfgFile)
+                               assert.Equal("mysource", v.source)
+                               assert.Equal("https://example.com/b/", v.baseURL)
+                       }
+               }
+       }}}
+
+       for _, test := range tests {
+               b := newCommandsBuilder()
+               root := b.addAll().build()
+
+               for _, c := range b.commands {
+                       // We are only intereseted in the flag handling here.
+                       c.getCommand().RunE = noOpRunE
+               }
+               rootCmd := root.getCommand()
+               rootCmd.SetArgs(test.args)
+               assert.NoError(rootCmd.Execute())
+               test.check(b.commands)
+       }
+
+}
+
+func TestCommandsExecute(t *testing.T) {
 
        assert := require.New(t)
 
@@ -90,7 +129,7 @@ func TestCommands(t *testing.T) {
 
        for _, test := range tests {
 
-               hugoCmd := newHugoCompleteCmd().getCommand()
+               hugoCmd := newCommandsBuilder().addAll().build().getCommand()
                test.flags = append(test.flags, "--quiet")
                hugoCmd.SetArgs(append(test.commands, test.flags...))
 
index c1cf9833f1efbe9ad73874c81ec73105825e592a..84e265cf24b185a13d5ad9e4852949f8c8444fc6 100644 (file)
@@ -70,7 +70,7 @@ func (r Response) IsUserError() bool {
 // Execute adds all child commands to the root command HugoCmd and sets flags appropriately.
 // The args are usually filled with os.Args[1:].
 func Execute(args []string) Response {
-       hugoCmd := newHugoCompleteCmd()
+       hugoCmd := newCommandsBuilder().addAll().build()
        cmd := hugoCmd.getCommand()
        cmd.SetArgs(args)
 
index 8db6fa918e342f81a3433b92784bf010537f1900..c05180de984958c9199c81c7d23bc3e003dedf54 100644 (file)
@@ -57,14 +57,14 @@ type serverCmd struct {
        *baseBuilderCmd
 }
 
-func newServerCmd() *serverCmd {
-       return newServerCmdSignaled(nil)
+func (b *commandsBuilder) newServerCmd() *serverCmd {
+       return b.newServerCmdSignaled(nil)
 }
 
-func newServerCmdSignaled(stop <-chan bool) *serverCmd {
+func (b *commandsBuilder) newServerCmdSignaled(stop <-chan bool) *serverCmd {
        cc := &serverCmd{stop: stop}
 
-       cc.baseBuilderCmd = newBuilderCmd(&cobra.Command{
+       cc.baseBuilderCmd = b.newBuilderCmd(&cobra.Command{
                Use:     "server",
                Aliases: []string{"serve"},
                Short:   "A high performance webserver",
@@ -463,7 +463,8 @@ func (sc *serverCmd) fixURL(cfg config.Provider, s string, port int) (string, er
 }
 
 func memStats() error {
-       sc := newServerCmd().getCommand()
+       b := newCommandsBuilder()
+       sc := b.newServerCmd().getCommand()
        memstats := sc.Flags().Lookup("memstats").Value.String()
        if memstats != "" {
                interval, err := time.ParseDuration(sc.Flags().Lookup("meminterval").Value.String())
index 648664697683031cc0122d3fa6d0c35522c9def0..7b9cdcfaab8dac0b2b5e2265818605626b66fa75 100644 (file)
@@ -40,7 +40,8 @@ func TestServer(t *testing.T) {
 
        stop := make(chan bool)
 
-       scmd := newServerCmdSignaled(stop)
+       b := newCommandsBuilder()
+       scmd := b.newServerCmdSignaled(stop)
 
        cmd := scmd.getCommand()
        cmd.SetArgs([]string{"-s=" + dir, fmt.Sprintf("-p=%d", port)})
@@ -90,7 +91,8 @@ func TestFixURL(t *testing.T) {
        }
 
        for i, test := range tests {
-               s := newServerCmd()
+               b := newCommandsBuilder()
+               s := b.newServerCmd()
                v := viper.New()
                baseURL := test.CLIBaseURL
                v.Set("baseURL", test.CfgBaseURL)