static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
                            struct sk_buff *skb)
 {
-       int ret;
-
-       /* strparser clones the skb before handing it to a upper layer,
-        * meaning we have the same data, but sk is NULL. We do want an
-        * sk pointer though when we run the BPF program. So we set it
-        * here and then NULL it to ensure we don't trigger a BUG_ON()
-        * in skb/sk operations later if kfree_skb is called with a
-        * valid skb->sk pointer and no destructor assigned.
-        */
-       skb->sk = psock->sk;
        bpf_compute_data_end_sk_skb(skb);
-       ret = bpf_prog_run_pin_on_cpu(prog, skb);
-       skb->sk = NULL;
-       return ret;
+       return bpf_prog_run_pin_on_cpu(prog, skb);
 }
 
 static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
        schedule_work(&psock_other->work);
 }
 
-static void sk_psock_tls_verdict_apply(struct sk_buff *skb, int verdict)
+static void sk_psock_tls_verdict_apply(struct sk_buff *skb, struct sock *sk, int verdict)
 {
        switch (verdict) {
        case __SK_REDIRECT:
+               skb_set_owner_r(skb, sk);
                sk_psock_skb_redirect(skb);
                break;
        case __SK_PASS:
        rcu_read_lock();
        prog = READ_ONCE(psock->progs.skb_verdict);
        if (likely(prog)) {
+               /* We skip full set_owner_r here because if we do a SK_PASS
+                * or SK_DROP we can skip skb memory accounting and use the
+                * TLS context.
+                */
+               skb->sk = psock->sk;
                tcp_skb_bpf_redirect_clear(skb);
                ret = sk_psock_bpf_run(psock, prog, skb);
                ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
+               skb->sk = NULL;
        }
-       sk_psock_tls_verdict_apply(skb, ret);
+       sk_psock_tls_verdict_apply(skb, psock->sk, ret);
        rcu_read_unlock();
        return ret;
 }
                kfree_skb(skb);
                goto out;
        }
+       skb_set_owner_r(skb, sk);
        prog = READ_ONCE(psock->progs.skb_verdict);
        if (likely(prog)) {
                tcp_skb_bpf_redirect_clear(skb);
 
        rcu_read_lock();
        prog = READ_ONCE(psock->progs.skb_parser);
-       if (likely(prog))
+       if (likely(prog)) {
+               skb->sk = psock->sk;
                ret = sk_psock_bpf_run(psock, prog, skb);
+               skb->sk = NULL;
+       }
        rcu_read_unlock();
        return ret;
 }