EXPORT_SYMBOL_GPL(devl_unlock);
 
 static struct devlink *
-devlinks_xa_find_get(unsigned long *indexp, xa_mark_t filter,
+devlinks_xa_find_get(struct net *net, unsigned long *indexp, xa_mark_t filter,
                     void * (*xa_find_fn)(struct xarray *, unsigned long *,
                                          unsigned long, xa_mark_t))
 {
        xa_find_fn = xa_find_after;
        if (!devlink_try_get(devlink))
                goto retry;
+       if (!net_eq(devlink_net(devlink), net)) {
+               devlink_put(devlink);
+               goto retry;
+       }
 unlock:
        rcu_read_unlock();
        return devlink;
 }
 
-static struct devlink *devlinks_xa_find_get_first(unsigned long *indexp,
+static struct devlink *devlinks_xa_find_get_first(struct net *net,
+                                                 unsigned long *indexp,
                                                  xa_mark_t filter)
 {
-       return devlinks_xa_find_get(indexp, filter, xa_find);
+       return devlinks_xa_find_get(net, indexp, filter, xa_find);
 }
 
-static struct devlink *devlinks_xa_find_get_next(unsigned long *indexp,
+static struct devlink *devlinks_xa_find_get_next(struct net *net,
+                                                unsigned long *indexp,
                                                 xa_mark_t filter)
 {
-       return devlinks_xa_find_get(indexp, filter, xa_find_after);
+       return devlinks_xa_find_get(net, indexp, filter, xa_find_after);
 }
 
 /* Iterate over devlink pointers which were possible to get reference to.
  * devlink_put() needs to be called for each iterated devlink pointer
  * in loop body in order to release the reference.
  */
-#define devlinks_xa_for_each_get(index, devlink, filter)                       \
-       for (index = 0, devlink = devlinks_xa_find_get_first(&index, filter);   \
-            devlink; devlink = devlinks_xa_find_get_next(&index, filter))
+#define devlinks_xa_for_each_get(net, index, devlink, filter)                  \
+       for (index = 0,                                                         \
+            devlink = devlinks_xa_find_get_first(net, &index, filter);         \
+            devlink; devlink = devlinks_xa_find_get_next(net, &index, filter))
 
-#define devlinks_xa_for_each_registered_get(index, devlink)                    \
-       devlinks_xa_for_each_get(index, devlink, DEVLINK_REGISTERED)
+#define devlinks_xa_for_each_registered_get(net, index, devlink)               \
+       devlinks_xa_for_each_get(net, index, devlink, DEVLINK_REGISTERED)
 
 static struct devlink *devlink_get_from_attrs(struct net *net,
                                              struct nlattr **attrs)
        busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]);
        devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]);
 
-       devlinks_xa_for_each_registered_get(index, devlink) {
+       devlinks_xa_for_each_registered_get(net, index, devlink) {
                if (strcmp(devlink->dev->bus->name, busname) == 0 &&
-                   strcmp(dev_name(devlink->dev), devname) == 0 &&
-                   net_eq(devlink_net(devlink), net))
+                   strcmp(dev_name(devlink->dev), devname) == 0)
                        return devlink;
                devlink_put(devlink);
        }
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(devlink_rate, &devlink->rate_list, list) {
                        enum devlink_command cmd = DEVLINK_CMD_RATE_NEW;
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk))) {
-                       devlink_put(devlink);
-                       continue;
-               }
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                if (idx < start) {
                        idx++;
                        devlink_put(devlink);
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(devlink_port, &devlink->port_list, list) {
                        if (idx < start) {
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                mutex_lock(&devlink->linecards_lock);
                list_for_each_entry(linecard, &devlink->linecard_list, list) {
                        if (idx < start) {
                        idx++;
                }
                mutex_unlock(&devlink->linecards_lock);
-retry:
                devlink_put(devlink);
        }
 out:
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(devlink_sb, &devlink->sb_list, list) {
                        if (idx < start) {
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
                    !devlink->ops->sb_pool_get)
                        goto retry;
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
-                   !devlink->ops->sb_port_pool_get)
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
+               if (!devlink->ops->sb_port_pool_get)
                        goto retry;
 
                devl_lock(devlink);
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)) ||
-                   !devlink->ops->sb_tc_pool_bind_get)
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
+               if (!devlink->ops->sb_tc_pool_bind_get)
                        goto retry;
 
                devl_lock(devlink);
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(param_item, &devlink->param_list, list) {
                        if (idx < start) {
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(devlink_port, &devlink->port_list, list) {
                        list_for_each_entry(param_item,
                        }
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                err = devlink_nl_cmd_region_get_devlink_dumpit(msg, cb, devlink,
                                                               &idx, start);
-retry:
                devlink_put(devlink);
                if (err)
                        goto out;
        int err = 0;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                if (idx < start || !devlink->ops->info_get)
                        goto inc;
 
                }
 inc:
                idx++;
-retry:
                devlink_put(devlink);
        }
        mutex_unlock(&devlink_mutex);
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry_rep;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                mutex_lock(&devlink->reporters_lock);
                list_for_each_entry(reporter, &devlink->reporter_list,
                                    list) {
                        idx++;
                }
                mutex_unlock(&devlink->reporters_lock);
-retry_rep:
                devlink_put(devlink);
        }
 
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry_port;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(port, &devlink->port_list, list) {
                        mutex_lock(&port->reporters_lock);
                        mutex_unlock(&port->reporters_lock);
                }
                devl_unlock(devlink);
-retry_port:
                devlink_put(devlink);
        }
 out:
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(trap_item, &devlink->trap_list, list) {
                        if (idx < start) {
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(group_item, &devlink->trap_group_list,
                                    list) {
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
        int err;
 
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), sock_net(msg->sk)))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) {
                devl_lock(devlink);
                list_for_each_entry(policer_item, &devlink->trap_policer_list,
                                    list) {
                        idx++;
                }
                devl_unlock(devlink);
-retry:
                devlink_put(devlink);
        }
 out:
         * all devlink instances from this namespace into init_net.
         */
        mutex_lock(&devlink_mutex);
-       devlinks_xa_for_each_registered_get(index, devlink) {
-               if (!net_eq(devlink_net(devlink), net))
-                       goto retry;
-
+       devlinks_xa_for_each_registered_get(net, index, devlink) {
                WARN_ON(!(devlink->features & DEVLINK_F_RELOAD));
                err = devlink_reload(devlink, &init_net,
                                     DEVLINK_RELOAD_ACTION_DRIVER_REINIT,
                                     &actions_performed, NULL);
                if (err && err != -EOPNOTSUPP)
                        pr_warn("Failed to reload devlink instance into init_net\n");
-retry:
                devlink_put(devlink);
        }
        mutex_unlock(&devlink_mutex);