Currently the bpf_sk_assign helper in tc BPF context refuses SO_REUSEPORT
sockets. This means we can't use the helper to steer traffic to Envoy,
which configures SO_REUSEPORT on its sockets. In turn, we're blocked
from removing TPROXY from our setup.
The reason that bpf_sk_assign refuses such sockets is that the
bpf_sk_lookup helpers don't execute SK_REUSEPORT programs. Instead,
one of the reuseport sockets is selected by hash. This could cause
dispatch to the "wrong" socket:
    sk = bpf_sk_lookup_tcp(...) // select SO_REUSEPORT by hash
    bpf_sk_assign(skb, sk) // SK_REUSEPORT wasn't executed
Fixing this isn't as simple as invoking SK_REUSEPORT from the lookup
helpers unfortunately. In the tc context, L2 headers are at the start
of the skb, while SK_REUSEPORT expects L3 headers instead.
Instead, we execute the SK_REUSEPORT program when the assigned socket
is pulled out of the skb, further up the stack. This creates some
trickiness with regards to refcounting as bpf_sk_assign will put both
refcounted and RCU freed sockets in skb->sk. reuseport sockets are RCU
freed. We can infer that the sk_assigned socket is RCU freed if the
reuseport lookup succeeds, but convincing yourself of this fact isn't
straight forward. Therefore we defensively check refcounting on the
sk_assign sock even though it's probably not required in practice.
Fixes: 8e368dc72e86 ("bpf: Fix use of sk->sk_reuseport from sk_assign")
Fixes: cf7fbe660f2d ("bpf: Add socket assign support")
Co-developed-by: Daniel Borkmann <daniel@iogearbox.net>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Cc: Joe Stringer <joe@cilium.io>
Link: https://lore.kernel.org/bpf/CACAyw98+qycmpQzKupquhkxbvWK4OFyDuuLMBNROnfWMZxUWeA@mail.gmail.com/
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Signed-off-by: Lorenz Bauer <lmb@isovalent.com>
Link: https://lore.kernel.org/r/20230720-so-reuseport-v6-7-7021b683cdae@isovalent.com
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
                                     daddr, hnum, dif, sdif);
 }
 
+static inline
+struct sock *inet6_steal_sock(struct net *net, struct sk_buff *skb, int doff,
+                             const struct in6_addr *saddr, const __be16 sport,
+                             const struct in6_addr *daddr, const __be16 dport,
+                             bool *refcounted, inet6_ehashfn_t *ehashfn)
+{
+       struct sock *sk, *reuse_sk;
+       bool prefetched;
+
+       sk = skb_steal_sock(skb, refcounted, &prefetched);
+       if (!sk)
+               return NULL;
+
+       if (!prefetched)
+               return sk;
+
+       if (sk->sk_protocol == IPPROTO_TCP) {
+               if (sk->sk_state != TCP_LISTEN)
+                       return sk;
+       } else if (sk->sk_protocol == IPPROTO_UDP) {
+               if (sk->sk_state != TCP_CLOSE)
+                       return sk;
+       } else {
+               return sk;
+       }
+
+       reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
+                                         saddr, sport, daddr, ntohs(dport),
+                                         ehashfn);
+       if (!reuse_sk)
+               return sk;
+
+       /* We've chosen a new reuseport sock which is never refcounted. This
+        * implies that sk also isn't refcounted.
+        */
+       WARN_ON_ONCE(*refcounted);
+
+       return reuse_sk;
+}
+
 static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
                                              struct sk_buff *skb, int doff,
                                              const __be16 sport,
                                              int iif, int sdif,
                                              bool *refcounted)
 {
-       struct sock *sk = skb_steal_sock(skb, refcounted);
-
+       struct net *net = dev_net(skb_dst(skb)->dev);
+       const struct ipv6hdr *ip6h = ipv6_hdr(skb);
+       struct sock *sk;
+
+       sk = inet6_steal_sock(net, skb, doff, &ip6h->saddr, sport, &ip6h->daddr, dport,
+                             refcounted, inet6_ehashfn);
+       if (IS_ERR(sk))
+               return NULL;
        if (sk)
                return sk;
 
-       return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
-                             doff, &ipv6_hdr(skb)->saddr, sport,
-                             &ipv6_hdr(skb)->daddr, ntohs(dport),
+       return __inet6_lookup(net, hashinfo, skb,
+                             doff, &ip6h->saddr, sport,
+                             &ip6h->daddr, ntohs(dport),
                              iif, sdif, refcounted);
 }
 
 
        return sk;
 }
 
+static inline
+struct sock *inet_steal_sock(struct net *net, struct sk_buff *skb, int doff,
+                            const __be32 saddr, const __be16 sport,
+                            const __be32 daddr, const __be16 dport,
+                            bool *refcounted, inet_ehashfn_t *ehashfn)
+{
+       struct sock *sk, *reuse_sk;
+       bool prefetched;
+
+       sk = skb_steal_sock(skb, refcounted, &prefetched);
+       if (!sk)
+               return NULL;
+
+       if (!prefetched)
+               return sk;
+
+       if (sk->sk_protocol == IPPROTO_TCP) {
+               if (sk->sk_state != TCP_LISTEN)
+                       return sk;
+       } else if (sk->sk_protocol == IPPROTO_UDP) {
+               if (sk->sk_state != TCP_CLOSE)
+                       return sk;
+       } else {
+               return sk;
+       }
+
+       reuse_sk = inet_lookup_reuseport(net, sk, skb, doff,
+                                        saddr, sport, daddr, ntohs(dport),
+                                        ehashfn);
+       if (!reuse_sk)
+               return sk;
+
+       /* We've chosen a new reuseport sock which is never refcounted. This
+        * implies that sk also isn't refcounted.
+        */
+       WARN_ON_ONCE(*refcounted);
+
+       return reuse_sk;
+}
+
 static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
                                             struct sk_buff *skb,
                                             int doff,
                                             const int sdif,
                                             bool *refcounted)
 {
-       struct sock *sk = skb_steal_sock(skb, refcounted);
+       struct net *net = dev_net(skb_dst(skb)->dev);
        const struct iphdr *iph = ip_hdr(skb);
+       struct sock *sk;
 
+       sk = inet_steal_sock(net, skb, doff, iph->saddr, sport, iph->daddr, dport,
+                            refcounted, inet_ehashfn);
+       if (IS_ERR(sk))
+               return NULL;
        if (sk)
                return sk;
 
-       return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
+       return __inet_lookup(net, hashinfo, skb,
                             doff, iph->saddr, sport,
                             iph->daddr, dport, inet_iif(skb), sdif,
                             refcounted);
 
  * skb_steal_sock - steal a socket from an sk_buff
  * @skb: sk_buff to steal the socket from
  * @refcounted: is set to true if the socket is reference-counted
+ * @prefetched: is set to true if the socket was assigned from bpf
  */
 static inline struct sock *
-skb_steal_sock(struct sk_buff *skb, bool *refcounted)
+skb_steal_sock(struct sk_buff *skb, bool *refcounted, bool *prefetched)
 {
        if (skb->sk) {
                struct sock *sk = skb->sk;
 
                *refcounted = true;
-               if (skb_sk_is_prefetched(skb))
+               *prefetched = skb_sk_is_prefetched(skb);
+               if (*prefetched)
                        *refcounted = sk_is_refcounted(sk);
                skb->destructor = NULL;
                skb->sk = NULL;
                return sk;
        }
+       *prefetched = false;
        *refcounted = false;
        return NULL;
 }
 
  *             **-EOPNOTSUPP** if the operation is not supported, for example
  *             a call from outside of TC ingress.
  *
- *             **-ESOCKTNOSUPPORT** if the socket type is not supported
- *             (reuseport).
- *
  * long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
  *     Description
  *             Helper is overloaded depending on BPF program type. This
 
                return -EOPNOTSUPP;
        if (unlikely(dev_net(skb->dev) != sock_net(sk)))
                return -ENETUNREACH;
-       if (unlikely(sk_fullsock(sk) && sk->sk_reuseport))
-               return -ESOCKTNOSUPPORT;
        if (sk_unhashed(sk))
                return -EOPNOTSUPP;
        if (sk_is_refcounted(sk) &&
 
        if (udp4_csum_init(skb, uh, proto))
                goto csum_error;
 
-       sk = skb_steal_sock(skb, &refcounted);
+       sk = inet_steal_sock(net, skb, sizeof(struct udphdr), saddr, uh->source, daddr, uh->dest,
+                            &refcounted, udp_ehashfn);
+       if (IS_ERR(sk))
+               goto no_sk;
+
        if (sk) {
                struct dst_entry *dst = skb_dst(skb);
                int ret;
        sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
        if (sk)
                return udp_unicast_rcv_skb(sk, skb, uh);
-
+no_sk:
        if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb))
                goto drop;
        nf_reset_ct(skb);
 
                goto csum_error;
 
        /* Check if the socket is already available, e.g. due to early demux */
-       sk = skb_steal_sock(skb, &refcounted);
+       sk = inet6_steal_sock(net, skb, sizeof(struct udphdr), saddr, uh->source, daddr, uh->dest,
+                             &refcounted, udp6_ehashfn);
+       if (IS_ERR(sk))
+               goto no_sk;
+
        if (sk) {
                struct dst_entry *dst = skb_dst(skb);
                int ret;
                        goto report_csum_error;
                return udp6_unicast_rcv_skb(sk, skb, uh);
        }
-
+no_sk:
        reason = SKB_DROP_REASON_NO_SOCKET;
 
        if (!uh->check)
 
  *             **-EOPNOTSUPP** if the operation is not supported, for example
  *             a call from outside of TC ingress.
  *
- *             **-ESOCKTNOSUPPORT** if the socket type is not supported
- *             (reuseport).
- *
  * long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
  *     Description
  *             Helper is overloaded depending on BPF program type. This