tpl/collections: Add Pages support to Intersect and Union
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Mon, 3 Jul 2017 08:32:10 +0000 (10:32 +0200)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Mon, 3 Jul 2017 19:48:03 +0000 (21:48 +0200)
This enables `AND` (`intersect`)  and `OR` (`union`)  filters when combined with `where`.

Example:

```go
{{ $pages := where .Site.RegularPages "Type" "not in" (slice "page" "about") }}
{{ $pages := $pages | union (where .Site.RegularPages "Params.pinned" true) }}
{{ $pages := $pages | intersect (where .Site.RegularPages "Params.images" "!=" nil) }}
```

The above fetches regular pages not of `page` or `about` type unless they are pinned. And finally, we exclude all pages with no `images` set in Page params.

Fixes #3174

tpl/collections/apply.go
tpl/collections/collections.go
tpl/collections/collections_test.go
tpl/collections/reflect_helpers.go
tpl/collections/where.go

index c3c3a297ba317a5309aa9fa0c18b35c07ebb1b20..0b2b006219783b16d54d9e44c74b767a07cb2c8d 100644 (file)
@@ -148,3 +148,15 @@ func indirect(v reflect.Value) (rv reflect.Value, isNil bool) {
        }
        return v, false
 }
+
+func indirectInterface(v reflect.Value) (rv reflect.Value, isNil bool) {
+       for ; v.Kind() == reflect.Interface; v = v.Elem() {
+               if v.IsNil() {
+                       return v, true
+               }
+               if v.Kind() == reflect.Interface && v.NumMethod() > 0 {
+                       break
+               }
+       }
+       return v, false
+}
index ab3d08f5e11de00ade312b3317a59693fcc2c20c..103cb38604dd206f616826386eff7caa6b36c99d 100644 (file)
@@ -256,7 +256,9 @@ func (ns *Namespace) In(l interface{}, v interface{}) bool {
                                }
                        default:
                                if isNumber(vv.Kind()) && isNumber(lvv.Kind()) {
-                                       if numberToFloat(vv) == numberToFloat(lvv) {
+                                       f1, err1 := numberToFloat(vv)
+                                       f2, err2 := numberToFloat(lvv)
+                                       if err1 == nil && err2 == nil && f1 == f2 {
                                                return true
                                        }
                                }
@@ -277,69 +279,24 @@ func (ns *Namespace) Intersect(l1, l2 interface{}) (interface{}, error) {
                return make([]interface{}, 0), nil
        }
 
+       var ins *intersector
+
        l1v := reflect.ValueOf(l1)
        l2v := reflect.ValueOf(l2)
 
        switch l1v.Kind() {
        case reflect.Array, reflect.Slice:
+               ins = &intersector{r: reflect.MakeSlice(l1v.Type(), 0, 0), seen: make(map[interface{}]bool)}
                switch l2v.Kind() {
                case reflect.Array, reflect.Slice:
-                       r := reflect.MakeSlice(l1v.Type(), 0, 0)
                        for i := 0; i < l1v.Len(); i++ {
                                l1vv := l1v.Index(i)
                                for j := 0; j < l2v.Len(); j++ {
                                        l2vv := l2v.Index(j)
-                                       switch l1vv.Kind() {
-                                       case reflect.String:
-                                               l2t, err := toString(l2vv)
-                                               if err == nil && l1vv.String() == l2t && !ns.In(r.Interface(), l1vv.Interface()) {
-                                                       r = reflect.Append(r, l1vv)
-                                               }
-                                       case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-                                               l2t, err := toInt(l2vv)
-                                               if err == nil && l1vv.Int() == l2t && !ns.In(r.Interface(), l1vv.Interface()) {
-                                                       r = reflect.Append(r, l1vv)
-                                               }
-                                       case reflect.Float32, reflect.Float64:
-                                               l2t, err := toFloat(l2vv)
-                                               if err == nil && l1vv.Float() == l2t && !ns.In(r.Interface(), l1vv.Interface()) {
-                                                       r = reflect.Append(r, l1vv)
-                                               }
-                                       case reflect.Interface:
-                                               switch l1vvActual := l1vv.Interface().(type) {
-                                               case string:
-                                                       switch l2vvActual := l2vv.Interface().(type) {
-                                                       case string:
-                                                               if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) {
-                                                                       r = reflect.Append(r, l1vv)
-                                                               }
-                                                       }
-                                               case int, int8, int16, int32, int64:
-                                                       switch l2vvActual := l2vv.Interface().(type) {
-                                                       case int, int8, int16, int32, int64:
-                                                               if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) {
-                                                                       r = reflect.Append(r, l1vv)
-                                                               }
-                                                       }
-                                               case uint, uint8, uint16, uint32, uint64:
-                                                       switch l2vvActual := l2vv.Interface().(type) {
-                                                       case uint, uint8, uint16, uint32, uint64:
-                                                               if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) {
-                                                                       r = reflect.Append(r, l1vv)
-                                                               }
-                                                       }
-                                               case float32, float64:
-                                                       switch l2vvActual := l2vv.Interface().(type) {
-                                                       case float32, float64:
-                                                               if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) {
-                                                                       r = reflect.Append(r, l1vv)
-                                                               }
-                                                       }
-                                               }
-                                       }
+                                       ins.handleValuePair(l1vv, l2vv)
                                }
                        }
-                       return r.Interface(), nil
+                       return ins.r.Interface(), nil
                default:
                        return nil, errors.New("can't iterate over " + reflect.ValueOf(l2).Type().String())
                }
@@ -531,6 +488,41 @@ func (ns *Namespace) Slice(args ...interface{}) []interface{} {
        return args
 }
 
+type intersector struct {
+       r    reflect.Value
+       seen map[interface{}]bool
+}
+
+func (i *intersector) appendIfNotSeen(v reflect.Value) {
+       vi := v.Interface()
+       if !i.seen[vi] {
+               i.r = reflect.Append(i.r, v)
+               i.seen[vi] = true
+       }
+}
+
+func (ins *intersector) handleValuePair(l1vv, l2vv reflect.Value) {
+       switch kind := l1vv.Kind(); {
+       case kind == reflect.String:
+               l2t, err := toString(l2vv)
+               if err == nil && l1vv.String() == l2t {
+                       ins.appendIfNotSeen(l1vv)
+               }
+       case isNumber(kind):
+               f1, err1 := numberToFloat(l1vv)
+               f2, err2 := numberToFloat(l2vv)
+               if err1 == nil && err2 == nil && f1 == f2 {
+                       ins.appendIfNotSeen(l1vv)
+               }
+       case kind == reflect.Ptr, kind == reflect.Struct:
+               if l1vv.Interface() == l2vv.Interface() {
+                       ins.appendIfNotSeen(l1vv)
+               }
+       case kind == reflect.Interface:
+               ins.handleValuePair(reflect.ValueOf(l1vv.Interface()), l2vv)
+       }
+}
+
 // Union returns the union of the given sets, l1 and l2. l1 and
 // l2 must be of the same type and may be either arrays or slices.
 // If l1 and l2 aren't of the same type then l1 will be returned.
@@ -547,105 +539,54 @@ func (ns *Namespace) Union(l1, l2 interface{}) (interface{}, error) {
        l1v := reflect.ValueOf(l1)
        l2v := reflect.ValueOf(l2)
 
+       var ins *intersector
+
        switch l1v.Kind() {
        case reflect.Array, reflect.Slice:
                switch l2v.Kind() {
                case reflect.Array, reflect.Slice:
-                       r := reflect.MakeSlice(l1v.Type(), 0, 0)
+                       ins = &intersector{r: reflect.MakeSlice(l1v.Type(), 0, 0), seen: make(map[interface{}]bool)}
 
                        if l1v.Type() != l2v.Type() &&
                                l1v.Type().Elem().Kind() != reflect.Interface &&
                                l2v.Type().Elem().Kind() != reflect.Interface {
-                               return r.Interface(), nil
+                               return ins.r.Interface(), nil
                        }
 
-                       var l1vv reflect.Value
+                       var (
+                               l1vv  reflect.Value
+                               isNil bool
+                       )
+
                        for i := 0; i < l1v.Len(); i++ {
-                               l1vv = l1v.Index(i)
-                               if !ns.In(r.Interface(), l1vv.Interface()) {
-                                       r = reflect.Append(r, l1vv)
+                               l1vv, isNil = indirectInterface(l1v.Index(i))
+                               if !isNil {
+                                       ins.appendIfNotSeen(l1vv)
                                }
                        }
 
                        for j := 0; j < l2v.Len(); j++ {
                                l2vv := l2v.Index(j)
 
-                               switch l1vv.Kind() {
-                               case reflect.String:
+                               switch kind := l1vv.Kind(); {
+                               case kind == reflect.String:
                                        l2t, err := toString(l2vv)
-                                       if err == nil && !ns.In(r.Interface(), l2t) {
-                                               r = reflect.Append(r, reflect.ValueOf(l2t))
-                                       }
-                               case reflect.Int:
-                                       l2t, err := toInt(l2vv)
-                                       if err == nil && !ns.In(r.Interface(), l2t) {
-                                               r = reflect.Append(r, reflect.ValueOf(int(l2t)))
-                                       }
-                               case reflect.Int8:
-                                       l2t, err := toInt(l2vv)
-                                       if err == nil && !ns.In(r.Interface(), l2t) {
-                                               r = reflect.Append(r, reflect.ValueOf(int8(l2t)))
-                                       }
-                               case reflect.Int16:
-                                       l2t, err := toInt(l2vv)
-                                       if err == nil && !ns.In(r.Interface(), l2t) {
-                                               r = reflect.Append(r, reflect.ValueOf(int16(l2t)))
+                                       if err == nil {
+                                               ins.appendIfNotSeen(reflect.ValueOf(l2t))
                                        }
-                               case reflect.Int32:
-                                       l2t, err := toInt(l2vv)
-                                       if err == nil && !ns.In(r.Interface(), l2t) {
-                                               r = reflect.Append(r, reflect.ValueOf(int32(l2t)))
-                                       }
-                               case reflect.Int64:
-                                       l2t, err := toInt(l2vv)
-                                       if err == nil && !ns.In(r.Interface(), l2t) {
-                                               r = reflect.Append(r, reflect.ValueOf(l2t))
-                                       }
-                               case reflect.Float32:
-                                       l2t, err := toFloat(l2vv)
-                                       if err == nil && !ns.In(r.Interface(), float32(l2t)) {
-                                               r = reflect.Append(r, reflect.ValueOf(float32(l2t)))
-                                       }
-                               case reflect.Float64:
-                                       l2t, err := toFloat(l2vv)
-                                       if err == nil && !ns.In(r.Interface(), l2t) {
-                                               r = reflect.Append(r, reflect.ValueOf(l2t))
-                                       }
-                               case reflect.Interface:
-                                       switch l1vv.Interface().(type) {
-                                       case string:
-                                               switch l2vvActual := l2vv.Interface().(type) {
-                                               case string:
-                                                       if !ns.In(r.Interface(), l2vvActual) {
-                                                               r = reflect.Append(r, l2vv)
-                                                       }
-                                               }
-                                       case int, int8, int16, int32, int64:
-                                               switch l2vvActual := l2vv.Interface().(type) {
-                                               case int, int8, int16, int32, int64:
-                                                       if !ns.In(r.Interface(), l2vvActual) {
-                                                               r = reflect.Append(r, l2vv)
-                                                       }
-                                               }
-                                       case uint, uint8, uint16, uint32, uint64:
-                                               switch l2vvActual := l2vv.Interface().(type) {
-                                               case uint, uint8, uint16, uint32, uint64:
-                                                       if !ns.In(r.Interface(), l2vvActual) {
-                                                               r = reflect.Append(r, l2vv)
-                                                       }
-                                               }
-                                       case float32, float64:
-                                               switch l2vvActual := l2vv.Interface().(type) {
-                                               case float32, float64:
-                                                       if !ns.In(r.Interface(), l2vvActual) {
-                                                               r = reflect.Append(r, l2vv)
-                                                       }
-                                               }
+                               case isNumber(kind):
+                                       var err error
+                                       l2vv, err = convertNumber(l2vv, kind)
+                                       if err == nil {
+                                               ins.appendIfNotSeen(l2vv)
                                        }
+                               case kind == reflect.Interface, kind == reflect.Struct, kind == reflect.Ptr:
+                                       ins.appendIfNotSeen(l2vv)
+
                                }
                        }
 
-                       return r.Interface(), nil
+                       return ins.r.Interface(), nil
                default:
                        return nil, errors.New("can't iterate over " + reflect.ValueOf(l2).Type().String())
                }
index ea23a1de773290e9acd2b81f3d0a86e3d4c9f3fe..46bef9483063f8b30fd0fb8c1611a5a65496f35f 100644 (file)
@@ -258,11 +258,34 @@ func TestIn(t *testing.T) {
        }
 }
 
+type page struct {
+       Title string
+}
+
+func (p page) String() string {
+       return "p-" + p.Title
+}
+
+type pagesPtr []*page
+type pagesVals []page
+
 func TestIntersect(t *testing.T) {
        t.Parallel()
 
        ns := New(&deps.Deps{})
 
+       var (
+               p1 = &page{"A"}
+               p2 = &page{"B"}
+               p3 = &page{"C"}
+               p4 = &page{"D"}
+
+               p1v = page{"A"}
+               p2v = page{"B"}
+               p3v = page{"C"}
+               p4v = page{"D"}
+       )
+
        for i, test := range []struct {
                l1, l2 interface{}
                expect interface{}
@@ -280,6 +303,7 @@ func TestIntersect(t *testing.T) {
                {[]int{2, 4}, []int{1, 2, 4}, []int{2, 4}},
                {[]int{1, 2, 4}, []int{3, 6}, []int{}},
                {[]float64{2.2, 4.4}, []float64{1.1, 2.2, 4.4}, []float64{2.2, 4.4}},
+
                // errors
                {"not array or slice", []string{"a"}, false},
                {[]string{"a"}, "not array or slice", false},
@@ -314,8 +338,15 @@ func TestIntersect(t *testing.T) {
                {[]int64{1, 2, 3}, []interface{}{int64(1), int64(2), int64(2)}, []int64{1, 2}},
                {[]float32{1, 2, 3}, []interface{}{float32(1), float32(2), float32(2)}, []float32{1, 2}},
                {[]float64{1, 2, 3}, []interface{}{float64(1), float64(2), float64(2)}, []float64{1, 2}},
+
+               // Structs
+               {pagesPtr{p1, p4, p2, p3}, pagesPtr{p4, p2, p2}, pagesPtr{p4, p2}},
+               {pagesVals{p1v, p4v, p2v, p3v}, pagesVals{p1v, p3v, p3v}, pagesVals{p1v, p3v}},
+               {[]interface{}{p1, p4, p2, p3}, []interface{}{p4, p2, p2}, []interface{}{p4, p2}},
+               {[]interface{}{p1v, p4v, p2v, p3v}, []interface{}{p1v, p3v, p3v}, []interface{}{p1v, p3v}},
        } {
-               errMsg := fmt.Sprintf("[%d] %v", i, test)
+
+               errMsg := fmt.Sprintf("[%d]", test)
 
                result, err := ns.Intersect(test.l1, test.l2)
 
@@ -325,7 +356,9 @@ func TestIntersect(t *testing.T) {
                }
 
                assert.NoError(t, err, errMsg)
-               assert.Equal(t, test.expect, result, errMsg)
+               if !reflect.DeepEqual(result, test.expect) {
+                       t.Fatalf("[%d] Got\n%v expected\n%v", i, result, test.expect)
+               }
        }
 }
 
@@ -569,6 +602,18 @@ func TestUnion(t *testing.T) {
 
        ns := New(&deps.Deps{})
 
+       var (
+               p1 = &page{"A"}
+               p2 = &page{"B"}
+               //              p3 = &page{"C"}
+               p4 = &page{"D"}
+
+               p1v = page{"A"}
+               //p2v = page{"B"}
+               p3v = page{"C"}
+               //p4v = page{"D"}
+       )
+
        for i, test := range []struct {
                l1     interface{}
                l2     interface{}
@@ -604,6 +649,7 @@ func TestUnion(t *testing.T) {
                {[]int16{2, 4}, []interface{}{1, 2, 4}, []int16{2, 4, 1}, false},
                {[]int32{2, 4}, []interface{}{1, 2, 4}, []int32{2, 4, 1}, false},
                {[]int64{2, 4}, []interface{}{1, 2, 4}, []int64{2, 4, 1}, false},
+
                {[]float64{2.2, 4.4}, []interface{}{1.1, 2.2, 4.4}, []float64{2.2, 4.4, 1.1}, false},
                {[]float32{2.2, 4.4}, []interface{}{1.1, 2.2, 4.4}, []float32{2.2, 4.4, 1.1}, false},
 
@@ -611,14 +657,21 @@ func TestUnion(t *testing.T) {
                {[]interface{}{"a", "b", "c", "c"}, []string{"a", "b", "d"}, []interface{}{"a", "b", "c", "d"}, false},
                {[]interface{}{}, []string{}, []interface{}{}, false},
                {[]interface{}{1, 2}, []int{2, 3}, []interface{}{1, 2, 3}, false},
-               {[]interface{}{1, 2}, []int8{2, 3}, []interface{}{1, 2, int8(3)}, false},
+               {[]interface{}{1, 2}, []int8{2, 3}, []interface{}{1, 2, 3}, false}, // 28
                {[]interface{}{uint(1), uint(2)}, []uint{2, 3}, []interface{}{uint(1), uint(2), uint(3)}, false},
                {[]interface{}{1.1, 2.2}, []float64{2.2, 3.3}, []interface{}{1.1, 2.2, 3.3}, false},
 
+               // Structs
+               {pagesPtr{p1, p4}, pagesPtr{p4, p2, p2}, pagesPtr{p1, p4, p2}, false},
+               {pagesVals{p1v}, pagesVals{p3v, p3v}, pagesVals{p1v, p3v}, false},
+               {[]interface{}{p1, p4}, []interface{}{p4, p2, p2}, []interface{}{p1, p4, p2}, false},
+               {[]interface{}{p1v}, []interface{}{p3v, p3v}, []interface{}{p1v, p3v}, false},
+
                // errors
                {"not array or slice", []string{"a"}, false, true},
                {[]string{"a"}, "not array or slice", false, true},
        } {
+
                errMsg := fmt.Sprintf("[%d] %v", i, test)
 
                result, err := ns.Union(test.l1, test.l2)
@@ -628,7 +681,9 @@ func TestUnion(t *testing.T) {
                }
 
                assert.NoError(t, err, errMsg)
-               assert.Equal(t, test.expect, result, errMsg)
+               if !reflect.DeepEqual(result, test.expect) {
+                       t.Fatalf("[%d] Got\n%v expected\n%v", i, result, test.expect)
+               }
        }
 }
 
index f07ea978c5f0e59ecd95d9c9c78bae8156d1569b..69eaa68c4139a65e244037c3595e3bdff2ddf729 100644 (file)
 package collections
 
 import (
+       "errors"
+       "fmt"
        "reflect"
+       "time"
 )
 
-func numberToFloat(v reflect.Value) float64 {
+var (
+       zero      reflect.Value
+       errorType = reflect.TypeOf((*error)(nil)).Elem()
+       timeType  = reflect.TypeOf((*time.Time)(nil)).Elem()
+)
+
+func numberToFloat(v reflect.Value) (float64, error) {
        switch kind := v.Kind(); {
        case isFloat(kind):
-               return v.Float()
+               return v.Float(), nil
        case isInt(kind):
-               return float64(v.Int())
-       case isUInt(kind):
-               return float64(v.Uint())
+               return float64(v.Int()), nil
+       case isUint(kind):
+               return float64(v.Uint()), nil
        case kind == reflect.Interface:
                return numberToFloat(v.Elem())
        default:
-               panic("Invalid type in numberToFloat")
+               return 0, fmt.Errorf("Invalid kind %s in numberToFloat", kind)
+       }
+}
+
+// There are potential overflows in this function, but the downconversion of
+// int64 etc. into int8 etc. is coming from the synthetic unit tests for Union etc.
+// TODO(bep) We should consider normalizing the slices to int64 etc.
+func convertNumber(v reflect.Value, to reflect.Kind) (reflect.Value, error) {
+       var n reflect.Value
+       if isFloat(to) {
+               f, err := toFloat(v)
+               if err != nil {
+                       return n, err
+               }
+               switch to {
+               case reflect.Float32:
+                       n = reflect.ValueOf(float32(f))
+               default:
+                       n = reflect.ValueOf(float64(f))
+               }
+       } else if isInt(to) {
+               i, err := toInt(v)
+               if err != nil {
+                       return n, err
+               }
+               switch to {
+               case reflect.Int:
+                       n = reflect.ValueOf(int(i))
+               case reflect.Int8:
+                       n = reflect.ValueOf(int8(i))
+               case reflect.Int16:
+                       n = reflect.ValueOf(int16(i))
+               case reflect.Int32:
+                       n = reflect.ValueOf(int32(i))
+               case reflect.Int64:
+                       n = reflect.ValueOf(int64(i))
+               }
+       } else if isUint(to) {
+               i, err := toUint(v)
+               if err != nil {
+                       return n, err
+               }
+               switch to {
+               case reflect.Uint:
+                       n = reflect.ValueOf(uint(i))
+               case reflect.Uint8:
+                       n = reflect.ValueOf(uint8(i))
+               case reflect.Uint16:
+                       n = reflect.ValueOf(uint16(i))
+               case reflect.Uint32:
+                       n = reflect.ValueOf(uint32(i))
+               case reflect.Uint64:
+                       n = reflect.ValueOf(uint64(i))
+               }
+
+       }
+
+       if !n.IsValid() {
+               return n, errors.New("invalid values")
        }
+
+       return n, nil
+
 }
 
 func isNumber(kind reflect.Kind) bool {
-       return isInt(kind) || isUInt(kind) || isFloat(kind)
+       return isInt(kind) || isUint(kind) || isFloat(kind)
 }
 
 func isInt(kind reflect.Kind) bool {
@@ -45,7 +115,7 @@ func isInt(kind reflect.Kind) bool {
        }
 }
 
-func isUInt(kind reflect.Kind) bool {
+func isUint(kind reflect.Kind) bool {
        switch kind {
        case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
                return true
index d8045b301e90f06ddba94247816faf95b9f06e1b..37be00509e5ba883ebdbc73b231cb3fbeca98546 100644 (file)
@@ -18,7 +18,6 @@ import (
        "fmt"
        "reflect"
        "strings"
-       "time"
 )
 
 // Where returns a filtered subset of a given data type.
@@ -404,6 +403,16 @@ func toInt(v reflect.Value) (int64, error) {
        return -1, errors.New("unable to convert value to int")
 }
 
+func toUint(v reflect.Value) (uint64, error) {
+       switch v.Kind() {
+       case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+               return v.Uint(), nil
+       case reflect.Interface:
+               return toUint(v.Elem())
+       }
+       return 0, errors.New("unable to convert value to uint")
+}
+
 // toString returns the string value if possible, "" if not.
 func toString(v reflect.Value) (string, error) {
        switch v.Kind() {
@@ -415,12 +424,6 @@ func toString(v reflect.Value) (string, error) {
        return "", errors.New("unable to convert value to string")
 }
 
-var (
-       zero      reflect.Value
-       errorType = reflect.TypeOf((*error)(nil)).Elem()
-       timeType  = reflect.TypeOf((*time.Time)(nil)).Elem()
-)
-
 func toTimeUnix(v reflect.Value) int64 {
        if v.Kind() == reflect.Interface {
                return toTimeUnix(v.Elem())