#include "cmd.h"
 
-int mlx5vf_cmd_suspend_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod)
+static int mlx5vf_cmd_get_vhca_id(struct mlx5_core_dev *mdev, u16 function_id,
+                                 u16 *vhca_id);
+
+int mlx5vf_cmd_suspend_vhca(struct mlx5vf_pci_core_device *mvdev, u16 op_mod)
 {
-       struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
        u32 out[MLX5_ST_SZ_DW(suspend_vhca_out)] = {};
        u32 in[MLX5_ST_SZ_DW(suspend_vhca_in)] = {};
-       int ret;
 
-       if (!mdev)
+       lockdep_assert_held(&mvdev->state_mutex);
+       if (mvdev->mdev_detach)
                return -ENOTCONN;
 
        MLX5_SET(suspend_vhca_in, in, opcode, MLX5_CMD_OP_SUSPEND_VHCA);
-       MLX5_SET(suspend_vhca_in, in, vhca_id, vhca_id);
+       MLX5_SET(suspend_vhca_in, in, vhca_id, mvdev->vhca_id);
        MLX5_SET(suspend_vhca_in, in, op_mod, op_mod);
 
-       ret = mlx5_cmd_exec_inout(mdev, suspend_vhca, in, out);
-       mlx5_vf_put_core_dev(mdev);
-       return ret;
+       return mlx5_cmd_exec_inout(mvdev->mdev, suspend_vhca, in, out);
 }
 
-int mlx5vf_cmd_resume_vhca(struct pci_dev *pdev, u16 vhca_id, u16 op_mod)
+int mlx5vf_cmd_resume_vhca(struct mlx5vf_pci_core_device *mvdev, u16 op_mod)
 {
-       struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
        u32 out[MLX5_ST_SZ_DW(resume_vhca_out)] = {};
        u32 in[MLX5_ST_SZ_DW(resume_vhca_in)] = {};
-       int ret;
 
-       if (!mdev)
+       lockdep_assert_held(&mvdev->state_mutex);
+       if (mvdev->mdev_detach)
                return -ENOTCONN;
 
        MLX5_SET(resume_vhca_in, in, opcode, MLX5_CMD_OP_RESUME_VHCA);
-       MLX5_SET(resume_vhca_in, in, vhca_id, vhca_id);
+       MLX5_SET(resume_vhca_in, in, vhca_id, mvdev->vhca_id);
        MLX5_SET(resume_vhca_in, in, op_mod, op_mod);
 
-       ret = mlx5_cmd_exec_inout(mdev, resume_vhca, in, out);
-       mlx5_vf_put_core_dev(mdev);
-       return ret;
+       return mlx5_cmd_exec_inout(mvdev->mdev, resume_vhca, in, out);
 }
 
-int mlx5vf_cmd_query_vhca_migration_state(struct pci_dev *pdev, u16 vhca_id,
+int mlx5vf_cmd_query_vhca_migration_state(struct mlx5vf_pci_core_device *mvdev,
                                          size_t *state_size)
 {
-       struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
        u32 out[MLX5_ST_SZ_DW(query_vhca_migration_state_out)] = {};
        u32 in[MLX5_ST_SZ_DW(query_vhca_migration_state_in)] = {};
        int ret;
 
-       if (!mdev)
+       lockdep_assert_held(&mvdev->state_mutex);
+       if (mvdev->mdev_detach)
                return -ENOTCONN;
 
        MLX5_SET(query_vhca_migration_state_in, in, opcode,
                 MLX5_CMD_OP_QUERY_VHCA_MIGRATION_STATE);
-       MLX5_SET(query_vhca_migration_state_in, in, vhca_id, vhca_id);
+       MLX5_SET(query_vhca_migration_state_in, in, vhca_id, mvdev->vhca_id);
        MLX5_SET(query_vhca_migration_state_in, in, op_mod, 0);
 
-       ret = mlx5_cmd_exec_inout(mdev, query_vhca_migration_state, in, out);
+       ret = mlx5_cmd_exec_inout(mvdev->mdev, query_vhca_migration_state, in,
+                                 out);
        if (ret)
-               goto end;
+               return ret;
 
        *state_size = MLX5_GET(query_vhca_migration_state_out, out,
                               required_umem_size);
-
-end:
-       mlx5_vf_put_core_dev(mdev);
-       return ret;
+       return 0;
 }
 
 static int mlx5fv_vf_event(struct notifier_block *nb,
        if (mvdev->vf_id < 0)
                goto end;
 
+       if (mlx5vf_cmd_get_vhca_id(mvdev->mdev, mvdev->vf_id + 1,
+                                  &mvdev->vhca_id))
+               goto end;
+
        mutex_init(&mvdev->state_mutex);
        spin_lock_init(&mvdev->reset_lock);
        mvdev->nb.notifier_call = mlx5fv_vf_event;
        mlx5_vf_put_core_dev(mvdev->mdev);
 }
 
-int mlx5vf_cmd_get_vhca_id(struct pci_dev *pdev, u16 function_id, u16 *vhca_id)
+static int mlx5vf_cmd_get_vhca_id(struct mlx5_core_dev *mdev, u16 function_id,
+                                 u16 *vhca_id)
 {
-       struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
        u32 in[MLX5_ST_SZ_DW(query_hca_cap_in)] = {};
        int out_size;
        void *out;
        int ret;
 
-       if (!mdev)
-               return -ENOTCONN;
-
        out_size = MLX5_ST_SZ_BYTES(query_hca_cap_out);
        out = kzalloc(out_size, GFP_KERNEL);
-       if (!out) {
-               ret = -ENOMEM;
-               goto end;
-       }
+       if (!out)
+               return -ENOMEM;
 
        MLX5_SET(query_hca_cap_in, in, opcode, MLX5_CMD_OP_QUERY_HCA_CAP);
        MLX5_SET(query_hca_cap_in, in, other_function, 1);
 
 err_exec:
        kfree(out);
-end:
-       mlx5_vf_put_core_dev(mdev);
        return ret;
 }
 
        return err;
 }
 
-int mlx5vf_cmd_save_vhca_state(struct pci_dev *pdev, u16 vhca_id,
+int mlx5vf_cmd_save_vhca_state(struct mlx5vf_pci_core_device *mvdev,
                               struct mlx5_vf_migration_file *migf)
 {
-       struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
        u32 out[MLX5_ST_SZ_DW(save_vhca_state_out)] = {};
        u32 in[MLX5_ST_SZ_DW(save_vhca_state_in)] = {};
+       struct mlx5_core_dev *mdev;
        u32 pdn, mkey;
        int err;
 
-       if (!mdev)
+       lockdep_assert_held(&mvdev->state_mutex);
+       if (mvdev->mdev_detach)
                return -ENOTCONN;
 
+       mdev = mvdev->mdev;
        err = mlx5_core_alloc_pd(mdev, &pdn);
        if (err)
-               goto end;
+               return err;
 
        err = dma_map_sgtable(mdev->device, &migf->table.sgt, DMA_FROM_DEVICE,
                              0);
        MLX5_SET(save_vhca_state_in, in, opcode,
                 MLX5_CMD_OP_SAVE_VHCA_STATE);
        MLX5_SET(save_vhca_state_in, in, op_mod, 0);
-       MLX5_SET(save_vhca_state_in, in, vhca_id, vhca_id);
+       MLX5_SET(save_vhca_state_in, in, vhca_id, mvdev->vhca_id);
        MLX5_SET(save_vhca_state_in, in, mkey, mkey);
        MLX5_SET(save_vhca_state_in, in, size, migf->total_length);
 
        if (err)
                goto err_exec;
 
-       migf->total_length =
-               MLX5_GET(save_vhca_state_out, out, actual_image_size);
-
-       mlx5_core_destroy_mkey(mdev, mkey);
-       mlx5_core_dealloc_pd(mdev, pdn);
-       dma_unmap_sgtable(mdev->device, &migf->table.sgt, DMA_FROM_DEVICE, 0);
-       mlx5_vf_put_core_dev(mdev);
-
-       return 0;
-
+       migf->total_length = MLX5_GET(save_vhca_state_out, out,
+                                     actual_image_size);
 err_exec:
        mlx5_core_destroy_mkey(mdev, mkey);
 err_create_mkey:
        dma_unmap_sgtable(mdev->device, &migf->table.sgt, DMA_FROM_DEVICE, 0);
 err_dma_map:
        mlx5_core_dealloc_pd(mdev, pdn);
-end:
-       mlx5_vf_put_core_dev(mdev);
        return err;
 }
 
-int mlx5vf_cmd_load_vhca_state(struct pci_dev *pdev, u16 vhca_id,
+int mlx5vf_cmd_load_vhca_state(struct mlx5vf_pci_core_device *mvdev,
                               struct mlx5_vf_migration_file *migf)
 {
-       struct mlx5_core_dev *mdev = mlx5_vf_get_core_dev(pdev);
+       struct mlx5_core_dev *mdev;
        u32 out[MLX5_ST_SZ_DW(save_vhca_state_out)] = {};
        u32 in[MLX5_ST_SZ_DW(save_vhca_state_in)] = {};
        u32 pdn, mkey;
        int err;
 
-       if (!mdev)
+       lockdep_assert_held(&mvdev->state_mutex);
+       if (mvdev->mdev_detach)
                return -ENOTCONN;
 
        mutex_lock(&migf->lock);
                goto end;
        }
 
+       mdev = mvdev->mdev;
        err = mlx5_core_alloc_pd(mdev, &pdn);
        if (err)
                goto end;
        MLX5_SET(load_vhca_state_in, in, opcode,
                 MLX5_CMD_OP_LOAD_VHCA_STATE);
        MLX5_SET(load_vhca_state_in, in, op_mod, 0);
-       MLX5_SET(load_vhca_state_in, in, vhca_id, vhca_id);
+       MLX5_SET(load_vhca_state_in, in, vhca_id, mvdev->vhca_id);
        MLX5_SET(load_vhca_state_in, in, mkey, mkey);
        MLX5_SET(load_vhca_state_in, in, size, migf->total_length);
 
 err_reg:
        mlx5_core_dealloc_pd(mdev, pdn);
 end:
-       mlx5_vf_put_core_dev(mdev);
        mutex_unlock(&migf->lock);
        return err;
 }
 
        stream_open(migf->filp->f_inode, migf->filp);
        mutex_init(&migf->lock);
 
-       ret = mlx5vf_cmd_query_vhca_migration_state(
-               mvdev->core_device.pdev, mvdev->vhca_id, &migf->total_length);
+       ret = mlx5vf_cmd_query_vhca_migration_state(mvdev,
+                                                   &migf->total_length);
        if (ret)
                goto out_free;
 
        if (ret)
                goto out_free;
 
-       ret = mlx5vf_cmd_save_vhca_state(mvdev->core_device.pdev,
-                                        mvdev->vhca_id, migf);
+       ret = mlx5vf_cmd_save_vhca_state(mvdev, migf);
        if (ret)
                goto out_free;
        return migf;
        int ret;
 
        if (cur == VFIO_DEVICE_STATE_RUNNING_P2P && new == VFIO_DEVICE_STATE_STOP) {
-               ret = mlx5vf_cmd_suspend_vhca(
-                       mvdev->core_device.pdev, mvdev->vhca_id,
+               ret = mlx5vf_cmd_suspend_vhca(mvdev,
                        MLX5_SUSPEND_VHCA_IN_OP_MOD_SUSPEND_RESPONDER);
                if (ret)
                        return ERR_PTR(ret);
        }
 
        if (cur == VFIO_DEVICE_STATE_STOP && new == VFIO_DEVICE_STATE_RUNNING_P2P) {
-               ret = mlx5vf_cmd_resume_vhca(
-                       mvdev->core_device.pdev, mvdev->vhca_id,
+               ret = mlx5vf_cmd_resume_vhca(mvdev,
                        MLX5_RESUME_VHCA_IN_OP_MOD_RESUME_RESPONDER);
                if (ret)
                        return ERR_PTR(ret);
        }
 
        if (cur == VFIO_DEVICE_STATE_RUNNING && new == VFIO_DEVICE_STATE_RUNNING_P2P) {
-               ret = mlx5vf_cmd_suspend_vhca(
-                       mvdev->core_device.pdev, mvdev->vhca_id,
+               ret = mlx5vf_cmd_suspend_vhca(mvdev,
                        MLX5_SUSPEND_VHCA_IN_OP_MOD_SUSPEND_INITIATOR);
                if (ret)
                        return ERR_PTR(ret);
        }
 
        if (cur == VFIO_DEVICE_STATE_RUNNING_P2P && new == VFIO_DEVICE_STATE_RUNNING) {
-               ret = mlx5vf_cmd_resume_vhca(
-                       mvdev->core_device.pdev, mvdev->vhca_id,
+               ret = mlx5vf_cmd_resume_vhca(mvdev,
                        MLX5_RESUME_VHCA_IN_OP_MOD_RESUME_INITIATOR);
                if (ret)
                        return ERR_PTR(ret);
        }
 
        if (cur == VFIO_DEVICE_STATE_RESUMING && new == VFIO_DEVICE_STATE_STOP) {
-               ret = mlx5vf_cmd_load_vhca_state(mvdev->core_device.pdev,
-                                                mvdev->vhca_id,
+               ret = mlx5vf_cmd_load_vhca_state(mvdev,
                                                 mvdev->resuming_migf);
                if (ret)
                        return ERR_PTR(ret);
        struct mlx5vf_pci_core_device *mvdev = container_of(
                core_vdev, struct mlx5vf_pci_core_device, core_device.vdev);
        struct vfio_pci_core_device *vdev = &mvdev->core_device;
-       int vf_id;
        int ret;
 
        ret = vfio_pci_core_enable(vdev);
        if (ret)
                return ret;
 
-       if (!mvdev->migrate_cap) {
-               vfio_pci_core_finish_enable(vdev);
-               return 0;
-       }
-
-       vf_id = pci_iov_vf_id(vdev->pdev);
-       if (vf_id < 0) {
-               ret = vf_id;
-               goto out_disable;
-       }
-
-       ret = mlx5vf_cmd_get_vhca_id(vdev->pdev, vf_id + 1, &mvdev->vhca_id);
-       if (ret)
-               goto out_disable;
-
-       mvdev->mig_state = VFIO_DEVICE_STATE_RUNNING;
+       if (mvdev->migrate_cap)
+               mvdev->mig_state = VFIO_DEVICE_STATE_RUNNING;
        vfio_pci_core_finish_enable(vdev);
        return 0;
-out_disable:
-       vfio_pci_core_disable(vdev);
-       return ret;
 }
 
 static void mlx5vf_pci_close_device(struct vfio_device *core_vdev)