}
 #endif
 
-static void set_gcr_el1_excl(u64 excl)
+static void mte_update_sctlr_user(struct task_struct *task)
 {
-       current->thread.mte_ctrl = excl;
+       unsigned long sctlr = task->thread.sctlr_user;
+       unsigned long pref = MTE_CTRL_TCF_ASYNC;
+       unsigned long mte_ctrl = task->thread.mte_ctrl;
+       unsigned long resolved_mte_tcf = (mte_ctrl & pref) ? pref : mte_ctrl;
 
-       /*
-        * SYS_GCR_EL1 will be set to current->thread.gcr_user_excl value
-        * by mte_set_user_gcr() in kernel_exit,
-        */
+       sctlr &= ~SCTLR_EL1_TCF0_MASK;
+       if (resolved_mte_tcf & MTE_CTRL_TCF_ASYNC)
+               sctlr |= SCTLR_EL1_TCF0_ASYNC;
+       else if (resolved_mte_tcf & MTE_CTRL_TCF_SYNC)
+               sctlr |= SCTLR_EL1_TCF0_SYNC;
+       task->thread.sctlr_user = sctlr;
 }
 
 void mte_thread_init_user(void)
        dsb(ish);
        write_sysreg_s(0, SYS_TFSRE0_EL1);
        clear_thread_flag(TIF_MTE_ASYNC_FAULT);
-       /* disable tag checking */
-       set_task_sctlr_el1((current->thread.sctlr_user & ~SCTLR_EL1_TCF0_MASK) |
-                          SCTLR_EL1_TCF0_NONE);
-       /* reset tag generation mask */
-       set_gcr_el1_excl(SYS_GCR_EL1_EXCL_MASK);
+       /* disable tag checking and reset tag generation mask */
+       current->thread.mte_ctrl = MTE_CTRL_GCR_USER_EXCL_MASK;
+       mte_update_sctlr_user(current);
+       set_task_sctlr_el1(current->thread.sctlr_user);
 }
 
 void mte_thread_switch(struct task_struct *next)
 {
+       mte_update_sctlr_user(next);
+
        /*
         * Check if an async tag exception occurred at EL1.
         *
 
 long set_mte_ctrl(struct task_struct *task, unsigned long arg)
 {
-       u64 sctlr = task->thread.sctlr_user & ~SCTLR_EL1_TCF0_MASK;
        u64 mte_ctrl = (~((arg & PR_MTE_TAG_MASK) >> PR_MTE_TAG_SHIFT) &
                        SYS_GCR_EL1_EXCL_MASK) << MTE_CTRL_GCR_USER_EXCL_SHIFT;
 
        if (!system_supports_mte())
                return 0;
 
-       switch (arg & PR_MTE_TCF_MASK) {
-       case PR_MTE_TCF_NONE:
-               sctlr |= SCTLR_EL1_TCF0_NONE;
-               break;
-       case PR_MTE_TCF_SYNC:
-               sctlr |= SCTLR_EL1_TCF0_SYNC;
-               break;
-       case PR_MTE_TCF_ASYNC:
-               sctlr |= SCTLR_EL1_TCF0_ASYNC;
-               break;
-       default:
-               return -EINVAL;
-       }
+       if (arg & PR_MTE_TCF_ASYNC)
+               mte_ctrl |= MTE_CTRL_TCF_ASYNC;
+       if (arg & PR_MTE_TCF_SYNC)
+               mte_ctrl |= MTE_CTRL_TCF_SYNC;
 
-       if (task != current) {
-               task->thread.sctlr_user = sctlr;
-               task->thread.mte_ctrl = mte_ctrl;
-       } else {
-               set_task_sctlr_el1(sctlr);
-               set_gcr_el1_excl(mte_ctrl);
+       task->thread.mte_ctrl = mte_ctrl;
+       if (task == current) {
+               mte_update_sctlr_user(task);
+               set_task_sctlr_el1(task->thread.sctlr_user);
        }
 
        return 0;
                return 0;
 
        ret = incl << PR_MTE_TAG_SHIFT;
-
-       switch (task->thread.sctlr_user & SCTLR_EL1_TCF0_MASK) {
-       case SCTLR_EL1_TCF0_NONE:
-               ret |= PR_MTE_TCF_NONE;
-               break;
-       case SCTLR_EL1_TCF0_SYNC:
-               ret |= PR_MTE_TCF_SYNC;
-               break;
-       case SCTLR_EL1_TCF0_ASYNC:
+       if (mte_ctrl & MTE_CTRL_TCF_ASYNC)
                ret |= PR_MTE_TCF_ASYNC;
-               break;
-       }
+       if (mte_ctrl & MTE_CTRL_TCF_SYNC)
+               ret |= PR_MTE_TCF_SYNC;
 
        return ret;
 }
 
 #define PR_GET_TAGGED_ADDR_CTRL                56
 # define PR_TAGGED_ADDR_ENABLE         (1UL << 0)
 /* MTE tag check fault modes */
-# define PR_MTE_TCF_SHIFT              1
-# define PR_MTE_TCF_NONE               (0UL << PR_MTE_TCF_SHIFT)
-# define PR_MTE_TCF_SYNC               (1UL << PR_MTE_TCF_SHIFT)
-# define PR_MTE_TCF_ASYNC              (2UL << PR_MTE_TCF_SHIFT)
-# define PR_MTE_TCF_MASK               (3UL << PR_MTE_TCF_SHIFT)
+# define PR_MTE_TCF_NONE               0
+# define PR_MTE_TCF_SYNC               (1UL << 1)
+# define PR_MTE_TCF_ASYNC              (1UL << 2)
+# define PR_MTE_TCF_MASK               (PR_MTE_TCF_SYNC | PR_MTE_TCF_ASYNC)
 /* MTE tag inclusion mask */
 # define PR_MTE_TAG_SHIFT              3
 # define PR_MTE_TAG_MASK               (0xffffUL << PR_MTE_TAG_SHIFT)
+/* Unused; kept only for source compatibility */
+# define PR_MTE_TCF_SHIFT              1
 
 /* Control reclaim behavior when allocating memory */
 #define PR_SET_IO_FLUSHER              57