Add writable context to Node
authorbep <bjorn.erik.pedersen@gmail.com>
Fri, 26 Dec 2014 20:18:26 +0000 (21:18 +0100)
committerbep <bjorn.erik.pedersen@gmail.com>
Sat, 31 Jan 2015 21:01:30 +0000 (22:01 +0100)
The variable scope in the Go templates makes it hard, if possible at all, to write templates with counter variables or similar state.

This commit fixes that by adding a writable context to Node, backed by a map: Scratch.

This context has three methods, Get, Set and Add. The Add is tailored for counter variables, but can be used for any built-in numeric values or strings.

helpers/general.go
helpers/general_test.go
hugolib/node.go
hugolib/scratch.go [new file with mode: 0644]
hugolib/scratch_test.go [new file with mode: 0644]
tpl/template_test.go

index f2ac253bed476dcedbdea33d3950945e3d4dc541..32666defab0d13462cf5f6c5e80492cb9e457962 100644 (file)
@@ -17,10 +17,12 @@ import (
        "bytes"
        "crypto/md5"
        "encoding/hex"
+       "errors"
        "fmt"
        "io"
        "net"
        "path/filepath"
+       "reflect"
        "strings"
 
        bp "github.com/spf13/hugo/bufferpool"
@@ -118,3 +120,124 @@ func Md5String(f string) string {
        h.Write([]byte(f))
        return hex.EncodeToString(h.Sum([]byte{}))
 }
+
+// DoArithmetic performs arithmetic operations (+,-,*,/) using reflection to
+// determine the type of the two terms.
+func DoArithmetic(a, b interface{}, op rune) (interface{}, error) {
+       av := reflect.ValueOf(a)
+       bv := reflect.ValueOf(b)
+       var ai, bi int64
+       var af, bf float64
+       var au, bu uint64
+       switch av.Kind() {
+       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+               ai = av.Int()
+               switch bv.Kind() {
+               case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+                       bi = bv.Int()
+               case reflect.Float32, reflect.Float64:
+                       af = float64(ai) // may overflow
+                       ai = 0
+                       bf = bv.Float()
+               case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+                       bu = bv.Uint()
+                       if ai >= 0 {
+                               au = uint64(ai)
+                               ai = 0
+                       } else {
+                               bi = int64(bu) // may overflow
+                               bu = 0
+                       }
+               default:
+                       return nil, errors.New("Can't apply the operator to the values")
+               }
+       case reflect.Float32, reflect.Float64:
+               af = av.Float()
+               switch bv.Kind() {
+               case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+                       bf = float64(bv.Int()) // may overflow
+               case reflect.Float32, reflect.Float64:
+                       bf = bv.Float()
+               case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+                       bf = float64(bv.Uint()) // may overflow
+               default:
+                       return nil, errors.New("Can't apply the operator to the values")
+               }
+       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+               au = av.Uint()
+               switch bv.Kind() {
+               case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+                       bi = bv.Int()
+                       if bi >= 0 {
+                               bu = uint64(bi)
+                               bi = 0
+                       } else {
+                               ai = int64(au) // may overflow
+                               au = 0
+                       }
+               case reflect.Float32, reflect.Float64:
+                       af = float64(au) // may overflow
+                       au = 0
+                       bf = bv.Float()
+               case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+                       bu = bv.Uint()
+               default:
+                       return nil, errors.New("Can't apply the operator to the values")
+               }
+       case reflect.String:
+               as := av.String()
+               if bv.Kind() == reflect.String && op == '+' {
+                       bs := bv.String()
+                       return as + bs, nil
+               } else {
+                       return nil, errors.New("Can't apply the operator to the values")
+               }
+       default:
+               return nil, errors.New("Can't apply the operator to the values")
+       }
+
+       switch op {
+       case '+':
+               if ai != 0 || bi != 0 {
+                       return ai + bi, nil
+               } else if af != 0 || bf != 0 {
+                       return af + bf, nil
+               } else if au != 0 || bu != 0 {
+                       return au + bu, nil
+               } else {
+                       return 0, nil
+               }
+       case '-':
+               if ai != 0 || bi != 0 {
+                       return ai - bi, nil
+               } else if af != 0 || bf != 0 {
+                       return af - bf, nil
+               } else if au != 0 || bu != 0 {
+                       return au - bu, nil
+               } else {
+                       return 0, nil
+               }
+       case '*':
+               if ai != 0 || bi != 0 {
+                       return ai * bi, nil
+               } else if af != 0 || bf != 0 {
+                       return af * bf, nil
+               } else if au != 0 || bu != 0 {
+                       return au * bu, nil
+               } else {
+                       return 0, nil
+               }
+       case '/':
+               if bi != 0 {
+                       return ai / bi, nil
+               } else if bf != 0 {
+                       return af / bf, nil
+               } else if bu != 0 {
+                       return au / bu, nil
+               } else {
+                       return nil, errors.New("Can't divide the value by 0")
+               }
+       default:
+               return nil, errors.New("There is no such an operation")
+       }
+}
index fef073f050c997d85c9d896691d67fa0d928bbf8..527ba6facdf6257f4dad5deba3c3d8648af306fb 100644 (file)
@@ -2,6 +2,7 @@ package helpers
 
 import (
        "github.com/stretchr/testify/assert"
+       "reflect"
        "strings"
        "testing"
 )
@@ -128,3 +129,91 @@ func TestMd5StringEmpty(t *testing.T) {
                Md5String(in)
        }
 }
+
+func TestDoArithmetic(t *testing.T) {
+       for i, this := range []struct {
+               a      interface{}
+               b      interface{}
+               op     rune
+               expect interface{}
+       }{
+               {3, 2, '+', int64(5)},
+               {3, 2, '-', int64(1)},
+               {3, 2, '*', int64(6)},
+               {3, 2, '/', int64(1)},
+               {3.0, 2, '+', float64(5)},
+               {3.0, 2, '-', float64(1)},
+               {3.0, 2, '*', float64(6)},
+               {3.0, 2, '/', float64(1.5)},
+               {3, 2.0, '+', float64(5)},
+               {3, 2.0, '-', float64(1)},
+               {3, 2.0, '*', float64(6)},
+               {3, 2.0, '/', float64(1.5)},
+               {3.0, 2.0, '+', float64(5)},
+               {3.0, 2.0, '-', float64(1)},
+               {3.0, 2.0, '*', float64(6)},
+               {3.0, 2.0, '/', float64(1.5)},
+               {uint(3), uint(2), '+', uint64(5)},
+               {uint(3), uint(2), '-', uint64(1)},
+               {uint(3), uint(2), '*', uint64(6)},
+               {uint(3), uint(2), '/', uint64(1)},
+               {uint(3), 2, '+', uint64(5)},
+               {uint(3), 2, '-', uint64(1)},
+               {uint(3), 2, '*', uint64(6)},
+               {uint(3), 2, '/', uint64(1)},
+               {3, uint(2), '+', uint64(5)},
+               {3, uint(2), '-', uint64(1)},
+               {3, uint(2), '*', uint64(6)},
+               {3, uint(2), '/', uint64(1)},
+               {uint(3), -2, '+', int64(1)},
+               {uint(3), -2, '-', int64(5)},
+               {uint(3), -2, '*', int64(-6)},
+               {uint(3), -2, '/', int64(-1)},
+               {-3, uint(2), '+', int64(-1)},
+               {-3, uint(2), '-', int64(-5)},
+               {-3, uint(2), '*', int64(-6)},
+               {-3, uint(2), '/', int64(-1)},
+               {uint(3), 2.0, '+', float64(5)},
+               {uint(3), 2.0, '-', float64(1)},
+               {uint(3), 2.0, '*', float64(6)},
+               {uint(3), 2.0, '/', float64(1.5)},
+               {3.0, uint(2), '+', float64(5)},
+               {3.0, uint(2), '-', float64(1)},
+               {3.0, uint(2), '*', float64(6)},
+               {3.0, uint(2), '/', float64(1.5)},
+               {0, 0, '+', 0},
+               {0, 0, '-', 0},
+               {0, 0, '*', 0},
+               {"foo", "bar", '+', "foobar"},
+               {3, 0, '/', false},
+               {3.0, 0, '/', false},
+               {3, 0.0, '/', false},
+               {uint(3), uint(0), '/', false},
+               {3, uint(0), '/', false},
+               {-3, uint(0), '/', false},
+               {uint(3), 0, '/', false},
+               {3.0, uint(0), '/', false},
+               {uint(3), 0.0, '/', false},
+               {3, "foo", '+', false},
+               {3.0, "foo", '+', false},
+               {uint(3), "foo", '+', false},
+               {"foo", 3, '+', false},
+               {"foo", "bar", '-', false},
+               {3, 2, '%', false},
+       } {
+               result, err := DoArithmetic(this.a, this.b, this.op)
+               if b, ok := this.expect.(bool); ok && !b {
+                       if err == nil {
+                               t.Errorf("[%d] doArithmetic didn't return an expected error")
+                       }
+               } else {
+                       if err != nil {
+                               t.Errorf("[%d] failed: %s", i, err)
+                               continue
+                       }
+                       if !reflect.DeepEqual(result, this.expect) {
+                               t.Errorf("[%d] doArithmetic got %v but expected %v", i, result, this.expect)
+                       }
+               }
+       }
+}
index 1916e8b03d7da8c0c0078c7a036e71753e4f8d0d..604b5475a62b45b0afcbb00d28b4e4214cba8eee 100644 (file)
@@ -33,6 +33,7 @@ type Node struct {
        UrlPath
        paginator     *pager
        paginatorInit sync.Once
+       scratch       *Scratch
 }
 
 func (n *Node) Now() time.Time {
@@ -124,3 +125,11 @@ type UrlPath struct {
        Slug      string
        Section   string
 }
+
+// Scratch returns the writable context associated with this Node.
+func (n *Node) Scratch() *Scratch {
+       if n.scratch == nil {
+               n.scratch = newScratch()
+       }
+       return n.scratch
+}
diff --git a/hugolib/scratch.go b/hugolib/scratch.go
new file mode 100644 (file)
index 0000000..0f5c4b4
--- /dev/null
@@ -0,0 +1,57 @@
+// Copyright © 2013-14 Steve Francia <spf@spf13.com>.
+//
+// Licensed under the Simple Public License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://opensource.org/licenses/Simple-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hugolib
+
+import (
+       "github.com/spf13/hugo/helpers"
+)
+
+// Scratch is a writable context used for stateful operations in Page/Node rendering.
+type Scratch struct {
+       values map[string]interface{}
+}
+
+// Add will add (using the + operator) the addend to the existing addend (if found).
+// Supports numeric values and strings.
+func (c *Scratch) Add(key string, newAddend interface{}) (string, error) {
+       var newVal interface{}
+       existingAddend, found := c.values[key]
+       if found {
+               var err error
+               newVal, err = helpers.DoArithmetic(existingAddend, newAddend, '+')
+               if err != nil {
+                       return "", err
+               }
+       } else {
+               newVal = newAddend
+       }
+       c.values[key] = newVal
+       return "", nil // have to return something to make it work with the Go templates
+}
+
+// Set stores a value with the given key in the Node context.
+// This value can later be retrieved with Get.
+func (c *Scratch) Set(key string, value interface{}) string {
+       c.values[key] = value
+       return ""
+}
+
+// Get returns a value previously set by Add or Set
+func (c *Scratch) Get(key string) interface{} {
+       return c.values[key]
+}
+
+func newScratch() *Scratch {
+       return &Scratch{values: make(map[string]interface{})}
+}
diff --git a/hugolib/scratch_test.go b/hugolib/scratch_test.go
new file mode 100644 (file)
index 0000000..adff2c8
--- /dev/null
@@ -0,0 +1,49 @@
+package hugolib
+
+import (
+       "github.com/stretchr/testify/assert"
+       "testing"
+)
+
+func TestScratchAdd(t *testing.T) {
+       scratch := newScratch()
+       scratch.Add("int1", 10)
+       scratch.Add("int1", 20)
+       scratch.Add("int2", 20)
+
+       assert.Equal(t, 30, scratch.Get("int1"))
+       assert.Equal(t, 20, scratch.Get("int2"))
+
+       scratch.Add("float1", float64(10.5))
+       scratch.Add("float1", float64(20.1))
+
+       assert.Equal(t, float64(30.6), scratch.Get("float1"))
+
+       scratch.Add("string1", "Hello ")
+       scratch.Add("string1", "big ")
+       scratch.Add("string1", "World!")
+
+       assert.Equal(t, "Hello big World!", scratch.Get("string1"))
+
+       scratch.Add("scratch", scratch)
+       _, err := scratch.Add("scratch", scratch)
+
+       if err == nil {
+               t.Errorf("Expected error from invalid arithmetic")
+       }
+
+}
+
+func TestScratchSet(t *testing.T) {
+       scratch := newScratch()
+       scratch.Set("key", "val")
+       assert.Equal(t, "val", scratch.Get("key"))
+}
+
+func TestScratchGet(t *testing.T) {
+       scratch := newScratch()
+       nothing := scratch.Get("nothing")
+       if nothing != nil {
+               t.Errorf("Should not return anything, but got %v", nothing)
+       }
+}
index 4477d0d26a29507c0a8536eb7eb3c2939bf29035..6fe1c93288a9b7aaca2b58a2ee83befb8690a8ac 100644 (file)
@@ -101,94 +101,6 @@ func doTestCompare(t *testing.T, tp tstCompareType, funcUnderTest func(a, b inte
        }
 }
 
-func TestDoArithmetic(t *testing.T) {
-       for i, this := range []struct {
-               a      interface{}
-               b      interface{}
-               op     rune
-               expect interface{}
-       }{
-               {3, 2, '+', int64(5)},
-               {3, 2, '-', int64(1)},
-               {3, 2, '*', int64(6)},
-               {3, 2, '/', int64(1)},
-               {3.0, 2, '+', float64(5)},
-               {3.0, 2, '-', float64(1)},
-               {3.0, 2, '*', float64(6)},
-               {3.0, 2, '/', float64(1.5)},
-               {3, 2.0, '+', float64(5)},
-               {3, 2.0, '-', float64(1)},
-               {3, 2.0, '*', float64(6)},
-               {3, 2.0, '/', float64(1.5)},
-               {3.0, 2.0, '+', float64(5)},
-               {3.0, 2.0, '-', float64(1)},
-               {3.0, 2.0, '*', float64(6)},
-               {3.0, 2.0, '/', float64(1.5)},
-               {uint(3), uint(2), '+', uint64(5)},
-               {uint(3), uint(2), '-', uint64(1)},
-               {uint(3), uint(2), '*', uint64(6)},
-               {uint(3), uint(2), '/', uint64(1)},
-               {uint(3), 2, '+', uint64(5)},
-               {uint(3), 2, '-', uint64(1)},
-               {uint(3), 2, '*', uint64(6)},
-               {uint(3), 2, '/', uint64(1)},
-               {3, uint(2), '+', uint64(5)},
-               {3, uint(2), '-', uint64(1)},
-               {3, uint(2), '*', uint64(6)},
-               {3, uint(2), '/', uint64(1)},
-               {uint(3), -2, '+', int64(1)},
-               {uint(3), -2, '-', int64(5)},
-               {uint(3), -2, '*', int64(-6)},
-               {uint(3), -2, '/', int64(-1)},
-               {-3, uint(2), '+', int64(-1)},
-               {-3, uint(2), '-', int64(-5)},
-               {-3, uint(2), '*', int64(-6)},
-               {-3, uint(2), '/', int64(-1)},
-               {uint(3), 2.0, '+', float64(5)},
-               {uint(3), 2.0, '-', float64(1)},
-               {uint(3), 2.0, '*', float64(6)},
-               {uint(3), 2.0, '/', float64(1.5)},
-               {3.0, uint(2), '+', float64(5)},
-               {3.0, uint(2), '-', float64(1)},
-               {3.0, uint(2), '*', float64(6)},
-               {3.0, uint(2), '/', float64(1.5)},
-               {0, 0, '+', 0},
-               {0, 0, '-', 0},
-               {0, 0, '*', 0},
-               {"foo", "bar", '+', "foobar"},
-               {3, 0, '/', false},
-               {3.0, 0, '/', false},
-               {3, 0.0, '/', false},
-               {uint(3), uint(0), '/', false},
-               {3, uint(0), '/', false},
-               {-3, uint(0), '/', false},
-               {uint(3), 0, '/', false},
-               {3.0, uint(0), '/', false},
-               {uint(3), 0.0, '/', false},
-               {3, "foo", '+', false},
-               {3.0, "foo", '+', false},
-               {uint(3), "foo", '+', false},
-               {"foo", 3, '+', false},
-               {"foo", "bar", '-', false},
-               {3, 2, '%', false},
-       } {
-               result, err := doArithmetic(this.a, this.b, this.op)
-               if b, ok := this.expect.(bool); ok && !b {
-                       if err == nil {
-                               t.Errorf("[%d] doArithmetic didn't return an expected error", i)
-                       }
-               } else {
-                       if err != nil {
-                               t.Errorf("[%d] failed: %s", i, err)
-                               continue
-                       }
-                       if !reflect.DeepEqual(result, this.expect) {
-                               t.Errorf("[%d] doArithmetic got %v but expected %v", i, result, this.expect)
-                       }
-               }
-       }
-}
-
 func TestMod(t *testing.T) {
        for i, this := range []struct {
                a      interface{}