vfio: Set device->group in helper function
authorYi Liu <yi.l.liu@intel.com>
Fri, 25 Nov 2022 05:22:27 +0000 (21:22 -0800)
committerJason Gunthorpe <jgg@nvidia.com>
Mon, 5 Dec 2022 12:56:01 +0000 (08:56 -0400)
This avoids referencing device->group in __vfio_register_dev().

Link: https://lore.kernel.org/r/20221201145535.589687-5-yi.l.liu@intel.com
Reviewed-by: Jason Gunthorpe <jgg@nvidia.com>
Reviewed-by: Kevin Tian <kevin.tian@intel.com>
Reviewed-by: Alex Williamson <alex.williamson@redhat.com>
Tested-by: Lixiao Yang <lixiao.yang@intel.com>
Tested-by: Yu He <yu.he@intel.com>
Signed-off-by: Yi Liu <yi.l.liu@intel.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
drivers/vfio/vfio_main.c

index a5122fa4bf4d1afa7ac8f821b8b87ba88d2617b6..7e42ee0ee1bce02bd0e7542fba577e784dde0a13 100644 (file)
@@ -528,18 +528,29 @@ static void vfio_device_group_unregister(struct vfio_device *device)
        mutex_unlock(&device->group->device_lock);
 }
 
-static int __vfio_register_dev(struct vfio_device *device,
-               struct vfio_group *group)
+static int vfio_device_set_group(struct vfio_device *device,
+                                enum vfio_group_type type)
 {
-       int ret;
+       struct vfio_group *group;
+
+       if (type == VFIO_IOMMU)
+               group = vfio_group_find_or_alloc(device->dev);
+       else
+               group = vfio_noiommu_group_alloc(device->dev, type);
 
-       /*
-        * In all cases group is the output of one of the group allocation
-        * functions and we have group->drivers incremented for us.
-        */
        if (IS_ERR(group))
                return PTR_ERR(group);
 
+       /* Our reference on group is moved to the device */
+       device->group = group;
+       return 0;
+}
+
+static int __vfio_register_dev(struct vfio_device *device,
+                              enum vfio_group_type type)
+{
+       int ret;
+
        if (WARN_ON(device->ops->bind_iommufd &&
                    (!device->ops->unbind_iommufd ||
                     !device->ops->attach_ioas)))
@@ -552,12 +563,13 @@ static int __vfio_register_dev(struct vfio_device *device,
        if (!device->dev_set)
                vfio_assign_device_set(device, device);
 
-       /* Our reference on group is moved to the device */
-       device->group = group;
-
        ret = dev_set_name(&device->device, "vfio%d", device->index);
        if (ret)
-               goto err_out;
+               return ret;
+
+       ret = vfio_device_set_group(device, type);
+       if (ret)
+               return ret;
 
        ret = device_add(&device->device);
        if (ret)
@@ -576,8 +588,7 @@ err_out:
 
 int vfio_register_group_dev(struct vfio_device *device)
 {
-       return __vfio_register_dev(device,
-               vfio_group_find_or_alloc(device->dev));
+       return __vfio_register_dev(device, VFIO_IOMMU);
 }
 EXPORT_SYMBOL_GPL(vfio_register_group_dev);
 
@@ -587,8 +598,7 @@ EXPORT_SYMBOL_GPL(vfio_register_group_dev);
  */
 int vfio_register_emulated_iommu_dev(struct vfio_device *device)
 {
-       return __vfio_register_dev(device,
-               vfio_noiommu_group_alloc(device->dev, VFIO_EMULATED_IOMMU));
+       return __vfio_register_dev(device, VFIO_EMULATED_IOMMU);
 }
 EXPORT_SYMBOL_GPL(vfio_register_emulated_iommu_dev);
 
@@ -658,6 +668,7 @@ void vfio_unregister_group_dev(struct vfio_device *device)
        /* Balances device_add in register path */
        device_del(&device->device);
 
+       /* Balances vfio_device_set_group in register path */
        vfio_device_remove_group(device);
 }
 EXPORT_SYMBOL_GPL(vfio_unregister_group_dev);