#include "kvm_util.h"
 #include "processor.h"
 #include <linux/bitfield.h>
+#include <linux/sizes.h>
 
 #define DEFAULT_ARM64_GUEST_STACK_VADDR_MIN    0xac0000
 
        return (gva >> vm->page_shift) & mask;
 }
 
+static inline bool use_lpa2_pte_format(struct kvm_vm *vm)
+{
+       return (vm->page_size == SZ_4K || vm->page_size == SZ_16K) &&
+           (vm->pa_bits > 48 || vm->va_bits > 48);
+}
+
 static uint64_t addr_pte(struct kvm_vm *vm, uint64_t pa, uint64_t attrs)
 {
        uint64_t pte;
 
-       pte = pa & GENMASK(47, vm->page_shift);
-       if (vm->page_shift == 16)
-               pte |= FIELD_GET(GENMASK(51, 48), pa) << 12;
+       if (use_lpa2_pte_format(vm)) {
+               pte = pa & GENMASK(49, vm->page_shift);
+               pte |= FIELD_GET(GENMASK(51, 50), pa) << 8;
+               attrs &= ~GENMASK(9, 8);
+       } else {
+               pte = pa & GENMASK(47, vm->page_shift);
+               if (vm->page_shift == 16)
+                       pte |= FIELD_GET(GENMASK(51, 48), pa) << 12;
+       }
        pte |= attrs;
 
        return pte;
 {
        uint64_t pa;
 
-       pa = pte & GENMASK(47, vm->page_shift);
-       if (vm->page_shift == 16)
-               pa |= FIELD_GET(GENMASK(15, 12), pte) << 48;
+       if (use_lpa2_pte_format(vm)) {
+               pa = pte & GENMASK(49, vm->page_shift);
+               pa |= FIELD_GET(GENMASK(9, 8), pte) << 50;
+       } else {
+               pa = pte & GENMASK(47, vm->page_shift);
+               if (vm->page_shift == 16)
+                       pa |= FIELD_GET(GENMASK(15, 12), pte) << 48;
+       }
 
        return pa;
 }
 
        /* Configure base granule size */
        switch (vm->mode) {
-       case VM_MODE_P52V48_4K:
-               TEST_FAIL("AArch64 does not support 4K sized pages "
-                         "with 52-bit physical address ranges");
        case VM_MODE_PXXV48_4K:
                TEST_FAIL("AArch64 does not support 4K sized pages "
                          "with ANY-bit physical address ranges");
        case VM_MODE_P36V48_64K:
                tcr_el1 |= 1ul << 14; /* TG0 = 64KB */
                break;
+       case VM_MODE_P52V48_16K:
        case VM_MODE_P48V48_16K:
        case VM_MODE_P40V48_16K:
        case VM_MODE_P36V48_16K:
        case VM_MODE_P36V47_16K:
                tcr_el1 |= 2ul << 14; /* TG0 = 16KB */
                break;
+       case VM_MODE_P52V48_4K:
        case VM_MODE_P48V48_4K:
        case VM_MODE_P40V48_4K:
        case VM_MODE_P36V48_4K:
 
        /* Configure output size */
        switch (vm->mode) {
+       case VM_MODE_P52V48_4K:
+       case VM_MODE_P52V48_16K:
        case VM_MODE_P52V48_64K:
                tcr_el1 |= 6ul << 32; /* IPS = 52 bits */
                ttbr0_el1 |= FIELD_GET(GENMASK(51, 48), vm->pgd) << 2;
        /* TCR_EL1 |= IRGN0:WBWA | ORGN0:WBWA | SH0:Inner-Shareable */;
        tcr_el1 |= (1 << 8) | (1 << 10) | (3 << 12);
        tcr_el1 |= (64 - vm->va_bits) /* T0SZ */;
+       if (use_lpa2_pte_format(vm))
+               tcr_el1 |= (1ul << 59) /* DS */;
 
        vcpu_set_reg(vcpu, KVM_ARM64_SYS_REG(SYS_SCTLR_EL1), sctlr_el1);
        vcpu_set_reg(vcpu, KVM_ARM64_SYS_REG(SYS_TCR_EL1), tcr_el1);
 
 {
        static const char * const strings[] = {
                [VM_MODE_P52V48_4K]     = "PA-bits:52,  VA-bits:48,  4K pages",
+               [VM_MODE_P52V48_16K]    = "PA-bits:52,  VA-bits:48, 16K pages",
                [VM_MODE_P52V48_64K]    = "PA-bits:52,  VA-bits:48, 64K pages",
                [VM_MODE_P48V48_4K]     = "PA-bits:48,  VA-bits:48,  4K pages",
                [VM_MODE_P48V48_16K]    = "PA-bits:48,  VA-bits:48, 16K pages",
 
 const struct vm_guest_mode_params vm_guest_mode_params[] = {
        [VM_MODE_P52V48_4K]     = { 52, 48,  0x1000, 12 },
+       [VM_MODE_P52V48_16K]    = { 52, 48,  0x4000, 14 },
        [VM_MODE_P52V48_64K]    = { 52, 48, 0x10000, 16 },
        [VM_MODE_P48V48_4K]     = { 48, 48,  0x1000, 12 },
        [VM_MODE_P48V48_16K]    = { 48, 48,  0x4000, 14 },
        case VM_MODE_P36V48_64K:
                vm->pgtable_levels = 3;
                break;
+       case VM_MODE_P52V48_16K:
        case VM_MODE_P48V48_16K:
        case VM_MODE_P40V48_16K:
        case VM_MODE_P36V48_16K: