return psp_ta_invoke(psp, ta_cmd_id, psp->xgmi_context.session_id);
 }
 
-static int psp_xgmi_terminate(struct psp_context *psp)
+int psp_xgmi_terminate(struct psp_context *psp)
 {
        int ret;
 
        return 0;
 }
 
-static int psp_xgmi_initialize(struct psp_context *psp)
+int psp_xgmi_initialize(struct psp_context *psp)
 {
        struct ta_xgmi_shared_memory *xgmi_cmd;
        int ret;
                return ret;
        }
 
-       if (adev->gmc.xgmi.num_physical_nodes > 1) {
-               ret = psp_xgmi_initialize(psp);
-               /* Warning the XGMI seesion initialize failure
-                * Instead of stop driver initialization
-                */
-               if (ret)
-                       dev_err(psp->adev->dev,
-                               "XGMI: Failed to initialize XGMI session\n");
-       }
-
        if (psp->adev->psp.ta_fw) {
                ret = psp_ras_initialize(psp);
                if (ret)
        void *tmr_buf;
        void **pptr;
 
-       if (adev->gmc.xgmi.num_physical_nodes > 1 &&
-           psp->xgmi_context.initialized == 1)
-                psp_xgmi_terminate(psp);
-
        if (psp->adev->psp.ta_fw) {
                psp_ras_terminate(psp);
                psp_dtm_terminate(psp);
 
 int psp_update_vcn_sram(struct amdgpu_device *adev, int inst_idx,
                        uint64_t cmd_gpu_addr, int cmd_size);
 
+int psp_xgmi_initialize(struct psp_context *psp);
+int psp_xgmi_terminate(struct psp_context *psp);
 int psp_xgmi_invoke(struct psp_context *psp, uint32_t ta_cmd_id);
 
 int psp_ras_invoke(struct psp_context *psp, uint32_t ta_cmd_id);
 
                return 0;
 
        if (amdgpu_device_ip_get_ip_block(adev, AMD_IP_BLOCK_TYPE_PSP)) {
+               ret = psp_xgmi_initialize(&adev->psp);
+               if (ret) {
+                       dev_err(adev->dev,
+                               "XGMI: Failed to initialize xgmi session\n");
+                       return ret;
+               }
+
                ret = psp_xgmi_get_hive_id(&adev->psp, &adev->gmc.xgmi.hive_id);
                if (ret) {
                        dev_err(adev->dev,
        return ret;
 }
 
-void amdgpu_xgmi_remove_device(struct amdgpu_device *adev)
+int amdgpu_xgmi_remove_device(struct amdgpu_device *adev)
 {
        struct amdgpu_hive_info *hive;
 
        if (!adev->gmc.xgmi.supported)
-               return;
+               return -EINVAL;
 
        hive = amdgpu_get_xgmi_hive(adev, 1);
        if (!hive)
-               return;
+               return -EINVAL;
 
        if (!(hive->number_devices--)) {
                amdgpu_xgmi_sysfs_destroy(adev, hive);
                amdgpu_xgmi_sysfs_rem_dev_info(adev, hive);
                mutex_unlock(&hive->hive_lock);
        }
+
+       return psp_xgmi_terminate(&adev->psp);
 }
 
 int amdgpu_xgmi_ras_late_init(struct amdgpu_device *adev)
 
 struct amdgpu_hive_info *amdgpu_get_xgmi_hive(struct amdgpu_device *adev, int lock);
 int amdgpu_xgmi_update_topology(struct amdgpu_hive_info *hive, struct amdgpu_device *adev);
 int amdgpu_xgmi_add_device(struct amdgpu_device *adev);
-void amdgpu_xgmi_remove_device(struct amdgpu_device *adev);
+int amdgpu_xgmi_remove_device(struct amdgpu_device *adev);
 int amdgpu_xgmi_set_pstate(struct amdgpu_device *adev, int pstate);
 int amdgpu_xgmi_get_hops_count(struct amdgpu_device *adev,
                struct amdgpu_device *peer_adev);