PUD_TYPE_TABLE)
 #endif
 
+extern pgd_t init_pg_dir[PTRS_PER_PGD];
+extern pgd_t init_pg_end[];
+extern pgd_t swapper_pg_dir[PTRS_PER_PGD];
+extern pgd_t idmap_pg_dir[PTRS_PER_PGD];
+extern pgd_t tramp_pg_dir[PTRS_PER_PGD];
+
+extern void set_swapper_pgd(pgd_t *pgdp, pgd_t pgd);
+
+static inline bool in_swapper_pgdir(void *addr)
+{
+       return ((unsigned long)addr & PAGE_MASK) ==
+               ((unsigned long)swapper_pg_dir & PAGE_MASK);
+}
+
 static inline void set_pmd(pmd_t *pmdp, pmd_t pmd)
 {
+       if (__is_defined(__PAGETABLE_PMD_FOLDED) && in_swapper_pgdir(pmdp)) {
+               set_swapper_pgd((pgd_t *)pmdp, __pgd(pmd_val(pmd)));
+               return;
+       }
+
        WRITE_ONCE(*pmdp, pmd);
 
        if (pmd_valid(pmd))
 
 static inline void set_pud(pud_t *pudp, pud_t pud)
 {
+       if (__is_defined(__PAGETABLE_PUD_FOLDED) && in_swapper_pgdir(pudp)) {
+               set_swapper_pgd((pgd_t *)pudp, __pgd(pud_val(pud)));
+               return;
+       }
+
        WRITE_ONCE(*pudp, pud);
 
        if (pud_valid(pud))
 
 static inline void set_pgd(pgd_t *pgdp, pgd_t pgd)
 {
+       if (in_swapper_pgdir(pgdp)) {
+               set_swapper_pgd(pgdp, pgd);
+               return;
+       }
+
        WRITE_ONCE(*pgdp, pgd);
        dsb(ishst);
 }
 }
 #endif
 
-extern pgd_t init_pg_dir[PTRS_PER_PGD];
-extern pgd_t init_pg_end[];
-extern pgd_t swapper_pg_dir[PTRS_PER_PGD];
-extern pgd_t idmap_pg_dir[PTRS_PER_PGD];
-extern pgd_t tramp_pg_dir[PTRS_PER_PGD];
-
 /*
  * Encode and decode a swap entry:
  *     bits 0-1:       present (must be zero)
 
 static pmd_t bm_pmd[PTRS_PER_PMD] __page_aligned_bss __maybe_unused;
 static pud_t bm_pud[PTRS_PER_PUD] __page_aligned_bss __maybe_unused;
 
+static DEFINE_SPINLOCK(swapper_pgdir_lock);
+
+void set_swapper_pgd(pgd_t *pgdp, pgd_t pgd)
+{
+       pgd_t *fixmap_pgdp;
+
+       spin_lock(&swapper_pgdir_lock);
+       fixmap_pgdp = pgd_set_fixmap(__pa(pgdp));
+       WRITE_ONCE(*fixmap_pgdp, pgd);
+       /*
+        * We need dsb(ishst) here to ensure the page-table-walker sees
+        * our new entry before set_p?d() returns. The fixmap's
+        * flush_tlb_kernel_range() via clear_fixmap() does this for us.
+        */
+       pgd_clear_fixmap();
+       spin_unlock(&swapper_pgdir_lock);
+}
+
 pgprot_t phys_mem_access_prot(struct file *file, unsigned long pfn,
                              unsigned long size, pgprot_t vma_prot)
 {
  */
 void __init paging_init(void)
 {
-       map_kernel(swapper_pg_dir);
-       map_mem(swapper_pg_dir);
+       pgd_t *pgdp = pgd_set_fixmap(__pa_symbol(swapper_pg_dir));
+
+       map_kernel(pgdp);
+       map_mem(pgdp);
+
+       pgd_clear_fixmap();
 
        cpu_replace_ttbr1(lm_alias(swapper_pg_dir));
        init_mm.pgd = swapper_pg_dir;