Make where accept slice
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Sat, 5 Mar 2016 23:35:35 +0000 (00:35 +0100)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Sun, 6 Mar 2016 12:15:07 +0000 (13:15 +0100)
Fixes #1926

tpl/reflect_helpers.go [new file with mode: 0644]
tpl/template_funcs.go
tpl/template_funcs_test.go

diff --git a/tpl/reflect_helpers.go b/tpl/reflect_helpers.go
new file mode 100644 (file)
index 0000000..f2ce722
--- /dev/null
@@ -0,0 +1,70 @@
+// Copyright 2016 The Hugo Authors. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tpl
+
+import (
+       "reflect"
+       "time"
+)
+
+// toInt returns the int value if possible, -1 if not.
+func toInt(v reflect.Value) int64 {
+       switch v.Kind() {
+       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+               return v.Int()
+       case reflect.Interface:
+               return toInt(v.Elem())
+       }
+       return -1
+}
+
+// toString returns the string value if possible, "" if not.
+func toString(v reflect.Value) string {
+       switch v.Kind() {
+       case reflect.String:
+               return v.String()
+       case reflect.Interface:
+               return toString(v.Elem())
+       }
+       return ""
+}
+
+var (
+       zero      reflect.Value
+       errorType = reflect.TypeOf((*error)(nil)).Elem()
+       timeType  = reflect.TypeOf((*time.Time)(nil)).Elem()
+)
+
+func toTimeUnix(v reflect.Value) int64 {
+       if v.Kind() == reflect.Interface {
+               return toTimeUnix(v.Elem())
+       }
+       if v.Type() != timeType {
+               panic("coding error: argument must be time.Time type reflect Value")
+       }
+       return v.MethodByName("Unix").Call([]reflect.Value{})[0].Int()
+}
+
+// indirect is taken from 'text/template/exec.go'
+func indirect(v reflect.Value) (rv reflect.Value, isNil bool) {
+       for ; v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface; v = v.Elem() {
+               if v.IsNil() {
+                       return v, true
+               }
+               if v.Kind() == reflect.Interface && v.NumMethod() > 0 {
+                       break
+               }
+       }
+       return v, false
+}
index 9b2bbec41e07fa0b491855d42e98abae2a629502..4d88336d043e2d70b8ab7df822d2e1f0ed416f73 100644 (file)
@@ -132,7 +132,7 @@ func compareGetFloat(a interface{}, b interface{}) (float64, float64) {
        case reflect.Struct:
                switch av.Type() {
                case timeType:
-                       left = float64(timeUnix(av))
+                       left = float64(toTimeUnix(av))
                }
        }
 
@@ -155,7 +155,7 @@ func compareGetFloat(a interface{}, b interface{}) (float64, float64) {
        case reflect.Struct:
                switch bv.Type() {
                case timeType:
-                       right = float64(timeUnix(bv))
+                       right = float64(toTimeUnix(bv))
                }
        }
 
@@ -393,19 +393,6 @@ func in(l interface{}, v interface{}) bool {
        return false
 }
 
-// indirect is taken from 'text/template/exec.go'
-func indirect(v reflect.Value) (rv reflect.Value, isNil bool) {
-       for ; v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface; v = v.Elem() {
-               if v.IsNil() {
-                       return v, true
-               }
-               if v.Kind() == reflect.Interface && v.NumMethod() > 0 {
-                       break
-               }
-       }
-       return v, false
-}
-
 // first returns the first N items in a rangeable list.
 func first(limit interface{}, seq interface{}) (interface{}, error) {
        if limit == nil || seq == nil {
@@ -539,19 +526,6 @@ func shuffle(seq interface{}) (interface{}, error) {
        return shuffled.Interface(), nil
 }
 
-var (
-       zero      reflect.Value
-       errorType = reflect.TypeOf((*error)(nil)).Elem()
-       timeType  = reflect.TypeOf((*time.Time)(nil)).Elem()
-)
-
-func timeUnix(v reflect.Value) int64 {
-       if v.Type() != timeType {
-               panic("coding error: argument must be time.Time type reflect Value")
-       }
-       return v.MethodByName("Unix").Call([]reflect.Value{})[0].Int()
-}
-
 func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) {
        if !obj.IsValid() {
                return zero, errors.New("can't evaluate an invalid value")
@@ -662,9 +636,9 @@ func checkCondition(v, mv reflect.Value, op string) (bool, error) {
                case reflect.Struct:
                        switch v.Type() {
                        case timeType:
-                               iv := timeUnix(v)
+                               iv := toTimeUnix(v)
                                ivp = &iv
-                               imv := timeUnix(mv)
+                               imv := toTimeUnix(mv)
                                imvp = &imv
                        }
                }
@@ -672,7 +646,12 @@ func checkCondition(v, mv reflect.Value, op string) (bool, error) {
                if mv.Kind() != reflect.Array && mv.Kind() != reflect.Slice {
                        return false, nil
                }
-               if mv.Type().Elem() != v.Type() {
+
+               if mv.Len() == 0 {
+                       return false, nil
+               }
+
+               if v.Kind() != reflect.Interface && mv.Type().Elem().Kind() != reflect.Interface && mv.Type().Elem() != v.Type() {
                        return false, nil
                }
                switch v.Kind() {
@@ -680,21 +659,26 @@ func checkCondition(v, mv reflect.Value, op string) (bool, error) {
                        iv := v.Int()
                        ivp = &iv
                        for i := 0; i < mv.Len(); i++ {
-                               ima = append(ima, mv.Index(i).Int())
+                               if anInt := toInt(mv.Index(i)); anInt != -1 {
+                                       ima = append(ima, anInt)
+                               }
+
                        }
                case reflect.String:
                        sv := v.String()
                        svp = &sv
                        for i := 0; i < mv.Len(); i++ {
-                               sma = append(sma, mv.Index(i).String())
+                               if aString := toString(mv.Index(i)); aString != "" {
+                                       sma = append(sma, aString)
+                               }
                        }
                case reflect.Struct:
                        switch v.Type() {
                        case timeType:
-                               iv := timeUnix(v)
+                               iv := toTimeUnix(v)
                                ivp = &iv
                                for i := 0; i < mv.Len(); i++ {
-                                       ima = append(ima, timeUnix(mv.Index(i)))
+                                       ima = append(ima, toTimeUnix(mv.Index(i)))
                                }
                        }
                }
index a0f75da65f76048ab63318c6f6410fec19b34a8c..602fbc5dc46e7a4b5da1792927a69fa3a9d3f5e6 100644 (file)
@@ -848,7 +848,7 @@ func TestTimeUnix(t *testing.T) {
        tv := reflect.ValueOf(time.Unix(sec, 0))
        i := 1
 
-       res := timeUnix(tv)
+       res := toTimeUnix(tv)
        if sec != res {
                t.Errorf("[%d] timeUnix got %v but expected %v", i, res, sec)
        }
@@ -861,7 +861,7 @@ func TestTimeUnix(t *testing.T) {
                        }
                }()
                iv := reflect.ValueOf(sec)
-               timeUnix(iv)
+               toTimeUnix(iv)
        }(t)
 }
 
@@ -1036,14 +1036,18 @@ func TestCheckCondition(t *testing.T) {
 }
 
 func TestWhere(t *testing.T) {
-       // TODO(spf): Put these page tests back in
-       //page1 := &Page{contentType: "v", Source: Source{File: *source.NewFile("/x/y/z/source.md")}}
-       //page2 := &Page{contentType: "w", Source: Source{File: *source.NewFile("/y/z/a/source.md")}}
 
        type Mid struct {
                Tst TstX
        }
 
+       d1 := time.Now()
+       d2 := d1.Add(1 * time.Hour)
+       d3 := d2.Add(1 * time.Hour)
+       d4 := d3.Add(1 * time.Hour)
+       d5 := d4.Add(1 * time.Hour)
+       d6 := d5.Add(1 * time.Hour)
+
        for i, this := range []struct {
                sequence interface{}
                key      interface{}
@@ -1204,6 +1208,24 @@ func TestWhere(t *testing.T) {
                                {"a": 3, "b": 4},
                        },
                },
+               {
+                       sequence: []map[string]int{
+                               {"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6},
+                       },
+                       key: "b", op: "in", match: slice(3, 4, 5),
+                       expect: []map[string]int{
+                               {"a": 3, "b": 4},
+                       },
+               },
+               {
+                       sequence: []map[string]time.Time{
+                               {"a": d1, "b": d2}, {"a": d3, "b": d4}, {"a": d5, "b": d6},
+                       },
+                       key: "b", op: "in", match: slice(d3, d4, d5),
+                       expect: []map[string]time.Time{
+                               {"a": d3, "b": d4},
+                       },
+               },
                {
                        sequence: []TstX{
                                {A: "a", B: "b"}, {A: "c", B: "d"}, {A: "e", B: "f"},
@@ -1213,6 +1235,15 @@ func TestWhere(t *testing.T) {
                                {A: "a", B: "b"}, {A: "e", B: "f"},
                        },
                },
+               {
+                       sequence: []TstX{
+                               {A: "a", B: "b"}, {A: "c", B: "d"}, {A: "e", B: "f"},
+                       },
+                       key: "B", op: "not in", match: slice("c", t, "d", "e"),
+                       expect: []TstX{
+                               {A: "a", B: "b"}, {A: "e", B: "f"},
+                       },
+               },
                {
                        sequence: []map[string]int{
                                {"a": 1, "b": 2}, {"a": 3}, {"a": 5, "b": 6},
@@ -1273,11 +1304,10 @@ func TestWhere(t *testing.T) {
                        key: "B", op: "op", match: "f",
                        expect: false,
                },
-               //{[]*Page{page1, page2}, "Type", "v", []*Page{page1}},
-               //{[]*Page{page1, page2}, "Section", "y", []*Page{page2}},
        } {
                var results interface{}
                var err error
+
                if len(this.op) > 0 {
                        results, err = where(this.sequence, this.key, this.op, this.match)
                } else {