tpl/collections: Add support for interfaces to intersect
authorCameron Moore <moorereason@gmail.com>
Tue, 2 May 2017 02:25:39 +0000 (21:25 -0500)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Thu, 18 May 2017 07:13:44 +0000 (10:13 +0300)
Fixes #1952

tpl/collections/collections.go
tpl/collections/collections_test.go
tpl/collections/where.go

index 0843fb7bc866fe66d013fda7ef7b427261bf429d..081515ae57f233bd3f92895aa1deebee1cf95c7f 100644 (file)
@@ -298,21 +298,49 @@ func (ns *Namespace) Intersect(l1, l2 interface{}) (interface{}, error) {
                                        l2vv := l2v.Index(j)
                                        switch l1vv.Kind() {
                                        case reflect.String:
-                                               if l1vv.Type() == l2vv.Type() && l1vv.String() == l2vv.String() && !ns.In(r.Interface(), l2vv.Interface()) {
-                                                       r = reflect.Append(r, l2vv)
+                                               l2t, err := toString(l2vv)
+                                               if err == nil && l1vv.String() == l2t && !ns.In(r.Interface(), l1vv.Interface()) {
+                                                       r = reflect.Append(r, l1vv)
                                                }
                                        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-                                               switch l2vv.Kind() {
-                                               case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-                                                       if l1vv.Int() == l2vv.Int() && !ns.In(r.Interface(), l2vv.Interface()) {
-                                                               r = reflect.Append(r, l2vv)
-                                                       }
+                                               l2t, err := toInt(l2vv)
+                                               if err == nil && l1vv.Int() == l2t && !ns.In(r.Interface(), l1vv.Interface()) {
+                                                       r = reflect.Append(r, l1vv)
                                                }
                                        case reflect.Float32, reflect.Float64:
-                                               switch l2vv.Kind() {
-                                               case reflect.Float32, reflect.Float64:
-                                                       if l1vv.Float() == l2vv.Float() && !ns.In(r.Interface(), l2vv.Interface()) {
-                                                               r = reflect.Append(r, l2vv)
+                                               l2t, err := toFloat(l2vv)
+                                               if err == nil && l1vv.Float() == l2t && !ns.In(r.Interface(), l1vv.Interface()) {
+                                                       r = reflect.Append(r, l1vv)
+                                               }
+                                       case reflect.Interface:
+                                               switch l1vvActual := l1vv.Interface().(type) {
+                                               case string:
+                                                       switch l2vvActual := l2vv.Interface().(type) {
+                                                       case string:
+                                                               if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) {
+                                                                       r = reflect.Append(r, l1vv)
+                                                               }
+                                                       }
+                                               case int, int8, int16, int32, int64:
+                                                       switch l2vvActual := l2vv.Interface().(type) {
+                                                       case int, int8, int16, int32, int64:
+                                                               if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) {
+                                                                       r = reflect.Append(r, l1vv)
+                                                               }
+                                                       }
+                                               case uint, uint8, uint16, uint32, uint64:
+                                                       switch l2vvActual := l2vv.Interface().(type) {
+                                                       case uint, uint8, uint16, uint32, uint64:
+                                                               if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) {
+                                                                       r = reflect.Append(r, l1vv)
+                                                               }
+                                                       }
+                                               case float32, float64:
+                                                       switch l2vvActual := l2vv.Interface().(type) {
+                                                       case float32, float64:
+                                                               if l1vvActual == l2vvActual && !ns.In(r.Interface(), l1vvActual) {
+                                                                       r = reflect.Append(r, l1vv)
+                                                               }
                                                        }
                                                }
                                        }
index eefbcef6cef83ad2ea68d84a949e9b7dcacbc66e..07055de8665b2e4f6e0170b08cc4b7350ade3c3a 100644 (file)
@@ -260,7 +260,9 @@ func TestIntersect(t *testing.T) {
                {[]string{"a", "b"}, []string{"a", "b", "c"}, []string{"a", "b"}},
                {[]string{"a", "b", "c"}, []string{"d", "e"}, []string{}},
                {[]string{}, []string{}, []string{}},
-               {nil, nil, make([]interface{}, 0)},
+               {[]string{"a", "b"}, nil, []interface{}{}},
+               {nil, []string{"a", "b"}, []interface{}{}},
+               {nil, nil, []interface{}{}},
                {[]string{"1", "2"}, []int{1, 2}, []string{}},
                {[]int{1, 2}, []string{"1", "2"}, []int{}},
                {[]int{1, 2, 4}, []int{2, 4}, []int{2, 4}},
@@ -270,6 +272,36 @@ func TestIntersect(t *testing.T) {
                // errors
                {"not array or slice", []string{"a"}, false},
                {[]string{"a"}, "not array or slice", false},
+
+               // []interface{} ∩ []interface{}
+               {[]interface{}{"a", "b", "c"}, []interface{}{"a", "b", "b"}, []interface{}{"a", "b"}},
+               {[]interface{}{1, 2, 3}, []interface{}{1, 2, 2}, []interface{}{1, 2}},
+               {[]interface{}{int8(1), int8(2), int8(3)}, []interface{}{int8(1), int8(2), int8(2)}, []interface{}{int8(1), int8(2)}},
+               {[]interface{}{int16(1), int16(2), int16(3)}, []interface{}{int16(1), int16(2), int16(2)}, []interface{}{int16(1), int16(2)}},
+               {[]interface{}{int32(1), int32(2), int32(3)}, []interface{}{int32(1), int32(2), int32(2)}, []interface{}{int32(1), int32(2)}},
+               {[]interface{}{int64(1), int64(2), int64(3)}, []interface{}{int64(1), int64(2), int64(2)}, []interface{}{int64(1), int64(2)}},
+               {[]interface{}{float32(1), float32(2), float32(3)}, []interface{}{float32(1), float32(2), float32(2)}, []interface{}{float32(1), float32(2)}},
+               {[]interface{}{float64(1), float64(2), float64(3)}, []interface{}{float64(1), float64(2), float64(2)}, []interface{}{float64(1), float64(2)}},
+
+               // []interface{} ∩ []T
+               {[]interface{}{"a", "b", "c"}, []string{"a", "b", "b"}, []interface{}{"a", "b"}},
+               {[]interface{}{1, 2, 3}, []int{1, 2, 2}, []interface{}{1, 2}},
+               {[]interface{}{int8(1), int8(2), int8(3)}, []int8{1, 2, 2}, []interface{}{int8(1), int8(2)}},
+               {[]interface{}{int16(1), int16(2), int16(3)}, []int16{1, 2, 2}, []interface{}{int16(1), int16(2)}},
+               {[]interface{}{int32(1), int32(2), int32(3)}, []int32{1, 2, 2}, []interface{}{int32(1), int32(2)}},
+               {[]interface{}{int64(1), int64(2), int64(3)}, []int64{1, 2, 2}, []interface{}{int64(1), int64(2)}},
+               {[]interface{}{float32(1), float32(2), float32(3)}, []float32{1, 2, 2}, []interface{}{float32(1), float32(2)}},
+               {[]interface{}{float64(1), float64(2), float64(3)}, []float64{1, 2, 2}, []interface{}{float64(1), float64(2)}},
+
+               // []T ∩ []interface{}
+               {[]string{"a", "b", "c"}, []interface{}{"a", "b", "b"}, []string{"a", "b"}},
+               {[]int{1, 2, 3}, []interface{}{1, 2, 2}, []int{1, 2}},
+               {[]int8{1, 2, 3}, []interface{}{int8(1), int8(2), int8(2)}, []int8{1, 2}},
+               {[]int16{1, 2, 3}, []interface{}{int16(1), int16(2), int16(2)}, []int16{1, 2}},
+               {[]int32{1, 2, 3}, []interface{}{int32(1), int32(2), int32(2)}, []int32{1, 2}},
+               {[]int64{1, 2, 3}, []interface{}{int64(1), int64(2), int64(2)}, []int64{1, 2}},
+               {[]float32{1, 2, 3}, []interface{}{float32(1), float32(2), float32(2)}, []float32{1, 2}},
+               {[]float64{1, 2, 3}, []interface{}{float64(1), float64(2), float64(2)}, []float64{1, 2}},
        } {
                errMsg := fmt.Sprintf("[%d] %v", i, test)
 
index f34494eb304ce80ad0f06e884e31d41431db11f5..e9528fb867bbc4016543d4e2cd02c61e6e113215 100644 (file)
@@ -124,16 +124,15 @@ func (ns *Namespace) checkCondition(v, mv reflect.Value, op string) (bool, error
                        iv := v.Int()
                        ivp = &iv
                        for i := 0; i < mv.Len(); i++ {
-                               if anInt := toInt(mv.Index(i)); anInt != -1 {
+                               if anInt, err := toInt(mv.Index(i)); err == nil {
                                        ima = append(ima, anInt)
                                }
-
                        }
                case reflect.String:
                        sv := v.String()
                        svp = &sv
                        for i := 0; i < mv.Len(); i++ {
-                               if aString := toString(mv.Index(i)); aString != "" {
+                               if aString, err := toString(mv.Index(i)); err == nil {
                                        sma = append(sma, aString)
                                }
                        }
@@ -382,26 +381,37 @@ func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op
        return rv.Interface(), nil
 }
 
+// toFloat returns the int value if possible.
+func toFloat(v reflect.Value) (float64, error) {
+       switch v.Kind() {
+       case reflect.Float32, reflect.Float64:
+               return v.Float(), nil
+       case reflect.Interface:
+               return toFloat(v.Elem())
+       }
+       return -1, errors.New("unable to convert value to float")
+}
+
 // toInt returns the int value if possible, -1 if not.
-func toInt(v reflect.Value) int64 {
+func toInt(v reflect.Value) (int64, error) {
        switch v.Kind() {
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-               return v.Int()
+               return v.Int(), nil
        case reflect.Interface:
                return toInt(v.Elem())
        }
-       return -1
+       return -1, errors.New("unable to convert value to int")
 }
 
 // toString returns the string value if possible, "" if not.
-func toString(v reflect.Value) string {
+func toString(v reflect.Value) (string, error) {
        switch v.Kind() {
        case reflect.String:
-               return v.String()
+               return v.String(), nil
        case reflect.Interface:
                return toString(v.Elem())
        }
-       return ""
+       return "", errors.New("unable to convert value to string")
 }
 
 var (