static inline void init_sched_state(struct intel_context *ce)
 {
-       /* Only should be called from guc_lrc_desc_pin() */
+       lockdep_assert_held(&ce->guc_state.lock);
        atomic_set(&ce->guc_sched_state_no_lock, 0);
        ce->guc_state.sched_state &= SCHED_STATE_BLOCKED_MASK;
 }
 
+__maybe_unused
+static bool sched_state_is_init(struct intel_context *ce)
+{
+       /*
+        * XXX: Kernel contexts can have SCHED_STATE_NO_LOCK_REGISTERED after
+        * suspend.
+        */
+       return !(atomic_read(&ce->guc_sched_state_no_lock) &
+                ~SCHED_STATE_NO_LOCK_REGISTERED) &&
+               !(ce->guc_state.sched_state &= ~SCHED_STATE_BLOCKED_MASK);
+}
+
 static inline bool
 context_wait_for_deregister_to_register(struct intel_context *ce)
 {
 static inline void
 set_context_wait_for_deregister_to_register(struct intel_context *ce)
 {
-       /* Only should be called from guc_lrc_desc_pin() without lock */
+       lockdep_assert_held(&ce->guc_state.lock);
        ce->guc_state.sched_state |=
                SCHED_STATE_WAIT_FOR_DEREGISTER_TO_REGISTER;
 }
        bool pending_disable, pending_enable, deregister, destroyed, banned;
 
        xa_for_each(&guc->context_lookup, index, ce) {
-               /* Flush context */
                spin_lock_irqsave(&ce->guc_state.lock, flags);
-               spin_unlock_irqrestore(&ce->guc_state.lock, flags);
 
                /*
                 * Once we are at this point submission_disabled() is guaranteed
                banned = context_banned(ce);
                init_sched_state(ce);
 
+               spin_unlock_irqrestore(&ce->guc_state.lock, flags);
+
                if (pending_enable || destroyed || deregister) {
                        decr_outstanding_submission_g2h(guc);
                        if (deregister)
        int ret = 0;
 
        GEM_BUG_ON(!engine->mask);
+       GEM_BUG_ON(!sched_state_is_init(ce));
 
        /*
         * Ensure LRC + CT vmas are is same region as write barrier is done
        desc->priority = ce->guc_prio;
        desc->context_flags = CONTEXT_REGISTRATION_FLAG_KMD;
        guc_context_policy_init(engine, desc);
-       init_sched_state(ce);
 
        /*
         * The context_lookup xarray is used to determine if the hardware