genetlink: introduce per-sock family private storage
authorJiri Pirko <jiri@nvidia.com>
Sat, 16 Dec 2023 12:29:57 +0000 (13:29 +0100)
committerPaolo Abeni <pabeni@redhat.com>
Tue, 19 Dec 2023 14:31:40 +0000 (15:31 +0100)
Introduce an xarray for Generic netlink family to store per-socket
private. Initialize this xarray only if family uses per-socket privs.

Introduce genl_sk_priv_get() to get the socket priv pointer for a family
and initialize it in case it does not exist.
Introduce __genl_sk_priv_get() to obtain socket priv pointer for a
family under RCU read lock.

Allow family to specify the priv size, init() and destroy() callbacks.

Signed-off-by: Jiri Pirko <jiri@nvidia.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
include/net/genetlink.h
net/netlink/genetlink.c

index c53244f204370442054501e7c9698b6b7224af8b..6bc37f392a9a6926017c354c9dca60426f1e3557 100644 (file)
@@ -51,6 +51,9 @@ struct genl_info;
  * @split_ops: the split do/dump form of operation definition
  * @n_split_ops: number of entries in @split_ops, not that with split do/dump
  *     ops the number of entries is not the same as number of commands
+ * @sock_priv_size: the size of per-socket private memory
+ * @sock_priv_init: the per-socket private memory initializer
+ * @sock_priv_destroy: the per-socket private memory destructor
  *
  * Attribute policies (the combination of @policy and @maxattr fields)
  * can be attached at the family level or at the operation level.
@@ -84,11 +87,17 @@ struct genl_family {
        const struct genl_multicast_group *mcgrps;
        struct module           *module;
 
+       size_t                  sock_priv_size;
+       void                    (*sock_priv_init)(void *priv);
+       void                    (*sock_priv_destroy)(void *priv);
+
 /* private: internal use only */
        /* protocol family identifier */
        int                     id;
        /* starting number of multicast group IDs in this family */
        unsigned int            mcgrp_offset;
+       /* list of per-socket privs */
+       struct xarray           *sock_privs;
 };
 
 /**
@@ -298,6 +307,8 @@ static inline bool genl_info_is_ntf(const struct genl_info *info)
        return !info->nlhdr;
 }
 
+void *__genl_sk_priv_get(struct genl_family *family, struct sock *sk);
+void *genl_sk_priv_get(struct genl_family *family, struct sock *sk);
 int genl_register_family(struct genl_family *family);
 int genl_unregister_family(const struct genl_family *family);
 void genl_notify(const struct genl_family *family, struct sk_buff *skb,
index 9c7ffd10df2a72c00d626ab40ca95bb739425983..c0d15470a10b342578fcd82d25c19b0f1be03122 100644 (file)
@@ -631,6 +631,138 @@ static int genl_validate_ops(const struct genl_family *family)
        return 0;
 }
 
+static void *genl_sk_priv_alloc(struct genl_family *family)
+{
+       void *priv;
+
+       priv = kzalloc(family->sock_priv_size, GFP_KERNEL);
+       if (!priv)
+               return ERR_PTR(-ENOMEM);
+
+       if (family->sock_priv_init)
+               family->sock_priv_init(priv);
+
+       return priv;
+}
+
+static void genl_sk_priv_free(const struct genl_family *family, void *priv)
+{
+       if (family->sock_priv_destroy)
+               family->sock_priv_destroy(priv);
+       kfree(priv);
+}
+
+static int genl_sk_privs_alloc(struct genl_family *family)
+{
+       if (!family->sock_priv_size)
+               return 0;
+
+       family->sock_privs = kzalloc(sizeof(*family->sock_privs), GFP_KERNEL);
+       if (!family->sock_privs)
+               return -ENOMEM;
+       xa_init(family->sock_privs);
+       return 0;
+}
+
+static void genl_sk_privs_free(const struct genl_family *family)
+{
+       unsigned long id;
+       void *priv;
+
+       if (!family->sock_priv_size)
+               return;
+
+       xa_for_each(family->sock_privs, id, priv)
+               genl_sk_priv_free(family, priv);
+
+       xa_destroy(family->sock_privs);
+       kfree(family->sock_privs);
+}
+
+static void genl_sk_priv_free_by_sock(struct genl_family *family,
+                                     struct sock *sk)
+{
+       void *priv;
+
+       if (!family->sock_priv_size)
+               return;
+       priv = xa_erase(family->sock_privs, (unsigned long) sk);
+       if (!priv)
+               return;
+       genl_sk_priv_free(family, priv);
+}
+
+static void genl_release(struct sock *sk, unsigned long *groups)
+{
+       struct genl_family *family;
+       unsigned int id;
+
+       down_read(&cb_lock);
+
+       idr_for_each_entry(&genl_fam_idr, family, id)
+               genl_sk_priv_free_by_sock(family, sk);
+
+       up_read(&cb_lock);
+}
+
+/**
+ * __genl_sk_priv_get - Get family private pointer for socket, if exists
+ *
+ * @family: family
+ * @sk: socket
+ *
+ * Lookup a private memory for a Generic netlink family and specified socket.
+ *
+ * Caller should make sure this is called in RCU read locked section.
+ *
+ * Return: valid pointer on success, otherwise negative error value
+ * encoded by ERR_PTR(), NULL in case priv does not exist.
+ */
+void *__genl_sk_priv_get(struct genl_family *family, struct sock *sk)
+{
+       if (WARN_ON_ONCE(!family->sock_privs))
+               return ERR_PTR(-EINVAL);
+       return xa_load(family->sock_privs, (unsigned long) sk);
+}
+
+/**
+ * genl_sk_priv_get - Get family private pointer for socket
+ *
+ * @family: family
+ * @sk: socket
+ *
+ * Lookup a private memory for a Generic netlink family and specified socket.
+ * Allocate the private memory in case it was not already done.
+ *
+ * Return: valid pointer on success, otherwise negative error value
+ * encoded by ERR_PTR().
+ */
+void *genl_sk_priv_get(struct genl_family *family, struct sock *sk)
+{
+       void *priv, *old_priv;
+
+       priv = __genl_sk_priv_get(family, sk);
+       if (priv)
+               return priv;
+
+       /* priv for the family does not exist so far, create it. */
+
+       priv = genl_sk_priv_alloc(family);
+       if (IS_ERR(priv))
+               return ERR_CAST(priv);
+
+       old_priv = xa_cmpxchg(family->sock_privs, (unsigned long) sk, NULL,
+                             priv, GFP_KERNEL);
+       if (old_priv) {
+               genl_sk_priv_free(family, priv);
+               if (xa_is_err(old_priv))
+                       return ERR_PTR(xa_err(old_priv));
+               /* Race happened, priv for the socket was already inserted. */
+               return old_priv;
+       }
+       return priv;
+}
+
 /**
  * genl_register_family - register a generic netlink family
  * @family: generic netlink family
@@ -659,6 +791,10 @@ int genl_register_family(struct genl_family *family)
                goto errout_locked;
        }
 
+       err = genl_sk_privs_alloc(family);
+       if (err)
+               goto errout_locked;
+
        /*
         * Sadly, a few cases need to be special-cased
         * due to them having previously abused the API
@@ -679,7 +815,7 @@ int genl_register_family(struct genl_family *family)
                                      start, end + 1, GFP_KERNEL);
        if (family->id < 0) {
                err = family->id;
-               goto errout_locked;
+               goto errout_sk_privs_free;
        }
 
        err = genl_validate_assign_mc_groups(family);
@@ -698,6 +834,8 @@ int genl_register_family(struct genl_family *family)
 
 errout_remove:
        idr_remove(&genl_fam_idr, family->id);
+errout_sk_privs_free:
+       genl_sk_privs_free(family);
 errout_locked:
        genl_unlock_all();
        return err;
@@ -728,6 +866,9 @@ int genl_unregister_family(const struct genl_family *family)
        up_write(&cb_lock);
        wait_event(genl_sk_destructing_waitq,
                   atomic_read(&genl_sk_destructing_cnt) == 0);
+
+       genl_sk_privs_free(family);
+
        genl_unlock();
 
        genl_ctrl_event(CTRL_CMD_DELFAMILY, family, NULL, 0);
@@ -1708,6 +1849,7 @@ static int __net_init genl_pernet_init(struct net *net)
                .input          = genl_rcv,
                .flags          = NL_CFG_F_NONROOT_RECV,
                .bind           = genl_bind,
+               .release        = genl_release,
        };
 
        /* we'll bump the group number right afterwards */