#ifdef __KERNEL__
 
+#include <asm/alternative.h>
 #include <asm/ptrace.h>
+#include <asm/sysreg.h>
 
 /*
  * Aarch64 has flags for masking: Debug, Asynchronous (serror), Interrupts and
 /*
  * CPU interrupt mask handling.
  */
-static inline unsigned long arch_local_irq_save(void)
-{
-       unsigned long flags;
-       asm volatile(
-               "mrs    %0, daif                // arch_local_irq_save\n"
-               "msr    daifset, #2"
-               : "=r" (flags)
-               :
-               : "memory");
-       return flags;
-}
-
 static inline void arch_local_irq_enable(void)
 {
-       asm volatile(
-               "msr    daifclr, #2             // arch_local_irq_enable"
-               :
+       asm volatile(ALTERNATIVE(
+               "msr    daifclr, #2             // arch_local_irq_enable\n"
+               "nop",
+               "msr_s  " __stringify(SYS_ICC_PMR_EL1) ",%0\n"
+               "dsb    sy",
+               ARM64_HAS_IRQ_PRIO_MASKING)
                :
+               : "r" (GIC_PRIO_IRQON)
                : "memory");
 }
 
 static inline void arch_local_irq_disable(void)
 {
-       asm volatile(
-               "msr    daifset, #2             // arch_local_irq_disable"
-               :
+       asm volatile(ALTERNATIVE(
+               "msr    daifset, #2             // arch_local_irq_disable",
+               "msr_s  " __stringify(SYS_ICC_PMR_EL1) ", %0",
+               ARM64_HAS_IRQ_PRIO_MASKING)
                :
+               : "r" (GIC_PRIO_IRQOFF)
                : "memory");
 }
 
  */
 static inline unsigned long arch_local_save_flags(void)
 {
+       unsigned long daif_bits;
        unsigned long flags;
-       asm volatile(
-               "mrs    %0, daif                // arch_local_save_flags"
-               : "=r" (flags)
-               :
+
+       daif_bits = read_sysreg(daif);
+
+       /*
+        * The asm is logically equivalent to:
+        *
+        * if (system_uses_irq_prio_masking())
+        *      flags = (daif_bits & PSR_I_BIT) ?
+        *                      GIC_PRIO_IRQOFF :
+        *                      read_sysreg_s(SYS_ICC_PMR_EL1);
+        * else
+        *      flags = daif_bits;
+        */
+       asm volatile(ALTERNATIVE(
+                       "mov    %0, %1\n"
+                       "nop\n"
+                       "nop",
+                       "mrs_s  %0, " __stringify(SYS_ICC_PMR_EL1) "\n"
+                       "ands   %1, %1, " __stringify(PSR_I_BIT) "\n"
+                       "csel   %0, %0, %2, eq",
+                       ARM64_HAS_IRQ_PRIO_MASKING)
+               : "=&r" (flags), "+r" (daif_bits)
+               : "r" (GIC_PRIO_IRQOFF)
                : "memory");
+
+       return flags;
+}
+
+static inline unsigned long arch_local_irq_save(void)
+{
+       unsigned long flags;
+
+       flags = arch_local_save_flags();
+
+       arch_local_irq_disable();
+
        return flags;
 }
 
  */
 static inline void arch_local_irq_restore(unsigned long flags)
 {
-       asm volatile(
-               "msr    daif, %0                // arch_local_irq_restore"
-       :
-       : "r" (flags)
-       : "memory");
+       asm volatile(ALTERNATIVE(
+                       "msr    daif, %0\n"
+                       "nop",
+                       "msr_s  " __stringify(SYS_ICC_PMR_EL1) ", %0\n"
+                       "dsb    sy",
+                       ARM64_HAS_IRQ_PRIO_MASKING)
+               : "+r" (flags)
+               :
+               : "memory");
 }
 
 static inline int arch_irqs_disabled_flags(unsigned long flags)
 {
-       return flags & PSR_I_BIT;
+       int res;
+
+       asm volatile(ALTERNATIVE(
+                       "and    %w0, %w1, #" __stringify(PSR_I_BIT) "\n"
+                       "nop",
+                       "cmp    %w1, #" __stringify(GIC_PRIO_IRQOFF) "\n"
+                       "cset   %w0, ls",
+                       ARM64_HAS_IRQ_PRIO_MASKING)
+               : "=&r" (res)
+               : "r" ((int) flags)
+               : "memory");
+
+       return res;
 }
 #endif
 #endif