iommu: Pass domain to remove_dev_pasid() op
authorYi Liu <yi.l.liu@intel.com>
Thu, 28 Mar 2024 12:29:58 +0000 (05:29 -0700)
committerJoerg Roedel <jroedel@suse.de>
Fri, 12 Apr 2024 10:13:01 +0000 (12:13 +0200)
Existing remove_dev_pasid() callbacks of the underlying iommu drivers
get the attached domain from the group->pasid_array. However, the domain
stored in group->pasid_array is not always correct in all scenarios.
A wrong domain may result in failure in remove_dev_pasid() callback.
To avoid such problems, it is more reliable to pass the domain to the
remove_dev_pasid() op.

Suggested-by: Jason Gunthorpe <jgg@nvidia.com>
Signed-off-by: Yi Liu <yi.l.liu@intel.com>
Reviewed-by: Kevin Tian <kevin.tian@intel.com>
Link: https://lore.kernel.org/r/20240328122958.83332-3-yi.l.liu@intel.com
Signed-off-by: Joerg Roedel <jroedel@suse.de>
drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
drivers/iommu/intel/iommu.c
drivers/iommu/iommu.c
include/linux/iommu.h

index 41f93c3ab160d0f4a6d3c47c061b23156c016aa7..2e1988c861f1545d2fa7dff55344deb07e658d88 100644 (file)
@@ -3053,14 +3053,9 @@ static int arm_smmu_def_domain_type(struct device *dev)
        return 0;
 }
 
-static void arm_smmu_remove_dev_pasid(struct device *dev, ioasid_t pasid)
+static void arm_smmu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
+                                     struct iommu_domain *domain)
 {
-       struct iommu_domain *domain;
-
-       domain = iommu_get_domain_for_dev_pasid(dev, pasid, IOMMU_DOMAIN_SVA);
-       if (WARN_ON(IS_ERR(domain)) || !domain)
-               return;
-
        arm_smmu_sva_remove_dev_pasid(domain, dev, pasid);
 }
 
index 50eb9aed47cc585e1307b3d0f47252b2edcdaeb0..45c75a8a0ef567178ff7cb8b26856c8726c1e84b 100644 (file)
@@ -4587,19 +4587,15 @@ static int intel_iommu_iotlb_sync_map(struct iommu_domain *domain,
        return 0;
 }
 
-static void intel_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid)
+static void intel_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
+                                        struct iommu_domain *domain)
 {
        struct device_domain_info *info = dev_iommu_priv_get(dev);
+       struct dmar_domain *dmar_domain = to_dmar_domain(domain);
        struct dev_pasid_info *curr, *dev_pasid = NULL;
        struct intel_iommu *iommu = info->iommu;
-       struct dmar_domain *dmar_domain;
-       struct iommu_domain *domain;
        unsigned long flags;
 
-       domain = iommu_get_domain_for_dev_pasid(dev, pasid, 0);
-       if (WARN_ON_ONCE(!domain))
-               goto out_tear_down;
-
        /*
         * The SVA implementation needs to handle its own stuffs like the mm
         * notification. Before consolidating that code into iommu core, let
@@ -4610,7 +4606,6 @@ static void intel_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid)
                goto out_tear_down;
        }
 
-       dmar_domain = to_dmar_domain(domain);
        spin_lock_irqsave(&dmar_domain->lock, flags);
        list_for_each_entry(curr, &dmar_domain->dev_pasids, link_domain) {
                if (curr->dev == dev && curr->pasid == pasid) {
index 659a77f7bb833c2a78f7dc409fdb529e41c4f8e4..3183b0ed4cdb921611de6c12ad44ac055f3f24d3 100644 (file)
@@ -3335,20 +3335,21 @@ err_revert:
 
                if (device == last_gdev)
                        break;
-               ops->remove_dev_pasid(device->dev, pasid);
+               ops->remove_dev_pasid(device->dev, pasid, domain);
        }
        return ret;
 }
 
 static void __iommu_remove_group_pasid(struct iommu_group *group,
-                                      ioasid_t pasid)
+                                      ioasid_t pasid,
+                                      struct iommu_domain *domain)
 {
        struct group_device *device;
        const struct iommu_ops *ops;
 
        for_each_group_device(group, device) {
                ops = dev_iommu_ops(device->dev);
-               ops->remove_dev_pasid(device->dev, pasid);
+               ops->remove_dev_pasid(device->dev, pasid, domain);
        }
 }
 
@@ -3418,7 +3419,7 @@ void iommu_detach_device_pasid(struct iommu_domain *domain, struct device *dev,
        struct iommu_group *group = dev->iommu_group;
 
        mutex_lock(&group->mutex);
-       __iommu_remove_group_pasid(group, pasid);
+       __iommu_remove_group_pasid(group, pasid, domain);
        WARN_ON(xa_erase(&group->pasid_array, pasid) != domain);
        mutex_unlock(&group->mutex);
 }
index 2e925b5eba534c8b5335a8d73742d2a9e2ea38c2..40dd439307e8327314f6d6209804846092d0d5e5 100644 (file)
@@ -578,7 +578,8 @@ struct iommu_ops {
                              struct iommu_page_response *msg);
 
        int (*def_domain_type)(struct device *dev);
-       void (*remove_dev_pasid)(struct device *dev, ioasid_t pasid);
+       void (*remove_dev_pasid)(struct device *dev, ioasid_t pasid,
+                                struct iommu_domain *domain);
 
        const struct iommu_domain_ops *default_domain_ops;
        unsigned long pgsize_bitmap;