return ret;
 }
 
+static int mlx5fv_vf_event(struct notifier_block *nb,
+                          unsigned long event, void *data)
+{
+       struct mlx5vf_pci_core_device *mvdev =
+               container_of(nb, struct mlx5vf_pci_core_device, nb);
+
+       mutex_lock(&mvdev->state_mutex);
+       switch (event) {
+       case MLX5_PF_NOTIFY_ENABLE_VF:
+               mvdev->mdev_detach = false;
+               break;
+       case MLX5_PF_NOTIFY_DISABLE_VF:
+               mvdev->mdev_detach = true;
+               break;
+       default:
+               break;
+       }
+       mlx5vf_state_mutex_unlock(mvdev);
+       return 0;
+}
+
+void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev)
+{
+       if (!mvdev->migrate_cap)
+               return;
+
+       mlx5_sriov_blocking_notifier_unregister(mvdev->mdev, mvdev->vf_id,
+                                               &mvdev->nb);
+}
+
+void mlx5vf_cmd_set_migratable(struct mlx5vf_pci_core_device *mvdev)
+{
+       struct pci_dev *pdev = mvdev->core_device.pdev;
+       int ret;
+
+       if (!pdev->is_virtfn)
+               return;
+
+       mvdev->mdev = mlx5_vf_get_core_dev(pdev);
+       if (!mvdev->mdev)
+               return;
+
+       if (!MLX5_CAP_GEN(mvdev->mdev, migration))
+               goto end;
+
+       mvdev->vf_id = pci_iov_vf_id(pdev);
+       if (mvdev->vf_id < 0)
+               goto end;
+
+       mutex_init(&mvdev->state_mutex);
+       spin_lock_init(&mvdev->reset_lock);
+       mvdev->nb.notifier_call = mlx5fv_vf_event;
+       ret = mlx5_sriov_blocking_notifier_register(mvdev->mdev, mvdev->vf_id,
+                                                   &mvdev->nb);
+       if (ret)
+               goto end;
+
+       mvdev->migrate_cap = 1;
+       mvdev->core_device.vdev.migration_flags =
+               VFIO_MIGRATION_STOP_COPY |
+               VFIO_MIGRATION_P2P;
+
+end:
+       mlx5_vf_put_core_dev(mvdev->mdev);
+}
+
 int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id)
 {
        struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
 
 #define MLX5_VFIO_CMD_H
 
 #include <linux/kernel.h>
+#include <linux/vfio_pci_core.h>
 #include <linux/mlx5/driver.h>
 
 struct mlx5_vf_migration_file {
        unsigned long last_offset;
 };
 
+struct mlx5vf_pci_core_device {
+       struct vfio_pci_core_device core_device;
+       int vf_id;
+       u16 vhca_id;
+       u8 migrate_cap:1;
+       u8 deferred_reset:1;
+       u8 mdev_detach:1;
+       /* protect migration state */
+       struct mutex state_mutex;
+       enum vfio_device_mig_state mig_state;
+       /* protect the reset_done flow */
+       spinlock_t reset_lock;
+       struct mlx5_vf_migration_file *resuming_migf;
+       struct mlx5_vf_migration_file *saving_migf;
+       struct notifier_block nb;
+       struct mlx5_core_dev *mdev;
+};
+
 int mlx5vf_cmd_suspend_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod);
 int mlx5vf_cmd_resume_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod);
 int mlx5vf_cmd_query_vhca_migration_state(struct pci_dev *pdev, u16 vhca_id,
                                          size_t *state_size);
 int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id);
+void mlx5vf_cmd_set_migratable(struct mlx5vf_pci_core_device *mvdev);
+void mlx5vf_cmd_remove_migratable(struct mlx5vf_pci_core_device *mvdev);
 int mlx5vf_cmd_save_vhca_state(struct pci_dev *pdev, u16 vhca_id,
                               struct mlx5_vf_migration_file *migf);
 int mlx5vf_cmd_load_vhca_state(struct pci_dev *pdev, u16 vhca_id,
                               struct mlx5_vf_migration_file *migf);
+void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev);
 #endif /* MLX5_VFIO_CMD_H */
 
 #include <linux/uaccess.h>
 #include <linux/vfio.h>
 #include <linux/sched/mm.h>
-#include <linux/vfio_pci_core.h>
 #include <linux/anon_inodes.h>
 
 #include "cmd.h"
 /* Arbitrary to prevent userspace from consuming endless memory */
 #define MAX_MIGRATION_SIZE (512*1024*1024)
 
-struct mlx5vf_pci_core_device {
-       struct vfio_pci_core_device core_device;
-       u16 vhca_id;
-       u8 migrate_cap:1;
-       u8 deferred_reset:1;
-       /* protect migration state */
-       struct mutex state_mutex;
-       enum vfio_device_mig_state mig_state;
-       /* protect the reset_done flow */
-       spinlock_t reset_lock;
-       struct mlx5_vf_migration_file *resuming_migf;
-       struct mlx5_vf_migration_file *saving_migf;
-};
-
 static struct page *
 mlx5vf_get_migration_page(struct mlx5_vf_migration_file *migf,
                          unsigned long offset)
  * This function is called in all state_mutex unlock cases to
  * handle a 'deferred_reset' if exists.
  */
-static void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
+void mlx5vf_state_mutex_unlock(struct mlx5vf_pci_core_device *mvdev)
 {
 again:
        spin_lock(&mvdev->reset_lock);
        if (!mvdev)
                return -ENOMEM;
        vfio_pci_core_init_device(&mvdev->core_device, pdev, &mlx5vf_pci_ops);
-
-       if (pdev->is_virtfn) {
-               struct mlx5_core_dev *mdev =
-                       mlx5_vf_get_core_dev(pdev);
-
-               if (mdev) {
-                       if (MLX5_CAP_GEN(mdev, migration)) {
-                               mvdev->migrate_cap = 1;
-                               mvdev->core_device.vdev.migration_flags =
-                                       VFIO_MIGRATION_STOP_COPY |
-                                       VFIO_MIGRATION_P2P;
-                               mutex_init(&mvdev->state_mutex);
-                               spin_lock_init(&mvdev->reset_lock);
-                       }
-                       mlx5_vf_put_core_dev(mdev);
-               }
-       }
-
+       mlx5vf_cmd_set_migratable(mvdev);
        ret = vfio_pci_core_register_device(&mvdev->core_device);
        if (ret)
                goto out_free;
        return 0;
 
 out_free:
+       mlx5vf_cmd_remove_migratable(mvdev);
        vfio_pci_core_uninit_device(&mvdev->core_device);
        kfree(mvdev);
        return ret;
        struct mlx5vf_pci_core_device *mvdev = dev_get_drvdata(&pdev->dev);
 
        vfio_pci_core_unregister_device(&mvdev->core_device);
+       mlx5vf_cmd_remove_migratable(mvdev);
        vfio_pci_core_uninit_device(&mvdev->core_device);
        kfree(mvdev);
 }