From 488776b6498d1377718133d42daa87ce1236215d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Bj=C3=B8rn=20Erik=20Pedersen?= Date: Tue, 6 Nov 2018 13:04:11 +0100 Subject: [PATCH] tpl/collections: Add collections.SymDiff Fixes #5410 --- tpl/collections/complement.go | 20 ++------ tpl/collections/init.go | 7 +++ tpl/collections/reflect_helpers.go | 44 +++++++++++++++- tpl/collections/symdiff.go | 71 ++++++++++++++++++++++++++ tpl/collections/symdiff_test.go | 80 ++++++++++++++++++++++++++++++ 5 files changed, 204 insertions(+), 18 deletions(-) create mode 100644 tpl/collections/symdiff.go create mode 100644 tpl/collections/symdiff_test.go diff --git a/tpl/collections/complement.go b/tpl/collections/complement.go index d975faba..a5633f8b 100644 --- a/tpl/collections/complement.go +++ b/tpl/collections/complement.go @@ -33,23 +33,9 @@ func (ns *Namespace) Complement(seqs ...interface{}) (interface{}, error) { universe := seqs[len(seqs)-1] as := seqs[:len(seqs)-1] - aset := make(map[interface{}]struct{}) - - for _, av := range as { - v := reflect.ValueOf(av) - switch v.Kind() { - case reflect.Array, reflect.Slice: - for i := 0; i < v.Len(); i++ { - ev, _ := indirectInterface(v.Index(i)) - if !ev.Type().Comparable() { - return nil, errors.New("elements in complement must be comparable") - } - - aset[normalize(ev)] = struct{}{} - } - default: - return nil, fmt.Errorf("arguments to complement must be slices or arrays") - } + aset, err := collectIdentities(as...) + if err != nil { + return nil, err } v := reflect.ValueOf(universe) diff --git a/tpl/collections/init.go b/tpl/collections/init.go index 569932c0..8dbef75c 100644 --- a/tpl/collections/init.go +++ b/tpl/collections/init.go @@ -46,6 +46,13 @@ func init() { }, ) + ns.AddMethodMapping(ctx.SymDiff, + []string{"symdiff"}, + [][2]string{ + {`{{ slice 1 2 3 | symdiff (slice 3 4) }}`, `[1 2 4]`}, + }, + ) + ns.AddMethodMapping(ctx.Delimit, []string{"delimit"}, [][2]string{ diff --git a/tpl/collections/reflect_helpers.go b/tpl/collections/reflect_helpers.go index 07439647..85aa389c 100644 --- a/tpl/collections/reflect_helpers.go +++ b/tpl/collections/reflect_helpers.go @@ -14,10 +14,11 @@ package collections import ( - "errors" "fmt" "reflect" "time" + + "github.com/pkg/errors" ) var ( @@ -59,6 +60,47 @@ func normalize(v reflect.Value) interface{} { return v.Interface() } +// collects identities from the slices in seqs into a set. Numeric values are normalized, +// pointers unwrapped. +func collectIdentities(seqs ...interface{}) (map[interface{}]bool, error) { + seen := make(map[interface{}]bool) + for _, seq := range seqs { + v := reflect.ValueOf(seq) + switch v.Kind() { + case reflect.Array, reflect.Slice: + for i := 0; i < v.Len(); i++ { + ev, _ := indirectInterface(v.Index(i)) + if !ev.Type().Comparable() { + return nil, errors.New("elements must be comparable") + } + + seen[normalize(ev)] = true + } + default: + return nil, fmt.Errorf("arguments must be slices or arrays") + } + } + + return seen, nil +} + +// We have some different numeric and string types that we try to behave like +// they were the same. +func convertValue(v reflect.Value, to reflect.Type) (reflect.Value, error) { + if v.Type().AssignableTo(to) { + return v, nil + } + switch kind := to.Kind(); { + case kind == reflect.String: + s, err := toString(v) + return reflect.ValueOf(s), err + case isNumber(kind): + return convertNumber(v, kind) + default: + return reflect.Value{}, errors.Errorf("%s is not assignable to %s", v.Type(), to) + } +} + // There are potential overflows in this function, but the downconversion of // int64 etc. into int8 etc. is coming from the synthetic unit tests for Union etc. // TODO(bep) We should consider normalizing the slices to int64 etc. diff --git a/tpl/collections/symdiff.go b/tpl/collections/symdiff.go new file mode 100644 index 00000000..1c58257e --- /dev/null +++ b/tpl/collections/symdiff.go @@ -0,0 +1,71 @@ +// Copyright 2018 The Hugo Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package collections + +import ( + "fmt" + "reflect" + + "github.com/pkg/errors" +) + +// SymDiff returns the symmetric difference of s1 and s2. +// Arguments must be either a slice or an array of comparable types. +func (ns *Namespace) SymDiff(s2, s1 interface{}) (interface{}, error) { + ids1, err := collectIdentities(s1) + if err != nil { + return nil, err + } + ids2, err := collectIdentities(s2) + if err != nil { + return nil, err + } + + var slice reflect.Value + var sliceElemType reflect.Type + + for i, s := range []interface{}{s1, s2} { + v := reflect.ValueOf(s) + + switch v.Kind() { + case reflect.Array, reflect.Slice: + if i == 0 { + sliceType := v.Type() + sliceElemType = sliceType.Elem() + slice = reflect.MakeSlice(sliceType, 0, 0) + } + + for i := 0; i < v.Len(); i++ { + ev, _ := indirectInterface(v.Index(i)) + if !ev.Type().Comparable() { + return nil, errors.New("symdiff: elements must be comparable") + } + key := normalize(ev) + // Append if the key is not in their intersection. + if ids1[key] != ids2[key] { + v, err := convertValue(ev, sliceElemType) + if err != nil { + return nil, errors.WithMessage(err, "symdiff: failed to convert value") + } + slice = reflect.Append(slice, v) + } + } + default: + return nil, fmt.Errorf("arguments to symdiff must be slices or arrays") + } + } + + return slice.Interface(), nil + +} diff --git a/tpl/collections/symdiff_test.go b/tpl/collections/symdiff_test.go new file mode 100644 index 00000000..d4499973 --- /dev/null +++ b/tpl/collections/symdiff_test.go @@ -0,0 +1,80 @@ +// Copyright 2018 The Hugo Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package collections + +import ( + "fmt" + "reflect" + "testing" + + "github.com/gohugoio/hugo/deps" + + "github.com/stretchr/testify/require" +) + +func TestSymDiff(t *testing.T) { + t.Parallel() + + assert := require.New(t) + + ns := New(&deps.Deps{}) + + s1 := []TstX{TstX{A: "a"}, TstX{A: "b"}} + s2 := []TstX{TstX{A: "a"}, TstX{A: "e"}} + + xa, xd := &TstX{A: "a"}, &TstX{A: "d"} + + sp1 := []*TstX{xa, &TstX{A: "b"}, xd, &TstX{A: "e"}} + sp2 := []*TstX{&TstX{A: "b"}, &TstX{A: "e"}} + + for i, test := range []struct { + s1 interface{} + s2 interface{} + expected interface{} + }{ + {[]string{"a", "x", "b", "c"}, []string{"a", "b", "y", "c"}, []string{"x", "y"}}, + {[]string{"a", "b", "c"}, []string{"a", "b", "c"}, []string{}}, + {[]interface{}{"a", "b", nil}, []interface{}{"a"}, []interface{}{"b", nil}}, + {[]int{1, 2, 3}, []int{3, 4}, []int{1, 2, 4}}, + {[]int{1, 2, 3}, []int64{3, 4}, []int{1, 2, 4}}, + {s1, s2, []TstX{TstX{A: "b"}, TstX{A: "e"}}}, + {sp1, sp2, []*TstX{xa, xd}}, + + // Errors + {"error", "error", false}, + {[]int{1, 2, 3}, []string{"3", "4"}, false}, + } { + + errMsg := fmt.Sprintf("[%d]", i) + + result, err := ns.SymDiff(test.s2, test.s1) + + if b, ok := test.expected.(bool); ok && !b { + require.Error(t, err, errMsg) + continue + } + + require.NoError(t, err, errMsg) + + if !reflect.DeepEqual(test.expected, result) { + t.Fatalf("%s got\n%T: %v\nexpected\n%T: %v", errMsg, result, result, test.expected, test.expected) + } + } + + _, err := ns.Complement() + assert.Error(err) + _, err = ns.Complement([]string{"a", "b"}) + assert.Error(err) + +} -- 2.30.2