bpf, riscv: Implement PROBE_MEM32 pseudo instructions
authorPuranjay Mohan <puranjay12@gmail.com>
Thu, 4 Apr 2024 11:42:02 +0000 (11:42 +0000)
committerDaniel Borkmann <daniel@iogearbox.net>
Thu, 4 Apr 2024 14:48:10 +0000 (16:48 +0200)
Add support for [LDX | STX | ST], PROBE_MEM32, [B | H | W | DW]
instructions. They are similar to PROBE_MEM instructions with the
following differences:

- PROBE_MEM32 supports store.
- PROBE_MEM32 relies on the verifier to clear upper 32-bit of the
  src/dst register
- PROBE_MEM32 adds 64-bit kern_vm_start address (which is stored in S7
  in the prologue). Due to bpf_arena constructions such S7 + reg +
  off16 access is guaranteed to be within arena virtual range, so no
  address check at run-time.
- S11 is a free callee-saved register, so it is used to store kern_vm_start
- PROBE_MEM32 allows STX and ST. If they fault the store is a nop. When
  LDX faults the destination register is zeroed.

To support these on riscv, we do tmp = S7 + src/dst reg and then use
tmp2 as the new src/dst register. This allows us to reuse most of the
code for normal [LDX | STX | ST].

Signed-off-by: Puranjay Mohan <puranjay12@gmail.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Tested-by: Björn Töpel <bjorn@rivosinc.com>
Tested-by: Pu Lehui <pulehui@huawei.com>
Reviewed-by: Pu Lehui <pulehui@huawei.com>
Acked-by: Björn Töpel <bjorn@kernel.org>
Link: https://lore.kernel.org/bpf/20240404114203.105970-2-puranjay12@gmail.com
arch/riscv/net/bpf_jit.h
arch/riscv/net/bpf_jit_comp64.c
arch/riscv/net/bpf_jit_core.c

index f4b6b3b9edda3668c2075e0c42f8a8da2d92cac5..8a47da08dd9c1a506377f50092c6f353c5e3bca2 100644 (file)
@@ -81,6 +81,7 @@ struct rv_jit_context {
        int nexentries;
        unsigned long flags;
        int stack_size;
+       u64 arena_vm_start;
 };
 
 /* Convert from ninsns to bytes. */
index 1adf2f39ce59cbb691b7f89ae9fc7a5127642ca4..a4c8e1e6c1e233aa71340530b279285c9a09b6ca 100644 (file)
@@ -18,6 +18,7 @@
 
 #define RV_REG_TCC RV_REG_A6
 #define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
+#define RV_REG_ARENA RV_REG_S7 /* For storing arena_vm_start */
 
 static const int regmap[] = {
        [BPF_REG_0] =   RV_REG_A5,
@@ -255,6 +256,10 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
                emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
                store_offset -= 8;
        }
+       if (ctx->arena_vm_start) {
+               emit_ld(RV_REG_ARENA, store_offset, RV_REG_SP, ctx);
+               store_offset -= 8;
+       }
 
        emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
        /* Set return value. */
@@ -548,6 +553,7 @@ static void emit_atomic(u8 rd, u8 rs, s16 off, s32 imm, bool is64,
 
 #define BPF_FIXUP_OFFSET_MASK   GENMASK(26, 0)
 #define BPF_FIXUP_REG_MASK      GENMASK(31, 27)
+#define REG_DONT_CLEAR_MARKER  0       /* RV_REG_ZERO unused in pt_regmap */
 
 bool ex_handler_bpf(const struct exception_table_entry *ex,
                    struct pt_regs *regs)
@@ -555,7 +561,8 @@ bool ex_handler_bpf(const struct exception_table_entry *ex,
        off_t offset = FIELD_GET(BPF_FIXUP_OFFSET_MASK, ex->fixup);
        int regs_offset = FIELD_GET(BPF_FIXUP_REG_MASK, ex->fixup);
 
-       *(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
+       if (regs_offset != REG_DONT_CLEAR_MARKER)
+               *(unsigned long *)((void *)regs + pt_regmap[regs_offset]) = 0;
        regs->epc = (unsigned long)&ex->fixup - offset;
 
        return true;
@@ -572,7 +579,8 @@ static int add_exception_handler(const struct bpf_insn *insn,
        off_t fixup_offset;
 
        if (!ctx->insns || !ctx->ro_insns || !ctx->prog->aux->extable ||
-           (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX))
+           (BPF_MODE(insn->code) != BPF_PROBE_MEM && BPF_MODE(insn->code) != BPF_PROBE_MEMSX &&
+            BPF_MODE(insn->code) != BPF_PROBE_MEM32))
                return 0;
 
        if (WARN_ON_ONCE(ctx->nexentries >= ctx->prog->aux->num_exentries))
@@ -1539,6 +1547,11 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
        case BPF_LDX | BPF_PROBE_MEMSX | BPF_B:
        case BPF_LDX | BPF_PROBE_MEMSX | BPF_H:
        case BPF_LDX | BPF_PROBE_MEMSX | BPF_W:
+       /* LDX | PROBE_MEM32: dst = *(unsigned size *)(src + RV_REG_ARENA + off) */
+       case BPF_LDX | BPF_PROBE_MEM32 | BPF_B:
+       case BPF_LDX | BPF_PROBE_MEM32 | BPF_H:
+       case BPF_LDX | BPF_PROBE_MEM32 | BPF_W:
+       case BPF_LDX | BPF_PROBE_MEM32 | BPF_DW:
        {
                int insn_len, insns_start;
                bool sign_ext;
@@ -1546,6 +1559,11 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
                sign_ext = BPF_MODE(insn->code) == BPF_MEMSX ||
                           BPF_MODE(insn->code) == BPF_PROBE_MEMSX;
 
+               if (BPF_MODE(insn->code) == BPF_PROBE_MEM32) {
+                       emit_add(RV_REG_T2, rs, RV_REG_ARENA, ctx);
+                       rs = RV_REG_T2;
+               }
+
                switch (BPF_SIZE(code)) {
                case BPF_B:
                        if (is_12b_int(off)) {
@@ -1682,6 +1700,86 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
                emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
                break;
 
+       case BPF_ST | BPF_PROBE_MEM32 | BPF_B:
+       case BPF_ST | BPF_PROBE_MEM32 | BPF_H:
+       case BPF_ST | BPF_PROBE_MEM32 | BPF_W:
+       case BPF_ST | BPF_PROBE_MEM32 | BPF_DW:
+       {
+               int insn_len, insns_start;
+
+               emit_add(RV_REG_T3, rd, RV_REG_ARENA, ctx);
+               rd = RV_REG_T3;
+
+               /* Load imm to a register then store it */
+               emit_imm(RV_REG_T1, imm, ctx);
+
+               switch (BPF_SIZE(code)) {
+               case BPF_B:
+                       if (is_12b_int(off)) {
+                               insns_start = ctx->ninsns;
+                               emit(rv_sb(rd, off, RV_REG_T1), ctx);
+                               insn_len = ctx->ninsns - insns_start;
+                               break;
+                       }
+
+                       emit_imm(RV_REG_T2, off, ctx);
+                       emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
+                       insns_start = ctx->ninsns;
+                       emit(rv_sb(RV_REG_T2, 0, RV_REG_T1), ctx);
+                       insn_len = ctx->ninsns - insns_start;
+                       break;
+               case BPF_H:
+                       if (is_12b_int(off)) {
+                               insns_start = ctx->ninsns;
+                               emit(rv_sh(rd, off, RV_REG_T1), ctx);
+                               insn_len = ctx->ninsns - insns_start;
+                               break;
+                       }
+
+                       emit_imm(RV_REG_T2, off, ctx);
+                       emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
+                       insns_start = ctx->ninsns;
+                       emit(rv_sh(RV_REG_T2, 0, RV_REG_T1), ctx);
+                       insn_len = ctx->ninsns - insns_start;
+                       break;
+               case BPF_W:
+                       if (is_12b_int(off)) {
+                               insns_start = ctx->ninsns;
+                               emit_sw(rd, off, RV_REG_T1, ctx);
+                               insn_len = ctx->ninsns - insns_start;
+                               break;
+                       }
+
+                       emit_imm(RV_REG_T2, off, ctx);
+                       emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
+                       insns_start = ctx->ninsns;
+                       emit_sw(RV_REG_T2, 0, RV_REG_T1, ctx);
+                       insn_len = ctx->ninsns - insns_start;
+                       break;
+               case BPF_DW:
+                       if (is_12b_int(off)) {
+                               insns_start = ctx->ninsns;
+                               emit_sd(rd, off, RV_REG_T1, ctx);
+                               insn_len = ctx->ninsns - insns_start;
+                               break;
+                       }
+
+                       emit_imm(RV_REG_T2, off, ctx);
+                       emit_add(RV_REG_T2, RV_REG_T2, rd, ctx);
+                       insns_start = ctx->ninsns;
+                       emit_sd(RV_REG_T2, 0, RV_REG_T1, ctx);
+                       insn_len = ctx->ninsns - insns_start;
+                       break;
+               }
+
+               ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
+                                           insn_len);
+               if (ret)
+                       return ret;
+
+               break;
+       }
+
        /* STX: *(size *)(dst + off) = src */
        case BPF_STX | BPF_MEM | BPF_B:
                if (is_12b_int(off)) {
@@ -1728,6 +1826,84 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
                emit_atomic(rd, rs, off, imm,
                            BPF_SIZE(code) == BPF_DW, ctx);
                break;
+
+       case BPF_STX | BPF_PROBE_MEM32 | BPF_B:
+       case BPF_STX | BPF_PROBE_MEM32 | BPF_H:
+       case BPF_STX | BPF_PROBE_MEM32 | BPF_W:
+       case BPF_STX | BPF_PROBE_MEM32 | BPF_DW:
+       {
+               int insn_len, insns_start;
+
+               emit_add(RV_REG_T2, rd, RV_REG_ARENA, ctx);
+               rd = RV_REG_T2;
+
+               switch (BPF_SIZE(code)) {
+               case BPF_B:
+                       if (is_12b_int(off)) {
+                               insns_start = ctx->ninsns;
+                               emit(rv_sb(rd, off, rs), ctx);
+                               insn_len = ctx->ninsns - insns_start;
+                               break;
+                       }
+
+                       emit_imm(RV_REG_T1, off, ctx);
+                       emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
+                       insns_start = ctx->ninsns;
+                       emit(rv_sb(RV_REG_T1, 0, rs), ctx);
+                       insn_len = ctx->ninsns - insns_start;
+                       break;
+               case BPF_H:
+                       if (is_12b_int(off)) {
+                               insns_start = ctx->ninsns;
+                               emit(rv_sh(rd, off, rs), ctx);
+                               insn_len = ctx->ninsns - insns_start;
+                               break;
+                       }
+
+                       emit_imm(RV_REG_T1, off, ctx);
+                       emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
+                       insns_start = ctx->ninsns;
+                       emit(rv_sh(RV_REG_T1, 0, rs), ctx);
+                       insn_len = ctx->ninsns - insns_start;
+                       break;
+               case BPF_W:
+                       if (is_12b_int(off)) {
+                               insns_start = ctx->ninsns;
+                               emit_sw(rd, off, rs, ctx);
+                               insn_len = ctx->ninsns - insns_start;
+                               break;
+                       }
+
+                       emit_imm(RV_REG_T1, off, ctx);
+                       emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
+                       insns_start = ctx->ninsns;
+                       emit_sw(RV_REG_T1, 0, rs, ctx);
+                       insn_len = ctx->ninsns - insns_start;
+                       break;
+               case BPF_DW:
+                       if (is_12b_int(off)) {
+                               insns_start = ctx->ninsns;
+                               emit_sd(rd, off, rs, ctx);
+                               insn_len = ctx->ninsns - insns_start;
+                               break;
+                       }
+
+                       emit_imm(RV_REG_T1, off, ctx);
+                       emit_add(RV_REG_T1, RV_REG_T1, rd, ctx);
+                       insns_start = ctx->ninsns;
+                       emit_sd(RV_REG_T1, 0, rs, ctx);
+                       insn_len = ctx->ninsns - insns_start;
+                       break;
+               }
+
+               ret = add_exception_handler(insn, ctx, REG_DONT_CLEAR_MARKER,
+                                           insn_len);
+               if (ret)
+                       return ret;
+
+               break;
+       }
+
        default:
                pr_err("bpf-jit: unknown opcode %02x\n", code);
                return -EINVAL;
@@ -1759,6 +1935,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
                stack_adjust += 8;
        if (seen_reg(RV_REG_S6, ctx))
                stack_adjust += 8;
+       if (ctx->arena_vm_start)
+               stack_adjust += 8;
 
        stack_adjust = round_up(stack_adjust, 16);
        stack_adjust += bpf_stack_adjust;
@@ -1810,6 +1988,10 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
                emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
                store_offset -= 8;
        }
+       if (ctx->arena_vm_start) {
+               emit_sd(RV_REG_SP, store_offset, RV_REG_ARENA, ctx);
+               store_offset -= 8;
+       }
 
        emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
 
@@ -1823,6 +2005,9 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx, bool is_subprog)
                emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
 
        ctx->stack_size = stack_adjust;
+
+       if (ctx->arena_vm_start)
+               emit_imm(RV_REG_ARENA, ctx->arena_vm_start, ctx);
 }
 
 void bpf_jit_build_epilogue(struct rv_jit_context *ctx)
index 6b3acac30c06199480a8397b2f59886184e3d1c3..9ab739b9f9a2042c00ce56e29bf8f49e4c37c52a 100644 (file)
@@ -80,6 +80,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                goto skip_init_ctx;
        }
 
+       ctx->arena_vm_start = bpf_arena_get_kern_vm_start(prog->aux->arena);
        ctx->prog = prog;
        ctx->offset = kcalloc(prog->len, sizeof(int), GFP_KERNEL);
        if (!ctx->offset) {