mptcp: don't overwrite sock_ops in mptcp_is_tcpsk()
authorDavide Caratti <dcaratti@redhat.com>
Tue, 19 Dec 2023 21:31:04 +0000 (22:31 +0100)
committerDavid S. Miller <davem@davemloft.net>
Tue, 26 Dec 2023 22:33:21 +0000 (22:33 +0000)
Eric Dumazet suggests:

 > The fact that mptcp_is_tcpsk() was able to write over sock->ops was a
 > bit strange to me.
 > mptcp_is_tcpsk() should answer a question, with a read-only argument.

re-factor code to avoid overwriting sock_ops inside that function. Also,
change the helper name to reflect the semantics and to disambiguate from
its dual, sk_is_mptcp(). While at it, collapse mptcp_stream_accept() and
mptcp_accept() into a single function, where fallback / non-fallback are
separated into a single sk_is_mptcp() conditional.

Link: https://github.com/multipath-tcp/mptcp_net-next/issues/432
Suggested-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: Davide Caratti <dcaratti@redhat.com>
Acked-by: Paolo Abeni <pabeni@redhat.com>
Signed-off-by: Matthieu Baerts <matttbe@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/mptcp/protocol.c

index 5cd5c3f535a82177cb2c2f23162ccf96ad9e7e09..91e5845d80a96530c37afc0fa47799ea6a8123a8 100644 (file)
@@ -55,28 +55,14 @@ static u64 mptcp_wnd_end(const struct mptcp_sock *msk)
        return READ_ONCE(msk->wnd_end);
 }
 
-static bool mptcp_is_tcpsk(struct sock *sk)
+static const struct proto_ops *mptcp_fallback_tcp_ops(const struct sock *sk)
 {
-       struct socket *sock = sk->sk_socket;
-
-       if (unlikely(sk->sk_prot == &tcp_prot)) {
-               /* we are being invoked after mptcp_accept() has
-                * accepted a non-mp-capable flow: sk is a tcp_sk,
-                * not an mptcp one.
-                *
-                * Hand the socket over to tcp so all further socket ops
-                * bypass mptcp.
-                */
-               WRITE_ONCE(sock->ops, &inet_stream_ops);
-               return true;
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
-       } else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
-               WRITE_ONCE(sock->ops, &inet6_stream_ops);
-               return true;
+       if (sk->sk_prot == &tcpv6_prot)
+               return &inet6_stream_ops;
 #endif
-       }
-
-       return false;
+       WARN_ON_ONCE(sk->sk_prot != &tcp_prot);
+       return &inet_stream_ops;
 }
 
 static int __mptcp_socket_create(struct mptcp_sock *msk)
@@ -3258,44 +3244,6 @@ void mptcp_rcv_space_init(struct mptcp_sock *msk, const struct sock *ssk)
        WRITE_ONCE(msk->wnd_end, msk->snd_nxt + tcp_sk(ssk)->snd_wnd);
 }
 
-static struct sock *mptcp_accept(struct sock *ssk, int flags, int *err,
-                                bool kern)
-{
-       struct sock *newsk;
-
-       pr_debug("ssk=%p, listener=%p", ssk, mptcp_subflow_ctx(ssk));
-       newsk = inet_csk_accept(ssk, flags, err, kern);
-       if (!newsk)
-               return NULL;
-
-       pr_debug("newsk=%p, subflow is mptcp=%d", newsk, sk_is_mptcp(newsk));
-       if (sk_is_mptcp(newsk)) {
-               struct mptcp_subflow_context *subflow;
-               struct sock *new_mptcp_sock;
-
-               subflow = mptcp_subflow_ctx(newsk);
-               new_mptcp_sock = subflow->conn;
-
-               /* is_mptcp should be false if subflow->conn is missing, see
-                * subflow_syn_recv_sock()
-                */
-               if (WARN_ON_ONCE(!new_mptcp_sock)) {
-                       tcp_sk(newsk)->is_mptcp = 0;
-                       goto out;
-               }
-
-               newsk = new_mptcp_sock;
-               MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_MPCAPABLEPASSIVEACK);
-       } else {
-               MPTCP_INC_STATS(sock_net(ssk),
-                               MPTCP_MIB_MPCAPABLEPASSIVEFALLBACK);
-       }
-
-out:
-       newsk->sk_kern_sock = kern;
-       return newsk;
-}
-
 void mptcp_destroy_common(struct mptcp_sock *msk, unsigned int flags)
 {
        struct mptcp_subflow_context *subflow, *tmp;
@@ -3739,7 +3687,6 @@ static struct proto mptcp_prot = {
        .connect        = mptcp_connect,
        .disconnect     = mptcp_disconnect,
        .close          = mptcp_close,
-       .accept         = mptcp_accept,
        .setsockopt     = mptcp_setsockopt,
        .getsockopt     = mptcp_getsockopt,
        .shutdown       = mptcp_shutdown,
@@ -3849,18 +3796,36 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
        if (!ssk)
                return -EINVAL;
 
-       newsk = mptcp_accept(ssk, flags, &err, kern);
+       pr_debug("ssk=%p, listener=%p", ssk, mptcp_subflow_ctx(ssk));
+       newsk = inet_csk_accept(ssk, flags, &err, kern);
        if (!newsk)
                return err;
 
-       lock_sock(newsk);
-
-       __inet_accept(sock, newsock, newsk);
-       if (!mptcp_is_tcpsk(newsock->sk)) {
-               struct mptcp_sock *msk = mptcp_sk(newsk);
+       pr_debug("newsk=%p, subflow is mptcp=%d", newsk, sk_is_mptcp(newsk));
+       if (sk_is_mptcp(newsk)) {
                struct mptcp_subflow_context *subflow;
+               struct sock *new_mptcp_sock;
+
+               subflow = mptcp_subflow_ctx(newsk);
+               new_mptcp_sock = subflow->conn;
+
+               /* is_mptcp should be false if subflow->conn is missing, see
+                * subflow_syn_recv_sock()
+                */
+               if (WARN_ON_ONCE(!new_mptcp_sock)) {
+                       tcp_sk(newsk)->is_mptcp = 0;
+                       goto tcpfallback;
+               }
+
+               newsk = new_mptcp_sock;
+               MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_MPCAPABLEPASSIVEACK);
+
+               newsk->sk_kern_sock = kern;
+               lock_sock(newsk);
+               __inet_accept(sock, newsock, newsk);
 
                set_bit(SOCK_CUSTOM_SOCKOPT, &newsock->flags);
+               msk = mptcp_sk(newsk);
                msk->in_accept_queue = 0;
 
                /* set ssk->sk_socket of accept()ed flows to mptcp socket.
@@ -3882,6 +3847,21 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
                        if (unlikely(list_is_singular(&msk->conn_list)))
                                inet_sk_state_store(newsk, TCP_CLOSE);
                }
+       } else {
+               MPTCP_INC_STATS(sock_net(ssk),
+                               MPTCP_MIB_MPCAPABLEPASSIVEFALLBACK);
+tcpfallback:
+               newsk->sk_kern_sock = kern;
+               lock_sock(newsk);
+               __inet_accept(sock, newsock, newsk);
+               /* we are being invoked after accepting a non-mp-capable
+                * flow: sk is a tcp_sk, not an mptcp one.
+                *
+                * Hand the socket over to tcp so all further socket ops
+                * bypass mptcp.
+                */
+               WRITE_ONCE(newsock->sk->sk_socket->ops,
+                          mptcp_fallback_tcp_ops(newsock->sk));
        }
        release_sock(newsk);