bpf: generalize is_scalar_branch_taken() logic
authorAndrii Nakryiko <andrii@kernel.org>
Sun, 12 Nov 2023 01:05:58 +0000 (17:05 -0800)
committerAlexei Starovoitov <ast@kernel.org>
Wed, 15 Nov 2023 20:03:41 +0000 (12:03 -0800)
Generalize is_branch_taken logic for SCALAR_VALUE register to handle
cases when both registers are not constants. Previously supported
<range> vs <scalar> cases are a natural subset of more generic <range>
vs <range> set of cases.

Generalized logic relies on straightforward segment intersection checks.

Acked-by: Eduard Zingerman <eddyz87@gmail.com>
Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
Acked-by: Shung-Hsi Yu <shung-hsi.yu@suse.com>
Link: https://lore.kernel.org/r/20231112010609.848406-3-andrii@kernel.org
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
kernel/bpf/verifier.c

index 39ce141c55d36b48d406a7dc11c280d48580adc5..f459ad99256e12bca1b84f33ce287c106b851d5c 100644 (file)
@@ -14261,82 +14261,99 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
                                  u8 opcode, bool is_jmp32)
 {
        struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
+       struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
        u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
        u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
        s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
        s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
-       u64 uval = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
-       s64 sval = is_jmp32 ? (s32)uval : (s64)uval;
+       u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
+       u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
+       s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
+       s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
 
        switch (opcode) {
        case BPF_JEQ:
-               if (tnum_is_const(t1))
-                       return !!tnum_equals_const(t1, uval);
-               else if (uval < umin1 || uval > umax1)
+               /* constants, umin/umax and smin/smax checks would be
+                * redundant in this case because they all should match
+                */
+               if (tnum_is_const(t1) && tnum_is_const(t2))
+                       return t1.value == t2.value;
+               /* non-overlapping ranges */
+               if (umin1 > umax2 || umax1 < umin2)
                        return 0;
-               else if (sval < smin1 || sval > smax1)
+               if (smin1 > smax2 || smax1 < smin2)
                        return 0;
                break;
        case BPF_JNE:
-               if (tnum_is_const(t1))
-                       return !tnum_equals_const(t1, uval);
-               else if (uval < umin1 || uval > umax1)
+               /* constants, umin/umax and smin/smax checks would be
+                * redundant in this case because they all should match
+                */
+               if (tnum_is_const(t1) && tnum_is_const(t2))
+                       return t1.value != t2.value;
+               /* non-overlapping ranges */
+               if (umin1 > umax2 || umax1 < umin2)
                        return 1;
-               else if (sval < smin1 || sval > smax1)
+               if (smin1 > smax2 || smax1 < smin2)
                        return 1;
                break;
        case BPF_JSET:
-               if ((~t1.mask & t1.value) & uval)
+               if (!is_reg_const(reg2, is_jmp32)) {
+                       swap(reg1, reg2);
+                       swap(t1, t2);
+               }
+               if (!is_reg_const(reg2, is_jmp32))
+                       return -1;
+               if ((~t1.mask & t1.value) & t2.value)
                        return 1;
-               if (!((t1.mask | t1.value) & uval))
+               if (!((t1.mask | t1.value) & t2.value))
                        return 0;
                break;
        case BPF_JGT:
-               if (umin1 > uval )
+               if (umin1 > umax2)
                        return 1;
-               else if (umax1 <= uval)
+               else if (umax1 <= umin2)
                        return 0;
                break;
        case BPF_JSGT:
-               if (smin1 > sval)
+               if (smin1 > smax2)
                        return 1;
-               else if (smax1 <= sval)
+               else if (smax1 <= smin2)
                        return 0;
                break;
        case BPF_JLT:
-               if (umax1 < uval)
+               if (umax1 < umin2)
                        return 1;
-               else if (umin1 >= uval)
+               else if (umin1 >= umax2)
                        return 0;
                break;
        case BPF_JSLT:
-               if (smax1 < sval)
+               if (smax1 < smin2)
                        return 1;
-               else if (smin1 >= sval)
+               else if (smin1 >= smax2)
                        return 0;
                break;
        case BPF_JGE:
-               if (umin1 >= uval)
+               if (umin1 >= umax2)
                        return 1;
-               else if (umax1 < uval)
+               else if (umax1 < umin2)
                        return 0;
                break;
        case BPF_JSGE:
-               if (smin1 >= sval)
+               if (smin1 >= smax2)
                        return 1;
-               else if (smax1 < sval)
+               else if (smax1 < smin2)
                        return 0;
                break;
        case BPF_JLE:
-               if (umax1 <= uval)
+               if (umax1 <= umin2)
                        return 1;
-               else if (umin1 > uval)
+               else if (umin1 > umax2)
                        return 0;
                break;
        case BPF_JSLE:
-               if (smax1 <= sval)
+               if (smax1 <= smin2)
                        return 1;
-               else if (smin1 > sval)
+               else if (smin1 > smax2)
                        return 0;
                break;
        }
@@ -14415,28 +14432,28 @@ static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg,
 static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
                           u8 opcode, bool is_jmp32)
 {
-       u64 val;
-
        if (reg_is_pkt_pointer_any(reg1) && reg_is_pkt_pointer_any(reg2) && !is_jmp32)
                return is_pkt_ptr_branch_taken(reg1, reg2, opcode);
 
-       /* try to make sure reg2 is a constant SCALAR_VALUE */
-       if (!is_reg_const(reg2, is_jmp32)) {
-               opcode = flip_opcode(opcode);
-               swap(reg1, reg2);
-       }
-       /* for now we expect reg2 to be a constant to make any useful decisions */
-       if (!is_reg_const(reg2, is_jmp32))
-               return -1;
-       val = reg_const_value(reg2, is_jmp32);
+       if (__is_pointer_value(false, reg1) || __is_pointer_value(false, reg2)) {
+               u64 val;
+
+               /* arrange that reg2 is a scalar, and reg1 is a pointer */
+               if (!is_reg_const(reg2, is_jmp32)) {
+                       opcode = flip_opcode(opcode);
+                       swap(reg1, reg2);
+               }
+               /* and ensure that reg2 is a constant */
+               if (!is_reg_const(reg2, is_jmp32))
+                       return -1;
 
-       if (__is_pointer_value(false, reg1)) {
                if (!reg_not_null(reg1))
                        return -1;
 
                /* If pointer is valid tests against zero will fail so we can
                 * use this to direct branch taken.
                 */
+               val = reg_const_value(reg2, is_jmp32);
                if (val != 0)
                        return -1;
 
@@ -14450,6 +14467,7 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
                }
        }
 
+       /* now deal with two scalars, but not necessarily constants */
        return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
 }