Extend template's basic math functions to accept float, uint and string values
authorTatsushi Demachi <tdemachi@gmail.com>
Fri, 19 Sep 2014 16:33:02 +0000 (01:33 +0900)
committerspf13 <steve.francia@gmail.com>
Mon, 22 Sep 2014 13:01:40 +0000 (09:01 -0400)
hugolib/template.go
hugolib/template_test.go

index d805aeef7487c4b5d967fd09bf7809798f4056b7..d2cfccf3d6286e2d9f213fb1a2e23a2b144a28ec 100644 (file)
@@ -233,6 +233,125 @@ func SafeHtml(text string) template.HTML {
        return template.HTML(text)
 }
 
+func doArithmetic(a, b interface{}, op rune) (interface{}, error) {
+       av := reflect.ValueOf(a)
+       bv := reflect.ValueOf(b)
+       var ai, bi int64
+       var af, bf float64
+       var au, bu uint64
+       switch av.Kind() {
+       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+               ai = av.Int()
+               switch bv.Kind() {
+               case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+                       bi = bv.Int()
+               case reflect.Float32, reflect.Float64:
+                       af = float64(ai) // may overflow
+                       ai = 0
+                       bf = bv.Float()
+               case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+                       bu = bv.Uint()
+                       if ai >= 0 {
+                               au = uint64(ai)
+                               ai = 0
+                       } else {
+                               bi = int64(bu) // may overflow
+                               bu = 0
+                       }
+               default:
+                       return nil, errors.New("Can't apply the operator to the values")
+               }
+       case reflect.Float32, reflect.Float64:
+               af = av.Float()
+               switch bv.Kind() {
+               case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+                       bf = float64(bv.Int()) // may overflow
+               case reflect.Float32, reflect.Float64:
+                       bf = bv.Float()
+               case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+                       bf = float64(bv.Uint()) // may overflow
+               default:
+                       return nil, errors.New("Can't apply the operator to the values")
+               }
+       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+               au = av.Uint()
+               switch bv.Kind() {
+               case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+                       bi = bv.Int()
+                       if bi >= 0 {
+                               bu = uint64(bi)
+                               bi = 0
+                       } else {
+                               ai = int64(au) // may overflow
+                               au = 0
+                       }
+               case reflect.Float32, reflect.Float64:
+                       af = float64(au) // may overflow
+                       au = 0
+                       bf = bv.Float()
+               case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+                       bu = bv.Uint()
+               default:
+                       return nil, errors.New("Can't apply the operator to the values")
+               }
+       case reflect.String:
+               as := av.String()
+               if bv.Kind() == reflect.String && op == '+' {
+                       bs := bv.String()
+                       return as + bs, nil
+               } else {
+                       return nil, errors.New("Can't apply the operator to the values")
+               }
+       default:
+               return nil, errors.New("Can't apply the operator to the values")
+       }
+
+       switch op {
+       case '+':
+               if ai != 0 || bi != 0 {
+                       return ai + bi, nil
+               } else if af != 0 || bf != 0 {
+                       return af + bf, nil
+               } else if au != 0 || bu != 0 {
+                       return au + bu, nil
+               } else {
+                       return 0, nil
+               }
+       case '-':
+               if ai != 0 || bi != 0 {
+                       return ai - bi, nil
+               } else if af != 0 || bf != 0 {
+                       return af - bf, nil
+               } else if au != 0 || bu != 0 {
+                       return au - bu, nil
+               } else {
+                       return 0, nil
+               }
+       case '*':
+               if ai != 0 || bi != 0 {
+                       return ai * bi, nil
+               } else if af != 0 || bf != 0 {
+                       return af * bf, nil
+               } else if au != 0 || bu != 0 {
+                       return au * bu, nil
+               } else {
+                       return 0, nil
+               }
+       case '/':
+               if bi != 0 {
+                       return ai / bi, nil
+               } else if bf != 0 {
+                       return af / bf, nil
+               } else if bu != 0 {
+                       return au / bu, nil
+               } else {
+                       return nil, errors.New("Can't divide the value by 0")
+               }
+       default:
+               return nil, errors.New("There is no such an operation")
+       }
+}
+
 type Template interface {
        ExecuteTemplate(wr io.Writer, name string, data interface{}) error
        Lookup(name string) *template.Template
@@ -278,11 +397,11 @@ func NewTemplate() Template {
                "first":       First,
                "where":       Where,
                "highlight":   Highlight,
-               "add":         func(a, b int) int { return a + b },
-               "sub":         func(a, b int) int { return a - b },
-               "div":         func(a, b int) int { return a / b },
+               "add":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '+') },
+               "sub":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '-') },
+               "div":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '/') },
                "mod":         func(a, b int) int { return a % b },
-               "mul":         func(a, b int) int { return a * b },
+               "mul":         func(a, b interface{}) (interface{}, error) { return doArithmetic(a, b, '*') },
                "modBool":     func(a, b int) bool { return a%b == 0 },
                "lower":       func(a string) string { return strings.ToLower(a) },
                "upper":       func(a string) string { return strings.ToUpper(a) },
index 9a34e99ded2e01c86228fc16f563da91ce27040b..57c132f967c3939af244e97753c88a41c960956f 100644 (file)
@@ -35,6 +35,94 @@ func TestGt(t *testing.T) {
        }
 }
 
+func TestDoArithmetic(t *testing.T) {
+       for i, this := range []struct {
+               a      interface{}
+               b      interface{}
+               op     rune
+               expect interface{}
+       }{
+               {3, 2, '+', int64(5)},
+               {3, 2, '-', int64(1)},
+               {3, 2, '*', int64(6)},
+               {3, 2, '/', int64(1)},
+               {3.0, 2, '+', float64(5)},
+               {3.0, 2, '-', float64(1)},
+               {3.0, 2, '*', float64(6)},
+               {3.0, 2, '/', float64(1.5)},
+               {3, 2.0, '+', float64(5)},
+               {3, 2.0, '-', float64(1)},
+               {3, 2.0, '*', float64(6)},
+               {3, 2.0, '/', float64(1.5)},
+               {3.0, 2.0, '+', float64(5)},
+               {3.0, 2.0, '-', float64(1)},
+               {3.0, 2.0, '*', float64(6)},
+               {3.0, 2.0, '/', float64(1.5)},
+               {uint(3), uint(2), '+', uint64(5)},
+               {uint(3), uint(2), '-', uint64(1)},
+               {uint(3), uint(2), '*', uint64(6)},
+               {uint(3), uint(2), '/', uint64(1)},
+               {uint(3), 2, '+', uint64(5)},
+               {uint(3), 2, '-', uint64(1)},
+               {uint(3), 2, '*', uint64(6)},
+               {uint(3), 2, '/', uint64(1)},
+               {3, uint(2), '+', uint64(5)},
+               {3, uint(2), '-', uint64(1)},
+               {3, uint(2), '*', uint64(6)},
+               {3, uint(2), '/', uint64(1)},
+               {uint(3), -2, '+', int64(1)},
+               {uint(3), -2, '-', int64(5)},
+               {uint(3), -2, '*', int64(-6)},
+               {uint(3), -2, '/', int64(-1)},
+               {-3, uint(2), '+', int64(-1)},
+               {-3, uint(2), '-', int64(-5)},
+               {-3, uint(2), '*', int64(-6)},
+               {-3, uint(2), '/', int64(-1)},
+               {uint(3), 2.0, '+', float64(5)},
+               {uint(3), 2.0, '-', float64(1)},
+               {uint(3), 2.0, '*', float64(6)},
+               {uint(3), 2.0, '/', float64(1.5)},
+               {3.0, uint(2), '+', float64(5)},
+               {3.0, uint(2), '-', float64(1)},
+               {3.0, uint(2), '*', float64(6)},
+               {3.0, uint(2), '/', float64(1.5)},
+               {0, 0, '+', 0},
+               {0, 0, '-', 0},
+               {0, 0, '*', 0},
+               {"foo", "bar", '+', "foobar"},
+               {3, 0, '/', false},
+               {3.0, 0, '/', false},
+               {3, 0.0, '/', false},
+               {uint(3), uint(0), '/', false},
+               {3, uint(0), '/', false},
+               {-3, uint(0), '/', false},
+               {uint(3), 0, '/', false},
+               {3.0, uint(0), '/', false},
+               {uint(3), 0.0, '/', false},
+               {3, "foo", '+', false},
+               {3.0, "foo", '+', false},
+               {uint(3), "foo", '+', false},
+               {"foo", 3, '+', false},
+               {"foo", "bar", '-', false},
+               {3, 2, '%', false},
+       } {
+               result, err := doArithmetic(this.a, this.b, this.op)
+               if b, ok := this.expect.(bool); ok && !b {
+                       if err == nil {
+                               t.Errorf("[%d] doArithmetic didn't return an expected error")
+                       }
+               } else {
+                       if err != nil {
+                               t.Errorf("[%d] failed: %s", i, err)
+                               continue
+                       }
+                       if !reflect.DeepEqual(result, this.expect) {
+                               t.Errorf("[%d] doArithmetic got %v but expected %v", i, result, this.expect)
+                       }
+               }
+       }
+}
+
 func TestFirst(t *testing.T) {
        for i, this := range []struct {
                count    int