Add 'where' template function
authorTatsushi Demachi <tdemachi@gmail.com>
Sat, 16 Aug 2014 04:12:34 +0000 (13:12 +0900)
committerspf13 <steve.francia@gmail.com>
Mon, 18 Aug 2014 15:31:17 +0000 (11:31 -0400)
hugolib/template.go
hugolib/template_test.go

index 1bf3fe110079d7cdca8ac5ccf274edefb1a71362..23cb7a680d2c62adfa315ac8171b77528ed6eb24 100644 (file)
@@ -109,6 +109,71 @@ func First(limit int, seq interface{}) (interface{}, error) {
        return seqv.Slice(0, limit).Interface(), nil
 }
 
+func Where(seq, key, match interface{}) (interface{}, error) {
+       seqv := reflect.ValueOf(seq)
+       kv := reflect.ValueOf(key)
+       mv := reflect.ValueOf(match)
+
+       // this is better than my first pass; ripped from text/template/exec.go indirect():
+       for ; seqv.Kind() == reflect.Ptr || seqv.Kind() == reflect.Interface; seqv = seqv.Elem() {
+               if seqv.IsNil() {
+                       return nil, errors.New("can't iterate over a nil value")
+               }
+               if seqv.Kind() == reflect.Interface && seqv.NumMethod() > 0 {
+                       break
+               }
+       }
+
+       switch seqv.Kind() {
+       case reflect.Array, reflect.Slice:
+               r := reflect.MakeSlice(seqv.Type(), 0, 0)
+               for i := 0; i < seqv.Len(); i++ {
+                       var vvv reflect.Value
+                       vv := seqv.Index(i)
+                       switch vv.Kind() {
+                       case reflect.Map:
+                               if kv.Type() == vv.Type().Key() && vv.MapIndex(kv).IsValid() {
+                                       vvv = vv.MapIndex(kv)
+                               }
+                       case reflect.Struct:
+                               if kv.Kind() == reflect.String && vv.FieldByName(kv.String()).IsValid() {
+                                       vvv = vv.FieldByName(kv.String())
+                               }
+                       case reflect.Ptr:
+                               if !vv.IsNil() {
+                                       ev := vv.Elem()
+                                       switch ev.Kind() {
+                                       case reflect.Map:
+                                               if kv.Type() == ev.Type().Key() && ev.MapIndex(kv).IsValid() {
+                                                       vvv = ev.MapIndex(kv)
+                                               }
+                                       case reflect.Struct:
+                                               if kv.Kind() == reflect.String && ev.FieldByName(kv.String()).IsValid() {
+                                                       vvv = ev.FieldByName(kv.String())
+                                               }
+                                       }
+                               }
+                       }
+
+                       if vvv.IsValid() && mv.Type() == vvv.Type() {
+                               switch mv.Kind() {
+                               case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+                                       if mv.Int() == vvv.Int() {
+                                               r = reflect.Append(r, vv)
+                                       }
+                               case reflect.String:
+                                       if mv.String() == vvv.String() {
+                                               r = reflect.Append(r, vv)
+                                       }
+                               }
+                       }
+               }
+               return r.Interface(), nil
+       default:
+               return nil, errors.New("can't iterate over " + reflect.ValueOf(seq).Type().String())
+       }
+}
+
 func IsSet(a interface{}, key interface{}) bool {
        av := reflect.ValueOf(a)
        kv := reflect.ValueOf(key)
@@ -211,6 +276,7 @@ func NewTemplate() Template {
                "echoParam":   ReturnWhenSet,
                "safeHtml":    SafeHtml,
                "first":       First,
+               "where":       Where,
                "highlight":   Highlight,
                "add":         func(a, b int) int { return a + b },
                "sub":         func(a, b int) int { return a - b },
index 029e2a49f706cee009a12d1aeeada10d191e3f05..9a34e99ded2e01c86228fc16f563da91ce27040b 100644 (file)
@@ -55,3 +55,30 @@ func TestFirst(t *testing.T) {
                }
        }
 }
+
+func TestWhere(t *testing.T) {
+       type X struct {
+               A, B string
+       }
+       for i, this := range []struct {
+               sequence interface{}
+               key      interface{}
+               match    interface{}
+               expect   interface{}
+       }{
+               {[]map[int]string{{1: "a", 2: "m"}, {1: "c", 2: "d"}, {1: "e", 3: "m"}}, 2, "m", []map[int]string{{1: "a", 2: "m"}}},
+               {[]map[string]int{{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "x": 4}}, "b", 4, []map[string]int{{"a": 3, "b": 4}}},
+               {[]X{{"a", "b"}, {"c", "d"}, {"e", "f"}}, "B", "f", []X{{"e", "f"}}},
+               {[]*map[int]string{&map[int]string{1: "a", 2: "m"}, &map[int]string{1: "c", 2: "d"}, &map[int]string{1: "e", 3: "m"}}, 2, "m", []*map[int]string{&map[int]string{1: "a", 2: "m"}}},
+               {[]*X{&X{"a", "b"}, &X{"c", "d"}, &X{"e", "f"}}, "B", "f", []*X{&X{"e", "f"}}},
+       } {
+               results, err := Where(this.sequence, this.key, this.match)
+               if err != nil {
+                       t.Errorf("[%d] failed: %s", i, err)
+                       continue
+               }
+               if !reflect.DeepEqual(results, this.expect) {
+                       t.Errorf("[%d] Where clause matching %v with %v, got %v but expected %v", i, this.key, this.match, results, this.expect)
+               }
+       }
+}