static void mlx5_sf_dev_remove(struct auxiliary_device *adev)
 {
        struct mlx5_sf_dev *sf_dev = container_of(adev, struct mlx5_sf_dev, adev);
-       struct devlink *devlink = priv_to_devlink(sf_dev->mdev);
+       struct mlx5_core_dev *mdev = sf_dev->mdev;
+       struct devlink *devlink;
 
-       mlx5_drain_health_wq(sf_dev->mdev);
+       devlink = priv_to_devlink(mdev);
+       set_bit(MLX5_BREAK_FW_WAIT, &mdev->intf_state);
+       mlx5_drain_health_wq(mdev);
        devlink_unregister(devlink);
-       if (mlx5_dev_is_lightweight(sf_dev->mdev))
-               mlx5_uninit_one_light(sf_dev->mdev);
+       if (mlx5_dev_is_lightweight(mdev))
+               mlx5_uninit_one_light(mdev);
        else
-               mlx5_uninit_one(sf_dev->mdev);
-       iounmap(sf_dev->mdev->iseg);
-       mlx5_mdev_uninit(sf_dev->mdev);
+               mlx5_uninit_one(mdev);
+       iounmap(mdev->iseg);
+       mlx5_mdev_uninit(mdev);
        mlx5_devlink_free(devlink);
 }
 
 static void mlx5_sf_dev_shutdown(struct auxiliary_device *adev)
 {
        struct mlx5_sf_dev *sf_dev = container_of(adev, struct mlx5_sf_dev, adev);
+       struct mlx5_core_dev *mdev = sf_dev->mdev;
 
-       mlx5_unload_one(sf_dev->mdev, false);
+       set_bit(MLX5_BREAK_FW_WAIT, &mdev->intf_state);
+       mlx5_unload_one(mdev, false);
 }
 
 static const struct auxiliary_device_id mlx5_sf_dev_id_table[] = {