Make 'where' template function accepts dot chaining key argument
authorTatsushi Demachi <tdemachi@gmail.com>
Mon, 29 Dec 2014 02:33:12 +0000 (11:33 +0900)
committerbep <bjorn.erik.pedersen@gmail.com>
Mon, 29 Dec 2014 11:53:41 +0000 (12:53 +0100)
'where' template function used to accept only each element's struct
field name, method name and map key name as its second argument. This
extends it to accept dot chaining key like 'Params.foo.bar' as the
argument. It evaluates sub elements of each array elements and checks it
matches the third argument value.

Typical use case would be for filtering Pages by user defined front
matter value. For example, to filter pages which have 'Params.foo.bar'
and its value is 'baz', it is used like

    {{ range where .Data.Pages "Params.foo.bar" "baz" }}
        {{ .Content }}
    {{ end }}

It ignores all leading and trailing dots so it can also be used with
".Params.foo.bar"

docs/content/templates/functions.md
tpl/template.go
tpl/template_test.go

index 819a5c7a9f410556f15fc7e79a5e98685712c838..4138ffde9caf4e926a33dcce5a663241c4f25c29 100644 (file)
@@ -66,6 +66,19 @@ e.g.
        {{ .Content}}
     {{ end }}
 
+It can be used with dot chaining second argument to refer a nested element of a value.
+
+e.g.
+
+    // Front matter on some pages
+    +++
+    series: golang
+    +++
+
+    {{ range where .Site.Recent "Params.series" "golang" }}
+       {{ .Content}}
+    {{ end }}
+
 *where and first can be stacked*
 
 e.g.
index aef6c3ba66f89c45dc2097180d9ae77c6b4e984d..1b8107f37ff9fc8055546bcd98badc935c8e607b 100644 (file)
@@ -289,6 +289,19 @@ func In(l interface{}, v interface{}) bool {
        return false
 }
 
+// indirect is taken from 'text/template/exec.go'
+func indirect(v reflect.Value) (rv reflect.Value, isNil bool) {
+       for ; v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface; v = v.Elem() {
+               if v.IsNil() {
+                       return v, true
+               }
+               if v.Kind() == reflect.Interface && v.NumMethod() > 0 {
+                       break
+               }
+       }
+       return v, false
+}
+
 // First is exposed to templates, to iterate over the first N items in a
 // rangeable list.
 func First(limit interface{}, seq interface{}) (interface{}, error) {
@@ -326,76 +339,122 @@ func First(limit interface{}, seq interface{}) (interface{}, error) {
        return seqv.Slice(0, limitv).Interface(), nil
 }
 
-func Where(seq, key, match interface{}) (interface{}, error) {
+var (
+       zero reflect.Value
+       errorType = reflect.TypeOf((*error)(nil)).Elem()
+)
+
+func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) {
+       if !obj.IsValid() {
+               return zero, errors.New("can't evaluate an invalid value")
+       }
+       typ := obj.Type()
+       obj, isNil := indirect(obj)
+
+       // first, check whether obj has a method. In this case, obj is
+       // an interface, a struct or its pointer. If obj is a struct,
+       // to check all T and *T method, use obj pointer type Value
+       objPtr := obj
+       if objPtr.Kind() != reflect.Interface && objPtr.CanAddr() {
+               objPtr = objPtr.Addr()
+       }
+       mt, ok := objPtr.Type().MethodByName(elemName)
+       if ok {
+               if mt.PkgPath != "" {
+                       return zero, fmt.Errorf("%s is an unexported method of type %s", elemName, typ)
+               }
+               // struct pointer has one receiver argument and interface doesn't have an argument
+               if mt.Type.NumIn() > 1 || mt.Type.NumOut() == 0 || mt.Type.NumOut() > 2 {
+                       return zero, fmt.Errorf("%s is a method of type %s but doesn't satisfy requirements", elemName, typ)
+               }
+               if mt.Type.NumOut() == 1 && mt.Type.Out(0).Implements(errorType) {
+                       return zero, fmt.Errorf("%s is a method of type %s but doesn't satisfy requirements", elemName, typ)
+               }
+               if mt.Type.NumOut() == 2 && !mt.Type.Out(1).Implements(errorType) {
+                       return zero, fmt.Errorf("%s is a method of type %s but doesn't satisfy requirements", elemName, typ)
+               }
+               res := objPtr.Method(mt.Index).Call([]reflect.Value{})
+               if len(res) == 2 && !res[1].IsNil() {
+                       return zero, fmt.Errorf("error at calling a method %s of type %s: %s", elemName, typ, res[1].Interface().(error))
+               }
+               return res[0], nil
+       }
+
+       // elemName isn't a method so next start to check whether it is
+       // a struct field or a map value. In both cases, it mustn't be
+       // a nil value
+       if isNil {
+               return zero, fmt.Errorf("can't evaluate a nil pointer of type %s by a struct field or map key name %s", typ, elemName)
+       }
+       switch obj.Kind() {
+       case reflect.Struct:
+               ft, ok := obj.Type().FieldByName(elemName)
+               if ok {
+                       if ft.PkgPath != "" {
+                               return zero, fmt.Errorf("%s is an unexported field of struct type %s", elemName, typ)
+                       }
+                       return obj.FieldByIndex(ft.Index), nil
+               }
+               return zero, fmt.Errorf("%s isn't a field of struct type %s", elemName, typ)
+       case reflect.Map:
+               kv := reflect.ValueOf(elemName)
+               if kv.Type().AssignableTo(obj.Type().Key()) {
+                       return obj.MapIndex(kv), nil
+               }
+               return zero, fmt.Errorf("%s isn't a key of map type %s", elemName, typ)
+       }
+       return zero, fmt.Errorf("%s is neither a struct field, a method nor a map element of type %s", elemName, typ)
+}
+
+func Where(seq, key, match interface{}) (r interface{}, err error) {
        seqv := reflect.ValueOf(seq)
        kv := reflect.ValueOf(key)
        mv := reflect.ValueOf(match)
 
-       // this is better than my first pass; ripped from text/template/exec.go indirect():
-       for ; seqv.Kind() == reflect.Ptr || seqv.Kind() == reflect.Interface; seqv = seqv.Elem() {
-               if seqv.IsNil() {
-                       return nil, errors.New("can't iterate over a nil value")
-               }
-               if seqv.Kind() == reflect.Interface && seqv.NumMethod() > 0 {
-                       break
-               }
+       seqv, isNil := indirect(seqv)
+       if isNil {
+               return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(seq).Type().String())
+       }
+
+       var path []string
+       if kv.Kind() == reflect.String {
+               path = strings.Split(strings.Trim(kv.String(), "."), ".")
        }
 
        switch seqv.Kind() {
        case reflect.Array, reflect.Slice:
-               r := reflect.MakeSlice(seqv.Type(), 0, 0)
+               rv := reflect.MakeSlice(seqv.Type(), 0, 0)
                for i := 0; i < seqv.Len(); i++ {
                        var vvv reflect.Value
-                       vv := seqv.Index(i)
-                       switch vv.Kind() {
-                       case reflect.Map:
-                               if kv.Type() == vv.Type().Key() && vv.MapIndex(kv).IsValid() {
-                                       vvv = vv.MapIndex(kv)
-                               }
-                       case reflect.Struct:
-                               if kv.Kind() == reflect.String {
-                                       method := vv.MethodByName(kv.String())
-                                       if method.IsValid() && method.Type().NumIn() == 0 && method.Type().NumOut() > 0 {
-                                               vvv = method.Call(nil)[0]
-                                       } else if vv.FieldByName(kv.String()).IsValid() {
-                                               vvv = vv.FieldByName(kv.String())
+                       rvv := seqv.Index(i)
+                       if kv.Kind() == reflect.String {
+                               vvv = rvv
+                               for _, elemName := range path {
+                                       vvv, err = evaluateSubElem(vvv, elemName)
+                                       if err != nil {
+                                               return nil, err
                                        }
                                }
-                       case reflect.Ptr:
-                               if !vv.IsNil() {
-                                       ev := vv.Elem()
-                                       switch ev.Kind() {
-                                       case reflect.Map:
-                                               if kv.Type() == ev.Type().Key() && ev.MapIndex(kv).IsValid() {
-                                                       vvv = ev.MapIndex(kv)
-                                               }
-                                       case reflect.Struct:
-                                               if kv.Kind() == reflect.String {
-                                                       method := vv.MethodByName(kv.String())
-                                                       if method.IsValid() && method.Type().NumIn() == 0 && method.Type().NumOut() > 0 {
-                                                               vvv = method.Call(nil)[0]
-                                                       } else if ev.FieldByName(kv.String()).IsValid() {
-                                                               vvv = ev.FieldByName(kv.String())
-                                                       }
-                                               }
-                                       }
+                       } else {
+                               vv, _ := indirect(rvv)
+                               if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
+                                       vvv = vv.MapIndex(kv)
                                }
                        }
-
                        if vvv.IsValid() && mv.Type() == vvv.Type() {
                                switch mv.Kind() {
                                case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                                        if mv.Int() == vvv.Int() {
-                                               r = reflect.Append(r, vv)
+                                               rv = reflect.Append(rv, rvv)
                                        }
                                case reflect.String:
                                        if mv.String() == vvv.String() {
-                                               r = reflect.Append(r, vv)
+                                               rv = reflect.Append(rv, rvv)
                                        }
                                }
                        }
                }
-               return r.Interface(), nil
+               return rv.Interface(), nil
        default:
                return nil, errors.New("can't iterate over " + reflect.ValueOf(seq).Type().String())
        }
index 30d721b6ca015273807cd0fc5a83d54bcfe79df3..00327ef7683e60cb5f15b0a02ecfa831d20afbc0 100644 (file)
@@ -1,6 +1,8 @@
 package tpl
 
 import (
+       "errors"
+       "fmt"
        "html/template"
        "reflect"
        "testing"
@@ -305,8 +307,85 @@ func (x TstX) TstRv() string {
        return "r" + x.B
 }
 
+func (x TstX) unexportedMethod() string {
+       return x.unexported
+}
+
+func (x TstX) MethodWithArg(s string) string {
+       return s
+}
+
+func (x TstX) MethodReturnNothing() {}
+
+func (x TstX) MethodReturnErrorOnly() error {
+       return errors.New("something error occured")
+}
+
+func (x TstX) MethodReturnTwoValues() (string, string) {
+       return "foo", "bar"
+}
+
+func (x TstX) MethodReturnValueWithError() (string, error) {
+       return "", errors.New("something error occured")
+}
+
+func (x TstX) String() string {
+       return fmt.Sprintf("A: %s, B: %s", x.A, x.B)
+}
+
 type TstX struct {
        A, B string
+       unexported string
+}
+
+func TestEvaluateSubElem(t *testing.T) {
+       tstx := TstX{A: "foo", B: "bar"}
+       var inner struct {
+               S fmt.Stringer
+       }
+       inner.S = tstx
+       interfaceValue := reflect.ValueOf(&inner).Elem().Field(0)
+
+       for i, this := range []struct {
+               value  reflect.Value
+               key    string
+               expect interface{}
+       }{
+               {reflect.ValueOf(tstx), "A", "foo"},
+               {reflect.ValueOf(&tstx), "TstRp", "rfoo"},
+               {reflect.ValueOf(tstx), "TstRv", "rbar"},
+               //{reflect.ValueOf(map[int]string{1: "foo", 2: "bar"}), 1, "foo"},
+               {reflect.ValueOf(map[string]string{"key1": "foo", "key2": "bar"}), "key1", "foo"},
+               {interfaceValue, "String", "A: foo, B: bar"},
+               {reflect.Value{}, "foo", false},
+               //{reflect.ValueOf(map[int]string{1: "foo", 2: "bar"}), 1.2, false},
+               {reflect.ValueOf(tstx), "unexported", false},
+               {reflect.ValueOf(tstx), "unexportedMethod", false},
+               {reflect.ValueOf(tstx), "MethodWithArg", false},
+               {reflect.ValueOf(tstx), "MethodReturnNothing", false},
+               {reflect.ValueOf(tstx), "MethodReturnErrorOnly", false},
+               {reflect.ValueOf(tstx), "MethodReturnTwoValues", false},
+               {reflect.ValueOf(tstx), "MethodReturnValueWithError", false},
+               {reflect.ValueOf((*TstX)(nil)), "A", false},
+               {reflect.ValueOf(tstx), "C", false},
+               {reflect.ValueOf(map[int]string{1: "foo", 2: "bar"}), "1", false},
+               {reflect.ValueOf([]string{"foo", "bar"}), "1", false},
+       } {
+               result, err := evaluateSubElem(this.value, this.key)
+               if b, ok := this.expect.(bool); ok && !b {
+                       if err == nil {
+                               t.Errorf("[%d] evaluateSubElem didn't return an expected error", i)
+                       }
+               } else {
+                       if err != nil {
+                               t.Errorf("[%d] failed: %s", i, err)
+                               continue
+                       }
+                       if result.Kind() != reflect.String || result.String() != this.expect {
+                               t.Errorf("[%d] evaluateSubElem with %v got %v but expected %v", i, this.key, result, this.expect)
+                       }
+               }
+       }
 }
 
 func TestWhere(t *testing.T) {
@@ -314,6 +393,10 @@ func TestWhere(t *testing.T) {
        //page1 := &Page{contentType: "v", Source: Source{File: *source.NewFile("/x/y/z/source.md")}}
        //page2 := &Page{contentType: "w", Source: Source{File: *source.NewFile("/y/z/a/source.md")}}
 
+       type Mid struct {
+               Tst TstX
+       }
+
        for i, this := range []struct {
                sequence interface{}
                key      interface{}
@@ -322,21 +405,37 @@ func TestWhere(t *testing.T) {
        }{
                {[]map[int]string{{1: "a", 2: "m"}, {1: "c", 2: "d"}, {1: "e", 3: "m"}}, 2, "m", []map[int]string{{1: "a", 2: "m"}}},
                {[]map[string]int{{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "x": 4}}, "b", 4, []map[string]int{{"a": 3, "b": 4}}},
-               {[]TstX{{"a", "b"}, {"c", "d"}, {"e", "f"}}, "B", "f", []TstX{{"e", "f"}}},
+               {[]TstX{{A: "a", B: "b"}, {A: "c", B: "d"}, {A: "e", B: "f"}}, "B", "f", []TstX{{A: "e", B: "f"}}},
                {[]*map[int]string{&map[int]string{1: "a", 2: "m"}, &map[int]string{1: "c", 2: "d"}, &map[int]string{1: "e", 3: "m"}}, 2, "m", []*map[int]string{&map[int]string{1: "a", 2: "m"}}},
-               {[]*TstX{&TstX{"a", "b"}, &TstX{"c", "d"}, &TstX{"e", "f"}}, "B", "f", []*TstX{&TstX{"e", "f"}}},
-               {[]*TstX{&TstX{"a", "b"}, &TstX{"c", "d"}, &TstX{"e", "c"}}, "TstRp", "rc", []*TstX{&TstX{"c", "d"}}},
-               {[]TstX{TstX{"a", "b"}, TstX{"c", "d"}, TstX{"e", "c"}}, "TstRv", "rc", []TstX{TstX{"e", "c"}}},
+               {[]*TstX{&TstX{A: "a", B: "b"}, &TstX{A: "c", B: "d"}, &TstX{A: "e", B: "f"}}, "B", "f", []*TstX{&TstX{A: "e", B: "f"}}},
+               {[]*TstX{&TstX{A: "a", B: "b"}, &TstX{A: "c", B: "d"}, &TstX{A: "e", B: "c"}}, "TstRp", "rc", []*TstX{&TstX{A: "c", B: "d"}}},
+               {[]TstX{TstX{A: "a", B: "b"}, TstX{A: "c", B: "d"}, TstX{A: "e", B: "c"}}, "TstRv", "rc", []TstX{TstX{A: "e", B: "c"}}},
+               {[]map[string]TstX{{"foo": TstX{A: "a", B: "b"}}, {"foo": TstX{A: "c", B: "d"}}, {"foo": TstX{A: "e", B: "f"}}}, "foo.B", "d", []map[string]TstX{{"foo": TstX{A: "c", B: "d"}}}},
+               {[]map[string]TstX{{"foo": TstX{A: "a", B: "b"}}, {"foo": TstX{A: "c", B: "d"}}, {"foo": TstX{A: "e", B: "f"}}}, ".foo.B", "d", []map[string]TstX{{"foo": TstX{A: "c", B: "d"}}}},
+               {[]map[string]TstX{{"foo": TstX{A: "a", B: "b"}}, {"foo": TstX{A: "c", B: "d"}}, {"foo": TstX{A: "e", B: "f"}}}, "foo.TstRv", "rd", []map[string]TstX{{"foo": TstX{A: "c", B: "d"}}}},
+               {[]map[string]*TstX{{"foo": &TstX{A: "a", B: "b"}}, {"foo": &TstX{A: "c", B: "d"}}, {"foo": &TstX{A: "e", B: "f"}}}, "foo.TstRp", "rc", []map[string]*TstX{{"foo": &TstX{A: "c", B: "d"}}}},
+               {[]map[string]Mid{{"foo": Mid{Tst: TstX{A: "a", B: "b"}}}, {"foo": Mid{Tst: TstX{A: "c", B: "d"}}}, {"foo": Mid{Tst: TstX{A: "e", B: "f"}}}}, "foo.Tst.B", "d", []map[string]Mid{{"foo": Mid{Tst: TstX{A: "c", B: "d"}}}}},
+               {[]map[string]Mid{{"foo": Mid{Tst: TstX{A: "a", B: "b"}}}, {"foo": Mid{Tst: TstX{A: "c", B: "d"}}}, {"foo": Mid{Tst: TstX{A: "e", B: "f"}}}}, "foo.Tst.TstRv", "rd", []map[string]Mid{{"foo": Mid{Tst: TstX{A: "c", B: "d"}}}}},
+               {[]map[string]*Mid{{"foo": &Mid{Tst: TstX{A: "a", B: "b"}}}, {"foo": &Mid{Tst: TstX{A: "c", B: "d"}}}, {"foo": &Mid{Tst: TstX{A: "e", B: "f"}}}}, "foo.Tst.TstRp", "rc", []map[string]*Mid{{"foo": &Mid{Tst: TstX{A: "c", B: "d"}}}}},
+               {(*[]TstX)(nil), "A", "a", false},
+               {TstX{A: "a", B: "b"}, "A", "a", false},
+               {[]map[string]*TstX{{"foo": nil}}, "foo.B", "d", false},
                //{[]*Page{page1, page2}, "Type", "v", []*Page{page1}},
                //{[]*Page{page1, page2}, "Section", "y", []*Page{page2}},
        } {
                results, err := Where(this.sequence, this.key, this.match)
-               if err != nil {
-                       t.Errorf("[%d] failed: %s", i, err)
-                       continue
-               }
-               if !reflect.DeepEqual(results, this.expect) {
-                       t.Errorf("[%d] Where clause matching %v with %v, got %v but expected %v", i, this.key, this.match, results, this.expect)
+               if b, ok := this.expect.(bool); ok && !b {
+                       if err == nil {
+                               t.Errorf("[%d] Where didn't return an expected error", i)
+                       }
+               } else {
+                       if err != nil {
+                               t.Errorf("[%d] failed: %s", i, err)
+                               continue
+                       }
+                       if !reflect.DeepEqual(results, this.expect) {
+                               t.Errorf("[%d] Where clause matching %v with %v, got %v but expected %v", i, this.key, this.match, results, this.expect)
+                       }
                }
        }
 }