start -= unaccepted->phys_base;
        end -= unaccepted->phys_base;
 
+       /*
+        * load_unaligned_zeropad() can lead to unwanted loads across page
+        * boundaries. The unwanted loads are typically harmless. But, they
+        * might be made to totally unrelated or even unmapped memory.
+        * load_unaligned_zeropad() relies on exception fixup (#PF, #GP and now
+        * #VE) to recover from these unwanted loads.
+        *
+        * But, this approach does not work for unaccepted memory. For TDX, a
+        * load from unaccepted memory will not lead to a recoverable exception
+        * within the guest. The guest will exit to the VMM where the only
+        * recourse is to terminate the guest.
+        *
+        * There are two parts to fix this issue and comprehensively avoid
+        * access to unaccepted memory. Together these ensure that an extra
+        * "guard" page is accepted in addition to the memory that needs to be
+        * used:
+        *
+        * 1. Implicitly extend the range_contains_unaccepted_memory(start, end)
+        *    checks up to end+unit_size if 'end' is aligned on a unit_size
+        *    boundary.
+        *
+        * 2. Implicitly extend accept_memory(start, end) to end+unit_size if
+        *    'end' is aligned on a unit_size boundary. (immediately following
+        *    this comment)
+        */
+       if (!(end % unit_size))
+               end += unit_size;
+
        /* Make sure not to overrun the bitmap */
        if (end > unaccepted->size * unit_size * BITS_PER_BYTE)
                end = unaccepted->size * unit_size * BITS_PER_BYTE;
        start -= unaccepted->phys_base;
        end -= unaccepted->phys_base;
 
+       /*
+        * Also consider the unaccepted state of the *next* page. See fix #1 in
+        * the comment on load_unaligned_zeropad() in accept_memory().
+        */
+       if (!(end % unit_size))
+               end += unit_size;
+
        /* Make sure not to overrun the bitmap */
        if (end > unaccepted->size * unit_size * BITS_PER_BYTE)
                end = unaccepted->size * unit_size * BITS_PER_BYTE;