Use READ/WRITE_ONCE() for IP local_port_range.
authorDavid Laight <David.Laight@ACULAB.COM>
Wed, 6 Dec 2023 13:44:20 +0000 (13:44 +0000)
committerJakub Kicinski <kuba@kernel.org>
Fri, 8 Dec 2023 18:44:42 +0000 (10:44 -0800)
Commit 227b60f5102cd added a seqlock to ensure that the low and high
port numbers were always updated together.
This is overkill because the two 16bit port numbers can be held in
a u32 and read/written in a single instruction.

More recently 91d0b78c5177f added support for finer per-socket limits.
The user-supplied value is 'high << 16 | low' but they are held
separately and the socket options protected by the socket lock.

Use a u32 containing 'high << 16 | low' for both the 'net' and 'sk'
fields and use READ_ONCE()/WRITE_ONCE() to ensure both values are
always updated together.

Change (the now trival) inet_get_local_port_range() to a static inline
to optimise the calling code.
(In particular avoiding returning integers by reference.)

Signed-off-by: David Laight <david.laight@aculab.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Reviewed-by: David Ahern <dsahern@kernel.org>
Acked-by: Mat Martineau <martineau@kernel.org>
Reviewed-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Link: https://lore.kernel.org/r/4e505d4198e946a8be03fb1b4c3072b0@AcuMS.aculab.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
include/net/inet_sock.h
include/net/ip.h
include/net/netns/ipv4.h
net/ipv4/af_inet.c
net/ipv4/inet_connection_sock.c
net/ipv4/ip_sockglue.c
net/ipv4/sysctl_net_ipv4.c

index 74db6d97cae107c770746317cfc49cbcc43e89c4..aa86453f6b9ba367f772570a7b783bb098be6236 100644 (file)
@@ -234,10 +234,7 @@ struct inet_sock {
        int                     uc_index;
        int                     mc_index;
        __be32                  mc_addr;
-       struct {
-               __u16 lo;
-               __u16 hi;
-       }                       local_port_range;
+       u32                     local_port_range;       /* high << 16 | low */
 
        struct ip_mc_socklist __rcu     *mc_list;
        struct inet_cork_full   cork;
index 1fc4c8d69e333e81b6fae1840262df18c2c66e25..b31be912489af8b01cc0393a27ffc80b086feaa0 100644 (file)
@@ -349,7 +349,13 @@ static inline u64 snmp_fold_field64(void __percpu *mib, int offt, size_t syncp_o
        } \
 }
 
-void inet_get_local_port_range(const struct net *net, int *low, int *high);
+static inline void inet_get_local_port_range(const struct net *net, int *low, int *high)
+{
+       u32 range = READ_ONCE(net->ipv4.ip_local_ports.range);
+
+       *low = range & 0xffff;
+       *high = range >> 16;
+}
 void inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high);
 
 #ifdef CONFIG_SYSCTL
index ea882964c71ee35048c4b8da81cb3226b79b5cec..c356c458b3409f966c3c2a1feddc0fa26b5f7088 100644 (file)
@@ -19,8 +19,7 @@ struct hlist_head;
 struct fib_table;
 struct sock;
 struct local_ports {
-       seqlock_t       lock;
-       int             range[2];
+       u32             range;  /* high << 16 | low */
        bool            warned;
 };
 
index fb81de10d3320c319b5c449e2ad4ee6da6c0df44..fbeacf04dbf3744e5888360e0b74bf6f70ff214f 100644 (file)
@@ -1847,9 +1847,7 @@ static __net_init int inet_init_net(struct net *net)
        /*
         * Set defaults for local port range
         */
-       seqlock_init(&net->ipv4.ip_local_ports.lock);
-       net->ipv4.ip_local_ports.range[0] =  32768;
-       net->ipv4.ip_local_ports.range[1] =  60999;
+       net->ipv4.ip_local_ports.range = 60999u << 16 | 32768u;
 
        seqlock_init(&net->ipv4.ping_group_range.lock);
        /*
index 394a498c28232213c1c3fb6f98af922d61a25418..70be0f6fe879ea671bf6686b04edf32bf5e0d4b6 100644 (file)
@@ -117,34 +117,25 @@ bool inet_rcv_saddr_any(const struct sock *sk)
        return !sk->sk_rcv_saddr;
 }
 
-void inet_get_local_port_range(const struct net *net, int *low, int *high)
-{
-       unsigned int seq;
-
-       do {
-               seq = read_seqbegin(&net->ipv4.ip_local_ports.lock);
-
-               *low = net->ipv4.ip_local_ports.range[0];
-               *high = net->ipv4.ip_local_ports.range[1];
-       } while (read_seqretry(&net->ipv4.ip_local_ports.lock, seq));
-}
-EXPORT_SYMBOL(inet_get_local_port_range);
-
 void inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high)
 {
        const struct inet_sock *inet = inet_sk(sk);
        const struct net *net = sock_net(sk);
        int lo, hi, sk_lo, sk_hi;
+       u32 sk_range;
 
        inet_get_local_port_range(net, &lo, &hi);
 
-       sk_lo = inet->local_port_range.lo;
-       sk_hi = inet->local_port_range.hi;
+       sk_range = READ_ONCE(inet->local_port_range);
+       if (unlikely(sk_range)) {
+               sk_lo = sk_range & 0xffff;
+               sk_hi = sk_range >> 16;
 
-       if (unlikely(lo <= sk_lo && sk_lo <= hi))
-               lo = sk_lo;
-       if (unlikely(lo <= sk_hi && sk_hi <= hi))
-               hi = sk_hi;
+               if (lo <= sk_lo && sk_lo <= hi)
+                       lo = sk_lo;
+               if (lo <= sk_hi && sk_hi <= hi)
+                       hi = sk_hi;
+       }
 
        *low = lo;
        *high = hi;
index 2efc53526a382a5c59eed393796315af9fd9cd2c..d7d13940774e837a9f9baeb4f971126abc26d9fc 100644 (file)
@@ -1055,6 +1055,19 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
        case IP_TOS:    /* This sets both TOS and Precedence */
                ip_sock_set_tos(sk, val);
                return 0;
+       case IP_LOCAL_PORT_RANGE:
+       {
+               u16 lo = val;
+               u16 hi = val >> 16;
+
+               if (optlen != sizeof(u32))
+                       return -EINVAL;
+               if (lo != 0 && hi != 0 && lo > hi)
+                       return -EINVAL;
+
+               WRITE_ONCE(inet->local_port_range, val);
+               return 0;
+       }
        }
 
        err = 0;
@@ -1332,20 +1345,6 @@ int do_ip_setsockopt(struct sock *sk, int level, int optname,
                err = xfrm_user_policy(sk, optname, optval, optlen);
                break;
 
-       case IP_LOCAL_PORT_RANGE:
-       {
-               const __u16 lo = val;
-               const __u16 hi = val >> 16;
-
-               if (optlen != sizeof(__u32))
-                       goto e_inval;
-               if (lo != 0 && hi != 0 && lo > hi)
-                       goto e_inval;
-
-               inet->local_port_range.lo = lo;
-               inet->local_port_range.hi = hi;
-               break;
-       }
        default:
                err = -ENOPROTOOPT;
                break;
@@ -1692,6 +1691,9 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
                        return -EFAULT;
                return 0;
        }
+       case IP_LOCAL_PORT_RANGE:
+               val = READ_ONCE(inet->local_port_range);
+               goto copyval;
        }
 
        if (needs_rtnl)
@@ -1721,9 +1723,6 @@ int do_ip_getsockopt(struct sock *sk, int level, int optname,
                else
                        err = ip_get_mcast_msfilter(sk, optval, optlen, len);
                goto out;
-       case IP_LOCAL_PORT_RANGE:
-               val = inet->local_port_range.hi << 16 | inet->local_port_range.lo;
-               break;
        case IP_PROTOCOL:
                val = inet_sk(sk)->inet_num;
                break;
index f63a545a73749714a747872457ebe097907e8ea8..7e4f16a7dcc17c02b13445ab8ecf44c3d88993ba 100644 (file)
@@ -50,26 +50,22 @@ static int tcp_plb_max_cong_thresh = 256;
 static int sysctl_tcp_low_latency __read_mostly;
 
 /* Update system visible IP port range */
-static void set_local_port_range(struct net *net, int range[2])
+static void set_local_port_range(struct net *net, unsigned int low, unsigned int high)
 {
-       bool same_parity = !((range[0] ^ range[1]) & 1);
+       bool same_parity = !((low ^ high) & 1);
 
-       write_seqlock_bh(&net->ipv4.ip_local_ports.lock);
        if (same_parity && !net->ipv4.ip_local_ports.warned) {
                net->ipv4.ip_local_ports.warned = true;
                pr_err_ratelimited("ip_local_port_range: prefer different parity for start/end values.\n");
        }
-       net->ipv4.ip_local_ports.range[0] = range[0];
-       net->ipv4.ip_local_ports.range[1] = range[1];
-       write_sequnlock_bh(&net->ipv4.ip_local_ports.lock);
+       WRITE_ONCE(net->ipv4.ip_local_ports.range, high << 16 | low);
 }
 
 /* Validate changes from /proc interface. */
 static int ipv4_local_port_range(struct ctl_table *table, int write,
                                 void *buffer, size_t *lenp, loff_t *ppos)
 {
-       struct net *net =
-               container_of(table->data, struct net, ipv4.ip_local_ports.range);
+       struct net *net = table->data;
        int ret;
        int range[2];
        struct ctl_table tmp = {
@@ -93,7 +89,7 @@ static int ipv4_local_port_range(struct ctl_table *table, int write,
                    (range[0] < READ_ONCE(net->ipv4.sysctl_ip_prot_sock)))
                        ret = -EINVAL;
                else
-                       set_local_port_range(net, range);
+                       set_local_port_range(net, range[0], range[1]);
        }
 
        return ret;
@@ -733,8 +729,8 @@ static struct ctl_table ipv4_net_table[] = {
        },
        {
                .procname       = "ip_local_port_range",
-               .maxlen         = sizeof(init_net.ipv4.ip_local_ports.range),
-               .data           = &init_net.ipv4.ip_local_ports.range,
+               .maxlen         = 0,
+               .data           = &init_net,
                .mode           = 0644,
                .proc_handler   = ipv4_local_port_range,
        },