* along with this program.  If not, see <http://www.gnu.org/licenses/>.
  */
 
+#include <linux/irqflags.h>
+
 #include <asm/kvm_hyp.h>
 #include <asm/kvm_mmu.h>
 #include <asm/tlbflush.h>
 
-static void __hyp_text __tlb_switch_to_guest_vhe(struct kvm *kvm)
+static void __hyp_text __tlb_switch_to_guest_vhe(struct kvm *kvm,
+                                                unsigned long *flags)
 {
        u64 val;
 
+       local_irq_save(*flags);
+
        /*
         * With VHE enabled, we have HCR_EL2.{E2H,TGE} = {1,1}, and
         * most TLB operations target EL2/EL0. In order to affect the
        isb();
 }
 
-static void __hyp_text __tlb_switch_to_guest_nvhe(struct kvm *kvm)
+static void __hyp_text __tlb_switch_to_guest_nvhe(struct kvm *kvm,
+                                                 unsigned long *flags)
 {
        __load_guest_stage2(kvm);
        isb();
                            __tlb_switch_to_guest_vhe,
                            ARM64_HAS_VIRT_HOST_EXTN);
 
-static void __hyp_text __tlb_switch_to_host_vhe(struct kvm *kvm)
+static void __hyp_text __tlb_switch_to_host_vhe(struct kvm *kvm,
+                                               unsigned long flags)
 {
        /*
         * We're done with the TLB operation, let's restore the host's
         */
        write_sysreg(0, vttbr_el2);
        write_sysreg(HCR_HOST_VHE_FLAGS, hcr_el2);
+       isb();
+       local_irq_restore(flags);
 }
 
-static void __hyp_text __tlb_switch_to_host_nvhe(struct kvm *kvm)
+static void __hyp_text __tlb_switch_to_host_nvhe(struct kvm *kvm,
+                                                unsigned long flags)
 {
        write_sysreg(0, vttbr_el2);
 }
 
 void __hyp_text __kvm_tlb_flush_vmid_ipa(struct kvm *kvm, phys_addr_t ipa)
 {
+       unsigned long flags;
+
        dsb(ishst);
 
        /* Switch to requested VMID */
        kvm = kern_hyp_va(kvm);
-       __tlb_switch_to_guest()(kvm);
+       __tlb_switch_to_guest()(kvm, &flags);
 
        /*
         * We could do so much better if we had the VA as well.
        if (!has_vhe() && icache_is_vpipt())
                __flush_icache_all();
 
-       __tlb_switch_to_host()(kvm);
+       __tlb_switch_to_host()(kvm, flags);
 }
 
 void __hyp_text __kvm_tlb_flush_vmid(struct kvm *kvm)
 {
+       unsigned long flags;
+
        dsb(ishst);
 
        /* Switch to requested VMID */
        kvm = kern_hyp_va(kvm);
-       __tlb_switch_to_guest()(kvm);
+       __tlb_switch_to_guest()(kvm, &flags);
 
        __tlbi(vmalls12e1is);
        dsb(ish);
        isb();
 
-       __tlb_switch_to_host()(kvm);
+       __tlb_switch_to_host()(kvm, flags);
 }
 
 void __hyp_text __kvm_tlb_flush_local_vmid(struct kvm_vcpu *vcpu)
 {
        struct kvm *kvm = kern_hyp_va(kern_hyp_va(vcpu)->kvm);
+       unsigned long flags;
 
        /* Switch to requested VMID */
-       __tlb_switch_to_guest()(kvm);
+       __tlb_switch_to_guest()(kvm, &flags);
 
        __tlbi(vmalle1);
        dsb(nsh);
        isb();
 
-       __tlb_switch_to_host()(kvm);
+       __tlb_switch_to_host()(kvm, flags);
 }
 
 void __hyp_text __kvm_flush_vm_context(void)