tpl/collections: Make Pages etc. work with the in func
authorBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Thu, 18 Apr 2019 15:06:54 +0000 (17:06 +0200)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Thu, 18 Apr 2019 21:42:01 +0000 (23:42 +0200)
Fixes #5875

tpl/collections/collections.go
tpl/collections/collections_test.go

index 15f17ff6df202b85a9e17bf9bbf31842f8f2f8db..69d0f7042232f45c23ba8bacfad7555f997af72e 100644 (file)
@@ -250,27 +250,26 @@ func (ns *Namespace) In(l interface{}, v interface{}) bool {
        lv := reflect.ValueOf(l)
        vv := reflect.ValueOf(v)
 
+       if !vv.Type().Comparable() {
+               // TODO(bep) consider adding error to the signature.
+               return false
+       }
+
+       // Normalize numeric types to float64 etc.
+       vvk := normalize(vv)
+
        switch lv.Kind() {
        case reflect.Array, reflect.Slice:
                for i := 0; i < lv.Len(); i++ {
-                       lvv := lv.Index(i)
-                       lvv, isNil := indirect(lvv)
-                       if isNil {
+                       lvv, isNil := indirectInterface(lv.Index(i))
+                       if isNil || !lvv.Type().Comparable() {
                                continue
                        }
-                       switch lvv.Kind() {
-                       case reflect.String:
-                               if vv.Type() == lvv.Type() && vv.String() == lvv.String() {
-                                       return true
-                               }
-                       default:
-                               if isNumber(vv.Kind()) && isNumber(lvv.Kind()) {
-                                       f1, err1 := numberToFloat(vv)
-                                       f2, err2 := numberToFloat(lvv)
-                                       if err1 == nil && err2 == nil && f1 == f2 {
-                                               return true
-                                       }
-                               }
+
+                       lvvk := normalize(lvv)
+
+                       if lvvk == vvk {
+                               return true
                        }
                }
        case reflect.String:
index 741dd074dd0c9fe3e3ecc63725e03f67653ec27c..c87490b2c099a0e71c9e2ace177cfcf9aad2cb4e 100644 (file)
@@ -276,6 +276,7 @@ func TestFirst(t *testing.T) {
 
 func TestIn(t *testing.T) {
        t.Parallel()
+       assert := require.New(t)
 
        ns := New(&deps.Deps{})
 
@@ -302,12 +303,18 @@ func TestIn(t *testing.T) {
                {"this substring should be found", "substring", true},
                {"this substring should not be found", "subseastring", false},
                {nil, "foo", false},
+               // Pointers
+               {pagesPtr{p1, p2, p3, p2}, p2, true},
+               {pagesPtr{p1, p2, p3, p2}, p4, false},
+               // Structs
+               {pagesVals{p3v, p2v, p3v, p2v}, p2v, true},
+               {pagesVals{p3v, p2v, p3v, p2v}, p4v, false},
        } {
 
                errMsg := fmt.Sprintf("[%d] %v", i, test)
 
                result := ns.In(test.l1, test.l2)
-               assert.Equal(t, test.expect, result, errMsg)
+               assert.Equal(test.expect, result, errMsg)
        }
 }