psp->hdcp_context.context.mem_context.shared_mem_size = PSP_HDCP_SHARED_MEM_SIZE;
        psp->hdcp_context.context.ta_load_type = GFX_CMD_ID_LOAD_TA;
 
-       if (!psp->hdcp_context.context.initialized) {
+       if (!psp->hdcp_context.context.mem_context.shared_buf) {
                ret = psp_ta_init_shared_buf(psp, &psp->hdcp_context.context.mem_context);
                if (ret)
                        return ret;
        psp->dtm_context.context.mem_context.shared_mem_size = PSP_DTM_SHARED_MEM_SIZE;
        psp->dtm_context.context.ta_load_type = GFX_CMD_ID_LOAD_TA;
 
-       if (!psp->dtm_context.context.initialized) {
+       if (!psp->dtm_context.context.mem_context.shared_buf) {
                ret = psp_ta_init_shared_buf(psp, &psp->dtm_context.context.mem_context);
                if (ret)
                        return ret;
        psp->rap_context.context.mem_context.shared_mem_size = PSP_RAP_SHARED_MEM_SIZE;
        psp->rap_context.context.ta_load_type = GFX_CMD_ID_LOAD_TA;
 
-       if (!psp->rap_context.context.initialized) {
+       if (!psp->rap_context.context.mem_context.shared_buf) {
                ret = psp_ta_init_shared_buf(psp, &psp->rap_context.context.mem_context);
                if (ret)
                        return ret;