#include "test_siphash.h"
 #include "test_tcp_custom_syncookie.h"
 
+#define MAX_PACKET_OFF 0xffff
+
 /* Hash is calculated for each client and split into ISN and TS.
  *
  *       MSB                                   LSB
 
 struct tcp_syncookie {
        struct __sk_buff *skb;
+       void *data;
        void *data_end;
        struct ethhdr *eth;
        struct iphdr *ipv4;
        struct ipv6hdr *ipv6;
        struct tcphdr *tcp;
-       union {
-               char *ptr;
-               __be32 *ptr32;
-       };
+       __be32 *ptr32;
        struct bpf_tcp_req_attrs attrs;
+       u32 off;
        u32 cookie;
        u64 first;
 };
 
 static int tcp_load_headers(struct tcp_syncookie *ctx)
 {
+       ctx->data = (void *)(long)ctx->skb->data;
        ctx->data_end = (void *)(long)ctx->skb->data_end;
        ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
 
        if (bpf_skb_change_tail(ctx->skb, data_len + 60 - ctx->tcp->doff * 4, 0))
                goto err;
 
+       ctx->data = (void *)(long)ctx->skb->data;
        ctx->data_end = (void *)(long)ctx->skb->data_end;
        ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
        if (ctx->ipv4) {
        return -1;
 }
 
-static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx)
+static __always_inline void *next(struct tcp_syncookie *ctx, __u32 sz)
 {
-       char opcode, opsize;
+       __u64 off = ctx->off;
+       __u8 *data;
 
-       if (ctx->ptr + 1 > ctx->data_end)
-               goto stop;
+       /* Verifier forbids access to packet when offset exceeds MAX_PACKET_OFF */
+       if (off > MAX_PACKET_OFF - sz)
+               return NULL;
+
+       data = ctx->data + off;
+       barrier_var(data);
+       if (data + sz >= ctx->data_end)
+               return NULL;
 
-       opcode = *ctx->ptr++;
+       ctx->off += sz;
+       return data;
+}
 
-       if (opcode == TCPOPT_EOL)
+static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx)
+{
+       __u8 *opcode, *opsize, *wscale;
+       __u32 *tsval, *tsecr;
+       __u16 *mss;
+       __u32 off;
+
+       off = ctx->off;
+       opcode = next(ctx, 1);
+       if (!opcode)
                goto stop;
 
-       if (opcode == TCPOPT_NOP)
+       if (*opcode == TCPOPT_EOL)
+               goto stop;
+
+       if (*opcode == TCPOPT_NOP)
                goto next;
 
-       if (ctx->ptr + 1 > ctx->data_end)
+       opsize = next(ctx, 1);
+       if (!opsize)
                goto stop;
 
-       opsize = *ctx->ptr++;
-
-       if (opsize < 2)
+       if (*opsize < 2)
                goto stop;
 
-       switch (opcode) {
+       switch (*opcode) {
        case TCPOPT_MSS:
-               if (opsize == TCPOLEN_MSS && ctx->tcp->syn &&
-                   ctx->ptr + (TCPOLEN_MSS - 2) < ctx->data_end)
-                       ctx->attrs.mss = get_unaligned_be16(ctx->ptr);
+               mss = next(ctx, 2);
+               if (*opsize == TCPOLEN_MSS && ctx->tcp->syn && mss)
+                       ctx->attrs.mss = get_unaligned_be16(mss);
                break;
        case TCPOPT_WINDOW:
-               if (opsize == TCPOLEN_WINDOW && ctx->tcp->syn &&
-                   ctx->ptr + (TCPOLEN_WINDOW - 2) < ctx->data_end) {
+               wscale = next(ctx, 1);
+               if (*opsize == TCPOLEN_WINDOW && ctx->tcp->syn && wscale) {
                        ctx->attrs.wscale_ok = 1;
-                       ctx->attrs.snd_wscale = *ctx->ptr;
+                       ctx->attrs.snd_wscale = *wscale;
                }
                break;
        case TCPOPT_TIMESTAMP:
-               if (opsize == TCPOLEN_TIMESTAMP &&
-                   ctx->ptr + (TCPOLEN_TIMESTAMP - 2) < ctx->data_end) {
-                       ctx->attrs.rcv_tsval = get_unaligned_be32(ctx->ptr);
-                       ctx->attrs.rcv_tsecr = get_unaligned_be32(ctx->ptr + 4);
+               tsval = next(ctx, 4);
+               tsecr = next(ctx, 4);
+               if (*opsize == TCPOLEN_TIMESTAMP && tsval && tsecr) {
+                       ctx->attrs.rcv_tsval = get_unaligned_be32(tsval);
+                       ctx->attrs.rcv_tsecr = get_unaligned_be32(tsecr);
 
                        if (ctx->tcp->syn && ctx->attrs.rcv_tsecr)
                                ctx->attrs.tstamp_ok = 0;
                }
                break;
        case TCPOPT_SACK_PERM:
-               if (opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn &&
-                   ctx->ptr + (TCPOLEN_SACK_PERM - 2) < ctx->data_end)
+               if (*opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn)
                        ctx->attrs.sack_ok = 1;
                break;
        }
 
-       ctx->ptr += opsize - 2;
+       ctx->off = off + *opsize;
 next:
        return 0;
 stop:
 
 static void tcp_parse_options(struct tcp_syncookie *ctx)
 {
-       ctx->ptr = (char *)(ctx->tcp + 1);
+       ctx->off = (__u8 *)(ctx->tcp + 1) - (__u8 *)ctx->data,
 
        bpf_loop(40, tcp_parse_option, ctx, 0);
 }