tpl: Allow the partial template func to return any type
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Tue, 2 Apr 2019 08:30:24 +0000 (10:30 +0200)
committerGitHub <noreply@github.com>
Tue, 2 Apr 2019 08:30:24 +0000 (10:30 +0200)
This commit adds support for return values in partials.

This means that you can now do this and similar:

    {{ $v := add . 42 }}
    {{ return $v }}

Partials without a `return` statement will be rendered as before.

This works for both `partial` and `partialCached`.

Fixes #5783

compare/compare.go
hugolib/template_test.go
metrics/metrics.go
tpl/partials/init.go
tpl/partials/partials.go
tpl/template_info.go
tpl/tplimpl/ace.go
tpl/tplimpl/shortcodes.go
tpl/tplimpl/template.go
tpl/tplimpl/template_ast_transformers.go
tpl/tplimpl/template_ast_transformers_test.go

index 19a5deaa2f150b78ce62841b11038c9ff683b361..18c0de777ee8f1a074d10c2866e33e1ab06b2d3a 100644 (file)
@@ -20,6 +20,12 @@ type Eqer interface {
        Eq(other interface{}) bool
 }
 
+// ProbablyEq is an equal check that may return false positives, but never
+// a false negative.
+type ProbablyEqer interface {
+       ProbablyEq(other interface{}) bool
+}
+
 // Comparer can be used to compare two values.
 // This will be used when using the le, ge etc. operators in the templates.
 // Compare returns -1 if the given version is less than, 0 if equal and 1 if greater than
index 56f5dd5ba0d438cf01c0e35f9366338e36fc845c..3ec81323b4bb4d9c73fc892c827994e0ccb0d3f6 100644 (file)
@@ -264,3 +264,44 @@ Hugo: {{ hugo.Generator }}
        )
 
 }
+
+func TestPartialWithReturn(t *testing.T) {
+
+       b := newTestSitesBuilder(t).WithSimpleConfigFile()
+
+       b.WithTemplatesAdded(
+               "index.html", `
+Test Partials With Return Values:
+
+add42: 50: {{ partial "add42.tpl" 8 }}
+dollarContext: 60: {{ partial "dollarContext.tpl" 18 }}
+adder: 70: {{ partial "dict.tpl" (dict "adder" 28) }}
+complex: 80: {{ partial "complex.tpl" 38 }}
+`,
+               "partials/add42.tpl", `
+               {{ $v := add . 42 }}
+               {{ return $v }}
+               `,
+               "partials/dollarContext.tpl", `
+{{ $v := add $ 42 }}
+{{ return $v }}
+`,
+               "partials/dict.tpl", `
+{{ $v := add $.adder 42 }}
+{{ return $v }}
+`,
+               "partials/complex.tpl", `
+{{ return add . 42 }}
+`,
+       )
+
+       b.CreateSites().Build(BuildCfg{})
+
+       b.AssertFileContent("public/index.html",
+               "add42: 50: 50",
+               "dollarContext: 60: 60",
+               "adder: 70: 70",
+               "complex: 80: 80",
+       )
+
+}
index c83610a929f074147ce251d9b847f7e245e4ec61..e67b16bdae18bfed632ae7e6e28d37874cec5d20 100644 (file)
@@ -23,6 +23,12 @@ import (
        "strings"
        "sync"
        "time"
+
+       "github.com/gohugoio/hugo/compare"
+
+       "github.com/gohugoio/hugo/common/hreflect"
+
+       "github.com/spf13/cast"
 )
 
 // The Provider interface defines an interface for measuring metrics.
@@ -35,20 +41,20 @@ type Provider interface {
        WriteMetrics(w io.Writer)
 
        // TrackValue tracks the value for diff calculations etc.
-       TrackValue(key, value string)
+       TrackValue(key string, value interface{})
 
        // Reset clears the metric store.
        Reset()
 }
 
 type diff struct {
-       baseline string
+       baseline interface{}
        count    int
        simSum   int
 }
 
-func (d *diff) add(v string) *diff {
-       if d.baseline == "" {
+func (d *diff) add(v interface{}) *diff {
+       if !hreflect.IsTruthful(v) {
                d.baseline = v
                d.count = 1
                d.simSum = 100 // If we get only one it is very cache friendly.
@@ -90,7 +96,7 @@ func (s *Store) Reset() {
 }
 
 // TrackValue tracks the value for diff calculations etc.
-func (s *Store) TrackValue(key, value string) {
+func (s *Store) TrackValue(key string, value interface{}) {
        if !s.calculateHints {
                return
        }
@@ -191,13 +197,43 @@ func (b bySum) Less(i, j int) bool { return b[i].sum > b[j].sum }
 
 // howSimilar is a naive diff implementation that returns
 // a number between 0-100 indicating how similar a and b are.
-// 100 is when all words in a also exists in b.
-func howSimilar(a, b string) int {
-
+func howSimilar(a, b interface{}) int {
        if a == b {
                return 100
        }
 
+       as, err1 := cast.ToStringE(a)
+       bs, err2 := cast.ToStringE(b)
+
+       if err1 == nil && err2 == nil {
+               return howSimilarStrings(as, bs)
+       }
+
+       if err1 != err2 {
+               return 0
+       }
+
+       e1, ok1 := a.(compare.Eqer)
+       e2, ok2 := b.(compare.Eqer)
+       if ok1 && ok2 && e1.Eq(e2) {
+               return 100
+       }
+
+       // TODO(bep) implement ProbablyEq for Pages etc.
+       pe1, pok1 := a.(compare.ProbablyEqer)
+       pe2, pok2 := b.(compare.ProbablyEqer)
+       if pok1 && pok2 && pe1.ProbablyEq(pe2) {
+               return 90
+       }
+
+       return 0
+}
+
+// howSimilar is a naive diff implementation that returns
+// a number between 0-100 indicating how similar a and b are.
+// 100 is when all words in a also exists in b.
+func howSimilarStrings(a, b string) int {
+
        // Give some weight to the word positions.
        const partitionSize = 4
 
index b68256a9a5f791637e3e8e11fb0201e720b5674d..c2135bca5b9e9628a4a501dc72e87b1f5b5a77e3 100644 (file)
@@ -36,6 +36,13 @@ func init() {
                        },
                )
 
+               // TODO(bep) we need the return to be a valid identifier, but
+               // should consider another way of adding it.
+               ns.AddMethodMapping(func() string { return "" },
+                       []string{"return"},
+                       [][2]string{},
+               )
+
                ns.AddMethodMapping(ctx.IncludeCached,
                        []string{"partialCached"},
                        [][2]string{},
index 1e8a84954bc7d81fb028f7eee584ba3f9832f9df..2599a5d0133c6af2bcf860c6c1333e4e08234c3d 100644 (file)
@@ -18,10 +18,14 @@ package partials
 import (
        "fmt"
        "html/template"
+       "io"
+       "io/ioutil"
        "strings"
        "sync"
        texttemplate "text/template"
 
+       "github.com/gohugoio/hugo/tpl"
+
        bp "github.com/gohugoio/hugo/bufferpool"
        "github.com/gohugoio/hugo/deps"
 )
@@ -62,8 +66,22 @@ type Namespace struct {
        cachedPartials *partialCache
 }
 
-// Include executes the named partial and returns either a string,
-// when the partial is a text/template, or template.HTML when html/template.
+// contextWrapper makes room for a return value in a partial invocation.
+type contextWrapper struct {
+       Arg    interface{}
+       Result interface{}
+}
+
+// Set sets the return value and returns an empty string.
+func (c *contextWrapper) Set(in interface{}) string {
+       c.Result = in
+       return ""
+}
+
+// Include executes the named partial.
+// If the partial contains a return statement, that value will be returned.
+// Else, the rendered output will be returned:
+// A string if the partial is a text/template, or template.HTML when html/template.
 func (ns *Namespace) Include(name string, contextList ...interface{}) (interface{}, error) {
        if strings.HasPrefix(name, "partials/") {
                name = name[8:]
@@ -83,31 +101,54 @@ func (ns *Namespace) Include(name string, contextList ...interface{}) (interface
                // For legacy reasons.
                templ, found = ns.deps.Tmpl.Lookup(n + ".html")
        }
-       if found {
+
+       if !found {
+               return "", fmt.Errorf("partial %q not found", name)
+       }
+
+       var info tpl.Info
+       if ip, ok := templ.(tpl.TemplateInfoProvider); ok {
+               info = ip.TemplateInfo()
+       }
+
+       var w io.Writer
+
+       if info.HasReturn {
+               // Wrap the context sent to the template to capture the return value.
+               // Note that the template is rewritten to make sure that the dot (".")
+               // and the $ variable points to Arg.
+               context = &contextWrapper{
+                       Arg: context,
+               }
+
+               // We don't care about any template output.
+               w = ioutil.Discard
+       } else {
                b := bp.GetBuffer()
                defer bp.PutBuffer(b)
+               w = b
+       }
 
-               if err := templ.Execute(b, context); err != nil {
-                       return "", err
-               }
+       if err := templ.Execute(w, context); err != nil {
+               return "", err
+       }
 
-               if _, ok := templ.(*texttemplate.Template); ok {
-                       s := b.String()
-                       if ns.deps.Metrics != nil {
-                               ns.deps.Metrics.TrackValue(n, s)
-                       }
-                       return s, nil
-               }
+       var result interface{}
 
-               s := b.String()
-               if ns.deps.Metrics != nil {
-                       ns.deps.Metrics.TrackValue(n, s)
-               }
-               return template.HTML(s), nil
+       if ctx, ok := context.(*contextWrapper); ok {
+               result = ctx.Result
+       } else if _, ok := templ.(*texttemplate.Template); ok {
+               result = w.(fmt.Stringer).String()
+       } else {
+               result = template.HTML(w.(fmt.Stringer).String())
+       }
 
+       if ns.deps.Metrics != nil {
+               ns.deps.Metrics.TrackValue(n, result)
        }
 
-       return "", fmt.Errorf("partial %q not found", name)
+       return result, nil
+
 }
 
 // IncludeCached executes and caches partial templates.  An optional variant
index 8568f46f0ffa6f8de691353c872522bebaedb264..be056695895791c565f9a8b8ce55760e006e11f1 100644 (file)
@@ -22,10 +22,17 @@ type Info struct {
        // Set for shortcode templates with any {{ .Inner }}
        IsInner bool
 
+       // Set for partials with a return statement.
+       HasReturn bool
+
        // Config extracted from template.
        Config Config
 }
 
+func (info Info) IsZero() bool {
+       return info.Config.Version == 0
+}
+
 type Config struct {
        Version int
 }
index 7a1f849f40ab3d1beaffb8604e329b7741c1372b..6fedcb583e01c991b67da6fc27cba0094058cd62 100644 (file)
@@ -51,15 +51,17 @@ func (t *templateHandler) addAceTemplate(name, basePath, innerPath string, baseC
                return err
        }
 
-       isShort := isShortcode(name)
+       typ := resolveTemplateType(name)
 
-       info, err := applyTemplateTransformersToHMLTTemplate(isShort, templ)
+       info, err := applyTemplateTransformersToHMLTTemplate(typ, templ)
        if err != nil {
                return err
        }
 
-       if isShort {
+       if typ == templateShortcode {
                t.addShortcodeVariant(name, info, templ)
+       } else {
+               t.templateInfo[name] = info
        }
 
        return nil
index 8577fbeedc938cab39314475653c7a480abbd374..40fdeea5d10bef043a4709fcd990c8154706e03e 100644 (file)
@@ -139,6 +139,18 @@ func templateNameAndVariants(name string) (string, []string) {
        return name, variants
 }
 
+func resolveTemplateType(name string) templateType {
+       if isShortcode(name) {
+               return templateShortcode
+       }
+
+       if strings.Contains(name, "partials/") {
+               return templatePartial
+       }
+
+       return templateUndefined
+}
+
 func isShortcode(name string) bool {
        return strings.Contains(name, "shortcodes/")
 }
index d6deba2dfa20be133e5bc23e839ea9991611c44e..49b9e1c349ab3b2cdd635199dd7e076b72caafdd 100644 (file)
@@ -90,6 +90,11 @@ type templateHandler struct {
        // (language, output format etc.) of that shortcode.
        shortcodes map[string]*shortcodeTemplates
 
+       // templateInfo maps template name to some additional information about that template.
+       // Note that for shortcodes that same information is embedded in the
+       // shortcodeTemplates type.
+       templateInfo map[string]tpl.Info
+
        // text holds all the pure text templates.
        text *textTemplates
        html *htmlTemplates
@@ -172,18 +177,30 @@ func (t *templateHandler) Lookup(name string) (tpl.Template, bool) {
                // The templates are stored without the prefix identificator.
                name = strings.TrimPrefix(name, textTmplNamePrefix)
 
-               return t.text.Lookup(name)
+               return t.applyTemplateInfo(t.text.Lookup(name))
        }
 
        // Look in both
        if te, found := t.html.Lookup(name); found {
-               return te, true
+               return t.applyTemplateInfo(te, true)
        }
 
-       return t.text.Lookup(name)
+       return t.applyTemplateInfo(t.text.Lookup(name))
 
 }
 
+func (t *templateHandler) applyTemplateInfo(templ tpl.Template, found bool) (tpl.Template, bool) {
+       if adapter, ok := templ.(*tpl.TemplateAdapter); ok {
+               if adapter.Info.IsZero() {
+                       if info, found := t.templateInfo[templ.Name()]; found {
+                               adapter.Info = info
+                       }
+               }
+       }
+
+       return templ, found
+}
+
 // This currently only applies to shortcodes and what we get here is the
 // shortcode name.
 func (t *templateHandler) LookupVariant(name string, variants tpl.TemplateVariants) (tpl.Template, bool, bool) {
@@ -243,12 +260,13 @@ func (t *templateHandler) setFuncMapInTemplate(in interface{}, funcs map[string]
 
 func (t *templateHandler) clone(d *deps.Deps) *templateHandler {
        c := &templateHandler{
-               Deps:       d,
-               layoutsFs:  d.BaseFs.Layouts.Fs,
-               shortcodes: make(map[string]*shortcodeTemplates),
-               html:       &htmlTemplates{t: template.Must(t.html.t.Clone()), overlays: make(map[string]*template.Template), templatesCommon: t.html.templatesCommon},
-               text:       &textTemplates{textTemplate: &textTemplate{t: texttemplate.Must(t.text.t.Clone())}, overlays: make(map[string]*texttemplate.Template), templatesCommon: t.text.templatesCommon},
-               errors:     make([]*templateErr, 0),
+               Deps:         d,
+               layoutsFs:    d.BaseFs.Layouts.Fs,
+               shortcodes:   make(map[string]*shortcodeTemplates),
+               templateInfo: t.templateInfo,
+               html:         &htmlTemplates{t: template.Must(t.html.t.Clone()), overlays: make(map[string]*template.Template), templatesCommon: t.html.templatesCommon},
+               text:         &textTemplates{textTemplate: &textTemplate{t: texttemplate.Must(t.text.t.Clone())}, overlays: make(map[string]*texttemplate.Template), templatesCommon: t.text.templatesCommon},
+               errors:       make([]*templateErr, 0),
        }
 
        for k, v := range t.shortcodes {
@@ -306,12 +324,13 @@ func newTemplateAdapter(deps *deps.Deps) *templateHandler {
                templatesCommon: common,
        }
        h := &templateHandler{
-               Deps:       deps,
-               layoutsFs:  deps.BaseFs.Layouts.Fs,
-               shortcodes: make(map[string]*shortcodeTemplates),
-               html:       htmlT,
-               text:       textT,
-               errors:     make([]*templateErr, 0),
+               Deps:         deps,
+               layoutsFs:    deps.BaseFs.Layouts.Fs,
+               shortcodes:   make(map[string]*shortcodeTemplates),
+               templateInfo: make(map[string]tpl.Info),
+               html:         htmlT,
+               text:         textT,
+               errors:       make([]*templateErr, 0),
        }
 
        common.handler = h
@@ -463,15 +482,17 @@ func (t *htmlTemplates) addTemplateIn(tt *template.Template, name, tpl string) e
                return err
        }
 
-       isShort := isShortcode(name)
+       typ := resolveTemplateType(name)
 
-       info, err := applyTemplateTransformersToHMLTTemplate(isShort, templ)
+       info, err := applyTemplateTransformersToHMLTTemplate(typ, templ)
        if err != nil {
                return err
        }
 
-       if isShort {
+       if typ == templateShortcode {
                t.handler.addShortcodeVariant(name, info, templ)
+       } else {
+               t.handler.templateInfo[name] = info
        }
 
        return nil
@@ -511,7 +532,7 @@ func (t *textTemplate) parseIn(tt *texttemplate.Template, name, tpl string) (*te
                return nil, err
        }
 
-       if _, err := applyTemplateTransformersToTextTemplate(false, templ); err != nil {
+       if _, err := applyTemplateTransformersToTextTemplate(templateUndefined, templ); err != nil {
                return nil, err
        }
        return templ, nil
@@ -524,15 +545,17 @@ func (t *textTemplates) addTemplateIn(tt *texttemplate.Template, name, tpl strin
                return err
        }
 
-       isShort := isShortcode(name)
+       typ := resolveTemplateType(name)
 
-       info, err := applyTemplateTransformersToTextTemplate(isShort, templ)
+       info, err := applyTemplateTransformersToTextTemplate(typ, templ)
        if err != nil {
                return err
        }
 
-       if isShort {
+       if typ == templateShortcode {
                t.handler.addShortcodeVariant(name, info, templ)
+       } else {
+               t.handler.templateInfo[name] = info
        }
 
        return nil
@@ -737,7 +760,7 @@ func (t *htmlTemplates) handleMaster(name, overlayFilename, masterFilename strin
        // * https://github.com/golang/go/issues/16101
        // * https://github.com/gohugoio/hugo/issues/2549
        overlayTpl = overlayTpl.Lookup(overlayTpl.Name())
-       if _, err := applyTemplateTransformersToHMLTTemplate(false, overlayTpl); err != nil {
+       if _, err := applyTemplateTransformersToHMLTTemplate(templateUndefined, overlayTpl); err != nil {
                return err
        }
 
@@ -777,7 +800,7 @@ func (t *textTemplates) handleMaster(name, overlayFilename, masterFilename strin
        }
 
        overlayTpl = overlayTpl.Lookup(overlayTpl.Name())
-       if _, err := applyTemplateTransformersToTextTemplate(false, overlayTpl); err != nil {
+       if _, err := applyTemplateTransformersToTextTemplate(templateUndefined, overlayTpl); err != nil {
                return err
        }
        t.overlays[name] = overlayTpl
@@ -847,15 +870,17 @@ func (t *templateHandler) addTemplateFile(name, baseTemplatePath, path string) e
                        return err
                }
 
-               isShort := isShortcode(name)
+               typ := resolveTemplateType(name)
 
-               info, err := applyTemplateTransformersToHMLTTemplate(isShort, templ)
+               info, err := applyTemplateTransformersToHMLTTemplate(typ, templ)
                if err != nil {
                        return err
                }
 
-               if isShort {
+               if typ == templateShortcode {
                        t.addShortcodeVariant(templateName, info, templ)
+               } else {
+                       t.templateInfo[name] = info
                }
 
                return nil
index 28898c55baa9741ee46cc96d1c4b5005a7116df0..57fafcd88f65b06870382ed2a748419a0581d92e 100644 (file)
@@ -39,6 +39,14 @@ var reservedContainers = map[string]bool{
        "Data": true,
 }
 
+type templateType int
+
+const (
+       templateUndefined templateType = iota
+       templateShortcode
+       templatePartial
+)
+
 type templateContext struct {
        decl     decl
        visited  map[string]bool
@@ -47,14 +55,16 @@ type templateContext struct {
        // The last error encountered.
        err error
 
-       // Only needed for shortcodes
-       isShortcode bool
+       typ templateType
 
        // Set when we're done checking for config header.
        configChecked bool
 
        // Contains some info about the template
        tpl.Info
+
+       // Store away the return node in partials.
+       returnNode *parse.CommandNode
 }
 
 func (c templateContext) getIfNotVisited(name string) *parse.Tree {
@@ -84,12 +94,12 @@ func createParseTreeLookup(templ *template.Template) func(nn string) *parse.Tree
        }
 }
 
-func applyTemplateTransformersToHMLTTemplate(isShortcode bool, templ *template.Template) (tpl.Info, error) {
-       return applyTemplateTransformers(isShortcode, templ.Tree, createParseTreeLookup(templ))
+func applyTemplateTransformersToHMLTTemplate(typ templateType, templ *template.Template) (tpl.Info, error) {
+       return applyTemplateTransformers(typ, templ.Tree, createParseTreeLookup(templ))
 }
 
-func applyTemplateTransformersToTextTemplate(isShortcode bool, templ *texttemplate.Template) (tpl.Info, error) {
-       return applyTemplateTransformers(isShortcode, templ.Tree,
+func applyTemplateTransformersToTextTemplate(typ templateType, templ *texttemplate.Template) (tpl.Info, error) {
+       return applyTemplateTransformers(typ, templ.Tree,
                func(nn string) *parse.Tree {
                        tt := templ.Lookup(nn)
                        if tt != nil {
@@ -99,19 +109,54 @@ func applyTemplateTransformersToTextTemplate(isShortcode bool, templ *texttempla
                })
 }
 
-func applyTemplateTransformers(isShortcode bool, templ *parse.Tree, lookupFn func(name string) *parse.Tree) (tpl.Info, error) {
+func applyTemplateTransformers(typ templateType, templ *parse.Tree, lookupFn func(name string) *parse.Tree) (tpl.Info, error) {
        if templ == nil {
                return tpl.Info{}, errors.New("expected template, but none provided")
        }
 
        c := newTemplateContext(lookupFn)
-       c.isShortcode = isShortcode
+       c.typ = typ
+
+       _, err := c.applyTransformations(templ.Root)
 
-       err := c.applyTransformations(templ.Root)
+       if err == nil && c.returnNode != nil {
+               // This is a partial with a return statement.
+               c.Info.HasReturn = true
+               templ.Root = c.wrapInPartialReturnWrapper(templ.Root)
+       }
 
        return c.Info, err
 }
 
+const (
+       partialReturnWrapperTempl = `{{ $_hugo_dot := $ }}{{ $ := .Arg }}{{ with .Arg }}{{ $_hugo_dot.Set ("PLACEHOLDER") }}{{ end }}`
+)
+
+var partialReturnWrapper *parse.ListNode
+
+func init() {
+       templ, err := texttemplate.New("").Parse(partialReturnWrapperTempl)
+       if err != nil {
+               panic(err)
+       }
+       partialReturnWrapper = templ.Tree.Root
+}
+
+func (c *templateContext) wrapInPartialReturnWrapper(n *parse.ListNode) *parse.ListNode {
+       wrapper := partialReturnWrapper.CopyList()
+       withNode := wrapper.Nodes[2].(*parse.WithNode)
+       retn := withNode.List.Nodes[0]
+       setCmd := retn.(*parse.ActionNode).Pipe.Cmds[0]
+       setPipe := setCmd.Args[1].(*parse.PipeNode)
+       // Replace PLACEHOLDER with the real return value.
+       // Note that this is a PipeNode, so it will be wrapped in parens.
+       setPipe.Cmds = []*parse.CommandNode{c.returnNode}
+       withNode.List.Nodes = append(n.Nodes, retn)
+
+       return wrapper
+
+}
+
 // The truth logic in Go's template package is broken for certain values
 // for the if and with keywords. This works around that problem by wrapping
 // the node passed to if/with in a getif conditional.
@@ -141,7 +186,7 @@ func (c *templateContext) wrapWithGetIf(p *parse.PipeNode) {
 // 1) Make all .Params.CamelCase and similar into lowercase.
 // 2) Wraps every with and if pipe in getif
 // 3) Collects some information about the template content.
-func (c *templateContext) applyTransformations(n parse.Node) error {
+func (c *templateContext) applyTransformations(n parse.Node) (bool, error) {
        switch x := n.(type) {
        case *parse.ListNode:
                if x != nil {
@@ -169,12 +214,16 @@ func (c *templateContext) applyTransformations(n parse.Node) error {
                        c.decl[x.Decl[0].Ident[0]] = x.Cmds[0].String()
                }
 
-               for _, cmd := range x.Cmds {
-                       c.applyTransformations(cmd)
+               for i, cmd := range x.Cmds {
+                       keep, _ := c.applyTransformations(cmd)
+                       if !keep {
+                               x.Cmds = append(x.Cmds[:i], x.Cmds[i+1:]...)
+                       }
                }
 
        case *parse.CommandNode:
                c.collectInner(x)
+               keep := c.collectReturnNode(x)
 
                for _, elem := range x.Args {
                        switch an := elem.(type) {
@@ -191,9 +240,10 @@ func (c *templateContext) applyTransformations(n parse.Node) error {
                                }
                        }
                }
+               return keep, c.err
        }
 
-       return c.err
+       return true, c.err
 }
 
 func (c *templateContext) applyTransformationsToNodes(nodes ...parse.Node) {
@@ -229,7 +279,7 @@ func (c *templateContext) hasIdent(idents []string, ident string) bool {
 // on the form:
 //    {{ $_hugo_config:= `{ "version": 1 }` }}
 func (c *templateContext) collectConfig(n *parse.PipeNode) {
-       if !c.isShortcode {
+       if c.typ != templateShortcode {
                return
        }
        if c.configChecked {
@@ -271,7 +321,7 @@ func (c *templateContext) collectConfig(n *parse.PipeNode) {
 // collectInner determines if the given CommandNode represents a
 // shortcode call to its .Inner.
 func (c *templateContext) collectInner(n *parse.CommandNode) {
-       if !c.isShortcode {
+       if c.typ != templateShortcode {
                return
        }
        if c.Info.IsInner || len(n.Args) == 0 {
@@ -295,6 +345,28 @@ func (c *templateContext) collectInner(n *parse.CommandNode) {
 
 }
 
+func (c *templateContext) collectReturnNode(n *parse.CommandNode) bool {
+       if c.typ != templatePartial || c.returnNode != nil {
+               return true
+       }
+
+       if len(n.Args) < 2 {
+               return true
+       }
+
+       ident, ok := n.Args[0].(*parse.IdentifierNode)
+       if !ok || ident.Ident != "return" {
+               return true
+       }
+
+       c.returnNode = n
+       // Remove the "return" identifiers
+       c.returnNode.Args = c.returnNode.Args[1:]
+
+       return false
+
+}
+
 // indexOfReplacementStart will return the index of where to start doing replacement,
 // -1 if none needed.
 func (d decl) indexOfReplacementStart(idents []string) int {
index 8d8b423683311e13d483d2e32a3fbc1005939efd..9ed29d27f4b2889766652766b286ff63ccc78f63 100644 (file)
@@ -180,7 +180,7 @@ PARAMS SITE GLOBAL3: {{ $site.Params.LOWER }}
 func TestParamsKeysToLower(t *testing.T) {
        t.Parallel()
 
-       _, err := applyTemplateTransformers(false, nil, nil)
+       _, err := applyTemplateTransformers(templateUndefined, nil, nil)
        require.Error(t, err)
 
        templ, err := template.New("foo").Funcs(testFuncs).Parse(paramsTempl)
@@ -484,7 +484,7 @@ func TestCollectInfo(t *testing.T) {
                        require.NoError(t, err)
 
                        c := newTemplateContext(createParseTreeLookup(templ))
-                       c.isShortcode = true
+                       c.typ = templateShortcode
                        c.applyTransformations(templ.Tree.Root)
 
                        assert.Equal(test.expected, c.Info)
@@ -492,3 +492,46 @@ func TestCollectInfo(t *testing.T) {
        }
 
 }
+
+func TestPartialReturn(t *testing.T) {
+
+       tests := []struct {
+               name      string
+               tplString string
+               expected  bool
+       }{
+               {"Basic", `
+{{ $a := "Hugo Rocks!" }}
+{{ return $a }}
+`, true},
+               {"Expression", `
+{{ return add 32 }}
+`, true},
+       }
+
+       echo := func(in interface{}) interface{} {
+               return in
+       }
+
+       funcs := template.FuncMap{
+               "return": echo,
+               "add":    echo,
+       }
+
+       for _, test := range tests {
+               t.Run(test.name, func(t *testing.T) {
+                       assert := require.New(t)
+
+                       templ, err := template.New("foo").Funcs(funcs).Parse(test.tplString)
+                       require.NoError(t, err)
+
+                       _, err = applyTemplateTransformers(templatePartial, templ.Tree, createParseTreeLookup(templ))
+
+                       // Just check that it doesn't fail in this test. We have functional tests
+                       // in hugoblib.
+                       assert.NoError(err)
+
+               })
+       }
+
+}