struct vfio_group {
        struct iommu_group      *iommu_group;
        struct list_head        next;
+       bool                    mdev_group;     /* An mdev group */
 };
 
 /*
        return ret;
 }
 
+static struct device *vfio_mdev_get_iommu_device(struct device *dev)
+{
+       struct device *(*fn)(struct device *dev);
+       struct device *iommu_device;
+
+       fn = symbol_get(mdev_get_iommu_device);
+       if (fn) {
+               iommu_device = fn(dev);
+               symbol_put(mdev_get_iommu_device);
+
+               return iommu_device;
+       }
+
+       return NULL;
+}
+
+static int vfio_mdev_attach_domain(struct device *dev, void *data)
+{
+       struct iommu_domain *domain = data;
+       struct device *iommu_device;
+
+       iommu_device = vfio_mdev_get_iommu_device(dev);
+       if (iommu_device) {
+               if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
+                       return iommu_aux_attach_device(domain, iommu_device);
+               else
+                       return iommu_attach_device(domain, iommu_device);
+       }
+
+       return -EINVAL;
+}
+
+static int vfio_mdev_detach_domain(struct device *dev, void *data)
+{
+       struct iommu_domain *domain = data;
+       struct device *iommu_device;
+
+       iommu_device = vfio_mdev_get_iommu_device(dev);
+       if (iommu_device) {
+               if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
+                       iommu_aux_detach_device(domain, iommu_device);
+               else
+                       iommu_detach_device(domain, iommu_device);
+       }
+
+       return 0;
+}
+
+static int vfio_iommu_attach_group(struct vfio_domain *domain,
+                                  struct vfio_group *group)
+{
+       if (group->mdev_group)
+               return iommu_group_for_each_dev(group->iommu_group,
+                                               domain->domain,
+                                               vfio_mdev_attach_domain);
+       else
+               return iommu_attach_group(domain->domain, group->iommu_group);
+}
+
+static void vfio_iommu_detach_group(struct vfio_domain *domain,
+                                   struct vfio_group *group)
+{
+       if (group->mdev_group)
+               iommu_group_for_each_dev(group->iommu_group, domain->domain,
+                                        vfio_mdev_detach_domain);
+       else
+               iommu_detach_group(domain->domain, group->iommu_group);
+}
+
 static int vfio_iommu_type1_attach_group(void *iommu_data,
                                         struct iommu_group *iommu_group)
 {
                        goto out_domain;
        }
 
-       ret = iommu_attach_group(domain->domain, iommu_group);
+       ret = vfio_iommu_attach_group(domain, group);
        if (ret)
                goto out_domain;
 
        list_for_each_entry(d, &iommu->domain_list, next) {
                if (d->domain->ops == domain->domain->ops &&
                    d->prot == domain->prot) {
-                       iommu_detach_group(domain->domain, iommu_group);
-                       if (!iommu_attach_group(d->domain, iommu_group)) {
+                       vfio_iommu_detach_group(domain, group);
+                       if (!vfio_iommu_attach_group(d, group)) {
                                list_add(&group->next, &d->group_list);
                                iommu_domain_free(domain->domain);
                                kfree(domain);
                                return 0;
                        }
 
-                       ret = iommu_attach_group(domain->domain, iommu_group);
+                       ret = vfio_iommu_attach_group(domain, group);
                        if (ret)
                                goto out_domain;
                }
        return 0;
 
 out_detach:
-       iommu_detach_group(domain->domain, iommu_group);
+       vfio_iommu_detach_group(domain, group);
 out_domain:
        iommu_domain_free(domain->domain);
 out_free:
                if (!group)
                        continue;
 
-               iommu_detach_group(domain->domain, iommu_group);
+               vfio_iommu_detach_group(domain, group);
                list_del(&group->next);
                kfree(group);
                /*
        list_for_each_entry_safe(group, group_tmp,
                                 &domain->group_list, next) {
                if (!external)
-                       iommu_detach_group(domain->domain, group->iommu_group);
+                       vfio_iommu_detach_group(domain, group);
                list_del(&group->next);
                kfree(group);
        }