tpl: Extend where to iterate over maps
authorCameron Moore <moorereason@gmail.com>
Wed, 13 Apr 2016 01:31:14 +0000 (20:31 -0500)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Wed, 13 Apr 2016 09:43:06 +0000 (11:43 +0200)
Refactor and extend where to iterate over maps.

tpl/template_funcs.go
tpl/template_funcs_test.go

index 9131cc1ecbd4686e0b32af73d4c95f8b330a7f07..f6913510eaa4f38ed27e98fadc0d4539ab5d3b99 100644 (file)
@@ -37,13 +37,11 @@ import (
        "time"
        "unicode/utf8"
 
-       "github.com/spf13/afero"
-       "github.com/spf13/hugo/hugofs"
-
        "github.com/bep/inflect"
-
+       "github.com/spf13/afero"
        "github.com/spf13/cast"
        "github.com/spf13/hugo/helpers"
+       "github.com/spf13/hugo/hugofs"
        jww "github.com/spf13/jwalterweatherman"
 )
 
@@ -771,64 +769,125 @@ func checkCondition(v, mv reflect.Value, op string) (bool, error) {
        return false, nil
 }
 
-// where returns a filtered subset of a given data type.
-func where(seq, key interface{}, args ...interface{}) (r interface{}, err error) {
-       seqv := reflect.ValueOf(seq)
-       kv := reflect.ValueOf(key)
-
-       var mv reflect.Value
-       var op string
+// parseWhereArgs parses the end arguments to the where function.  Return a
+// match value and an operator, if one is defined.
+func parseWhereArgs(args ...interface{}) (mv reflect.Value, op string, err error) {
        switch len(args) {
        case 1:
                mv = reflect.ValueOf(args[0])
        case 2:
                var ok bool
                if op, ok = args[0].(string); !ok {
-                       return nil, errors.New("operator argument must be string type")
+                       err = errors.New("operator argument must be string type")
+                       return
                }
                op = strings.TrimSpace(strings.ToLower(op))
                mv = reflect.ValueOf(args[1])
        default:
-               return nil, errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
+               err = errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
        }
+       return
+}
 
-       seqv, isNil := indirect(seqv)
+// checkWhereArray handles the where-matching logic when the seqv value is an
+// Array or Slice.
+func checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
+       rv := reflect.MakeSlice(seqv.Type(), 0, 0)
+       for i := 0; i < seqv.Len(); i++ {
+               var vvv reflect.Value
+               rvv := seqv.Index(i)
+               if kv.Kind() == reflect.String {
+                       vvv = rvv
+                       for _, elemName := range path {
+                               var err error
+                               vvv, err = evaluateSubElem(vvv, elemName)
+                               if err != nil {
+                                       return nil, err
+                               }
+                       }
+               } else {
+                       vv, _ := indirect(rvv)
+                       if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
+                               vvv = vv.MapIndex(kv)
+                       }
+               }
+
+               if ok, err := checkCondition(vvv, mv, op); ok {
+                       rv = reflect.Append(rv, rvv)
+               } else if err != nil {
+                       return nil, err
+               }
+       }
+       return rv.Interface(), nil
+}
+
+// checkWhereMap handles the where-matching logic when the seqv value is a Map.
+func checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
+       rv := reflect.MakeMap(seqv.Type())
+       keys := seqv.MapKeys()
+       for _, k := range keys {
+               elemv := seqv.MapIndex(k)
+               switch elemv.Kind() {
+               case reflect.Array, reflect.Slice:
+                       r, err := checkWhereArray(elemv, kv, mv, path, op)
+                       if err != nil {
+                               return nil, err
+                       }
+
+                       switch rr := reflect.ValueOf(r); rr.Kind() {
+                       case reflect.Slice:
+                               if rr.Len() > 0 {
+                                       rv.SetMapIndex(k, elemv)
+                               }
+                       }
+               case reflect.Interface:
+                       elemvv, isNil := indirect(elemv)
+                       if isNil {
+                               continue
+                       }
+
+                       switch elemvv.Kind() {
+                       case reflect.Array, reflect.Slice:
+                               r, err := checkWhereArray(elemvv, kv, mv, path, op)
+                               if err != nil {
+                                       return nil, err
+                               }
+
+                               switch rr := reflect.ValueOf(r); rr.Kind() {
+                               case reflect.Slice:
+                                       if rr.Len() > 0 {
+                                               rv.SetMapIndex(k, elemv)
+                                       }
+                               }
+                       }
+               }
+       }
+       return rv, nil
+}
+
+// where returns a filtered subset of a given data type.
+func where(seq, key interface{}, args ...interface{}) (interface{}, error) {
+       seqv, isNil := indirect(reflect.ValueOf(seq))
        if isNil {
                return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(seq).Type().String())
        }
 
+       mv, op, err := parseWhereArgs(args...)
+       if err != nil {
+               return nil, err
+       }
+
        var path []string
+       kv := reflect.ValueOf(key)
        if kv.Kind() == reflect.String {
                path = strings.Split(strings.Trim(kv.String(), "."), ".")
        }
 
        switch seqv.Kind() {
        case reflect.Array, reflect.Slice:
-               rv := reflect.MakeSlice(seqv.Type(), 0, 0)
-               for i := 0; i < seqv.Len(); i++ {
-                       var vvv reflect.Value
-                       rvv := seqv.Index(i)
-                       if kv.Kind() == reflect.String {
-                               vvv = rvv
-                               for _, elemName := range path {
-                                       vvv, err = evaluateSubElem(vvv, elemName)
-                                       if err != nil {
-                                               return nil, err
-                                       }
-                               }
-                       } else {
-                               vv, _ := indirect(rvv)
-                               if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
-                                       vvv = vv.MapIndex(kv)
-                               }
-                       }
-                       if ok, err := checkCondition(vvv, mv, op); ok {
-                               rv = reflect.Append(rv, rvv)
-                       } else if err != nil {
-                               return nil, err
-                       }
-               }
-               return rv.Interface(), nil
+               return checkWhereArray(seqv, kv, mv, path, op)
+       case reflect.Map:
+               return checkWhereMap(seqv, kv, mv, path, op)
        default:
                return nil, fmt.Errorf("can't iterate over %v", seq)
        }
index 8d604e81757bcf14abfe743fa061eb5e2d7ec68e..5dbcf6cf6f66cc8a996fee1c0ee7cf0e8cc91187 100644 (file)
@@ -18,11 +18,6 @@ import (
        "encoding/base64"
        "errors"
        "fmt"
-       "github.com/spf13/afero"
-       "github.com/spf13/cast"
-       "github.com/spf13/hugo/hugofs"
-       "github.com/spf13/viper"
-       "github.com/stretchr/testify/assert"
        "html/template"
        "math/rand"
        "path"
@@ -32,6 +27,12 @@ import (
        "strings"
        "testing"
        "time"
+
+       "github.com/spf13/afero"
+       "github.com/spf13/cast"
+       "github.com/spf13/hugo/hugofs"
+       "github.com/spf13/viper"
+       "github.com/stretchr/testify/assert"
 )
 
 type tstNoStringer struct {
@@ -1298,6 +1299,17 @@ func TestWhere(t *testing.T) {
                        key: "B", op: "op", match: "f",
                        expect: false,
                },
+               {
+                       sequence: map[string]interface{}{
+                               "foo": []interface{}{map[interface{}]interface{}{"a": 1, "b": 2}},
+                               "bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}},
+                               "zap": []interface{}{map[interface{}]interface{}{"a": 5, "b": 6}},
+                       },
+                       key: "b", op: "in", match: slice(3, 4, 5),
+                       expect: map[string]interface{}{
+                               "bar": []interface{}{map[interface{}]interface{}{"a": 3, "b": 4}},
+                       },
+               },
        } {
                var results interface{}
                var err error