Validate comparison operator argument count
authorJoe Mooring <joe.mooring@veriphor.com>
Fri, 4 Feb 2022 11:01:54 +0000 (03:01 -0800)
committerBjørn Erik Pedersen <bjorn.erik.pedersen@gmail.com>
Sat, 5 Feb 2022 16:41:43 +0000 (17:41 +0100)
Fixes #9462

tpl/compare/compare.go
tpl/compare/compare_test.go

index 88b18f00cd4d6f011a753409574673427e6de4a7..8f3f536b58b69175ba12854a99c5463f15356c74 100644 (file)
@@ -95,10 +95,7 @@ func (n *Namespace) Eq(first interface{}, others ...interface{}) bool {
        if n.caseInsensitive {
                panic("caseInsensitive not implemented for Eq")
        }
-       if len(others) == 0 {
-               panic("missing arguments for comparison")
-       }
-
+       n.checkComparisonArgCount(1, others...)
        normalize := func(v interface{}) interface{} {
                if types.IsNil(v) {
                        return nil
@@ -145,6 +142,7 @@ func (n *Namespace) Eq(first interface{}, others ...interface{}) bool {
 
 // Ne returns the boolean truth of arg1 != arg2 && arg1 != arg3 && arg1 != arg4.
 func (n *Namespace) Ne(first interface{}, others ...interface{}) bool {
+       n.checkComparisonArgCount(1, others...)
        for _, other := range others {
                if n.Eq(first, other) {
                        return false
@@ -155,6 +153,7 @@ func (n *Namespace) Ne(first interface{}, others ...interface{}) bool {
 
 // Ge returns the boolean truth of arg1 >= arg2 && arg1 >= arg3 && arg1 >= arg4.
 func (n *Namespace) Ge(first interface{}, others ...interface{}) bool {
+       n.checkComparisonArgCount(1, others...)
        for _, other := range others {
                left, right := n.compareGet(first, other)
                if !(left >= right) {
@@ -166,6 +165,7 @@ func (n *Namespace) Ge(first interface{}, others ...interface{}) bool {
 
 // Gt returns the boolean truth of arg1 > arg2 && arg1 > arg3 && arg1 > arg4.
 func (n *Namespace) Gt(first interface{}, others ...interface{}) bool {
+       n.checkComparisonArgCount(1, others...)
        for _, other := range others {
                left, right := n.compareGet(first, other)
                if !(left > right) {
@@ -177,6 +177,7 @@ func (n *Namespace) Gt(first interface{}, others ...interface{}) bool {
 
 // Le returns the boolean truth of arg1 <= arg2 && arg1 <= arg3 && arg1 <= arg4.
 func (n *Namespace) Le(first interface{}, others ...interface{}) bool {
+       n.checkComparisonArgCount(1, others...)
        for _, other := range others {
                left, right := n.compareGet(first, other)
                if !(left <= right) {
@@ -188,6 +189,7 @@ func (n *Namespace) Le(first interface{}, others ...interface{}) bool {
 
 // Lt returns the boolean truth of arg1 < arg2 && arg1 < arg3 && arg1 < arg4.
 func (n *Namespace) Lt(first interface{}, others ...interface{}) bool {
+       n.checkComparisonArgCount(1, others...)
        for _, other := range others {
                left, right := n.compareGet(first, other)
                if !(left < right) {
@@ -197,6 +199,13 @@ func (n *Namespace) Lt(first interface{}, others ...interface{}) bool {
        return true
 }
 
+func (n *Namespace) checkComparisonArgCount(min int, others ...interface{}) bool {
+       if len(others) < min {
+               panic("missing arguments for comparison")
+       }
+       return true
+}
+
 // Conditional can be used as a ternary operator.
 // It returns a if condition, else b.
 func (n *Namespace) Conditional(condition bool, a, b interface{}) interface{} {
index 76fe2698a9f2a6a7fc63dc862d870d96cef2aa92..9ef32fd851df225753936ae778c77bddec5dc346 100644 (file)
@@ -440,3 +440,20 @@ func TestConditional(t *testing.T) {
        c.Assert(n.Conditional(true, a, b), qt.Equals, a)
        c.Assert(n.Conditional(false, a, b), qt.Equals, b)
 }
+
+// Issue 9462
+func TestComparisonArgCount(t *testing.T) {
+       t.Parallel()
+       c := qt.New(t)
+
+       ns := New(false)
+
+       panicMsg := "missing arguments for comparison"
+
+       c.Assert(func() { ns.Eq(1) }, qt.PanicMatches, panicMsg)
+       c.Assert(func() { ns.Ge(1) }, qt.PanicMatches, panicMsg)
+       c.Assert(func() { ns.Gt(1) }, qt.PanicMatches, panicMsg)
+       c.Assert(func() { ns.Le(1) }, qt.PanicMatches, panicMsg)
+       c.Assert(func() { ns.Lt(1) }, qt.PanicMatches, panicMsg)
+       c.Assert(func() { ns.Ne(1) }, qt.PanicMatches, panicMsg)
+}