#include <linux/filter.h>
 #include <linux/if_vlan.h>
 #include <linux/bpf.h>
-
+#include <asm/extable.h>
 #include <asm/set_memory.h>
 #include <asm/nospec-branch.h>
 
        [AUX_REG] = 3,    /* R11 temp register */
 };
 
+static const int reg2pt_regs[] = {
+       [BPF_REG_0] = offsetof(struct pt_regs, ax),
+       [BPF_REG_1] = offsetof(struct pt_regs, di),
+       [BPF_REG_2] = offsetof(struct pt_regs, si),
+       [BPF_REG_3] = offsetof(struct pt_regs, dx),
+       [BPF_REG_4] = offsetof(struct pt_regs, cx),
+       [BPF_REG_5] = offsetof(struct pt_regs, r8),
+       [BPF_REG_6] = offsetof(struct pt_regs, bx),
+       [BPF_REG_7] = offsetof(struct pt_regs, r13),
+       [BPF_REG_8] = offsetof(struct pt_regs, r14),
+       [BPF_REG_9] = offsetof(struct pt_regs, r15),
+};
+
 /*
  * is_ereg() == true if BPF register 'reg' maps to x86-64 r8..r15
  * which need extra byte of encoding.
        *pprog = prog;
 }
 
+
+static bool ex_handler_bpf(const struct exception_table_entry *x,
+                          struct pt_regs *regs, int trapnr,
+                          unsigned long error_code, unsigned long fault_addr)
+{
+       u32 reg = x->fixup >> 8;
+
+       /* jump over faulting load and clear dest register */
+       *(unsigned long *)((void *)regs + reg) = 0;
+       regs->ip += x->fixup & 0xff;
+       return true;
+}
+
 static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image,
                  int oldproglen, struct jit_context *ctx)
 {
        int insn_cnt = bpf_prog->len;
        bool seen_exit = false;
        u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY];
-       int i, cnt = 0;
+       int i, cnt = 0, excnt = 0;
        int proglen = 0;
        u8 *prog = temp;
 
 
                        /* LDX: dst_reg = *(u8*)(src_reg + off) */
                case BPF_LDX | BPF_MEM | BPF_B:
+               case BPF_LDX | BPF_PROBE_MEM | BPF_B:
                        /* Emit 'movzx rax, byte ptr [rax + off]' */
                        EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB6);
                        goto ldx;
                case BPF_LDX | BPF_MEM | BPF_H:
+               case BPF_LDX | BPF_PROBE_MEM | BPF_H:
                        /* Emit 'movzx rax, word ptr [rax + off]' */
                        EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xB7);
                        goto ldx;
                case BPF_LDX | BPF_MEM | BPF_W:
+               case BPF_LDX | BPF_PROBE_MEM | BPF_W:
                        /* Emit 'mov eax, dword ptr [rax+0x14]' */
                        if (is_ereg(dst_reg) || is_ereg(src_reg))
                                EMIT2(add_2mod(0x40, src_reg, dst_reg), 0x8B);
                                EMIT1(0x8B);
                        goto ldx;
                case BPF_LDX | BPF_MEM | BPF_DW:
+               case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
                        /* Emit 'mov rax, qword ptr [rax+0x14]' */
                        EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x8B);
 ldx:                   /*
                        else
                                EMIT1_off32(add_2reg(0x80, src_reg, dst_reg),
                                            insn->off);
+                       if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
+                               struct exception_table_entry *ex;
+                               u8 *_insn = image + proglen;
+                               s64 delta;
+
+                               if (!bpf_prog->aux->extable)
+                                       break;
+
+                               if (excnt >= bpf_prog->aux->num_exentries) {
+                                       pr_err("ex gen bug\n");
+                                       return -EFAULT;
+                               }
+                               ex = &bpf_prog->aux->extable[excnt++];
+
+                               delta = _insn - (u8 *)&ex->insn;
+                               if (!is_simm32(delta)) {
+                                       pr_err("extable->insn doesn't fit into 32-bit\n");
+                                       return -EFAULT;
+                               }
+                               ex->insn = delta;
+
+                               delta = (u8 *)ex_handler_bpf - (u8 *)&ex->handler;
+                               if (!is_simm32(delta)) {
+                                       pr_err("extable->handler doesn't fit into 32-bit\n");
+                                       return -EFAULT;
+                               }
+                               ex->handler = delta;
+
+                               if (dst_reg > BPF_REG_9) {
+                                       pr_err("verifier error\n");
+                                       return -EFAULT;
+                               }
+                               /*
+                                * Compute size of x86 insn and its target dest x86 register.
+                                * ex_handler_bpf() will use lower 8 bits to adjust
+                                * pt_regs->ip to jump over this x86 instruction
+                                * and upper bits to figure out which pt_regs to zero out.
+                                * End result: x86 insn "mov rbx, qword ptr [rax+0x14]"
+                                * of 4 bytes will be ignored and rbx will be zero inited.
+                                */
+                               ex->fixup = (prog - temp) | (reg2pt_regs[dst_reg] << 8);
+                       }
                        break;
 
                        /* STX XADD: lock *(u32*)(dst_reg + off) += src_reg */
                addrs[i] = proglen;
                prog = temp;
        }
+
+       if (image && excnt != bpf_prog->aux->num_exentries) {
+               pr_err("extable is not populated\n");
+               return -EFAULT;
+       }
        return proglen;
 }
 
                        break;
                }
                if (proglen == oldproglen) {
-                       header = bpf_jit_binary_alloc(proglen, &image,
-                                                     1, jit_fill_hole);
+                       /*
+                        * The number of entries in extable is the number of BPF_LDX
+                        * insns that access kernel memory via "pointer to BTF type".
+                        * The verifier changed their opcode from LDX|MEM|size
+                        * to LDX|PROBE_MEM|size to make JITing easier.
+                        */
+                       u32 align = __alignof__(struct exception_table_entry);
+                       u32 extable_size = prog->aux->num_exentries *
+                               sizeof(struct exception_table_entry);
+
+                       /* allocate module memory for x86 insns and extable */
+                       header = bpf_jit_binary_alloc(roundup(proglen, align) + extable_size,
+                                                     &image, align, jit_fill_hole);
                        if (!header) {
                                prog = orig_prog;
                                goto out_addrs;
                        }
+                       prog->aux->extable = (void *) image + roundup(proglen, align);
                }
                oldproglen = proglen;
                cond_resched();