return ret;
 }
 
+static int __iommu_probe_device_helper(struct device *dev)
+{
+       const struct iommu_ops *ops = dev->bus->iommu_ops;
+       struct iommu_group *group;
+       int ret;
+
+       ret = __iommu_probe_device(dev, NULL);
+       if (ret)
+               goto err_out;
+
+       /*
+        * Try to allocate a default domain - needs support from the
+        * IOMMU driver. There are still some drivers which don't
+        * support default domains, so the return value is not yet
+        * checked.
+        */
+       iommu_alloc_default_domain(dev);
+
+       group = iommu_group_get(dev);
+       if (!group)
+               goto err_release;
+
+       if (group->default_domain)
+               ret = __iommu_attach_device(group->default_domain, dev);
+
+       iommu_group_put(group);
+
+       if (ret)
+               goto err_release;
+
+       if (ops->probe_finalize)
+               ops->probe_finalize(dev);
+
+       return 0;
+
+err_release:
+       iommu_release_device(dev);
+err_out:
+       return ret;
+
+}
+
 int iommu_probe_device(struct device *dev)
 {
        const struct iommu_ops *ops = dev->bus->iommu_ops;
        int ret;
 
        WARN_ON(dev->iommu_group);
+
        if (!ops)
                return -EINVAL;
 
                goto err_free_dev_param;
        }
 
-       if (ops->probe_device) {
-               struct iommu_group *group;
-
-               ret = __iommu_probe_device(dev, NULL);
-
-               /*
-                * Try to allocate a default domain - needs support from the
-                * IOMMU driver. There are still some drivers which don't
-                * support default domains, so the return value is not yet
-                * checked.
-                */
-               if (!ret)
-                       iommu_alloc_default_domain(dev);
-
-               group = iommu_group_get(dev);
-               if (group && group->default_domain) {
-                       ret = __iommu_attach_device(group->default_domain, dev);
-                       iommu_group_put(group);
-               }
-
-       } else {
-               ret = ops->add_device(dev);
-       }
+       if (ops->probe_device)
+               return __iommu_probe_device_helper(dev);
 
+       ret = ops->add_device(dev);
        if (ret)
                goto err_module_put;