// SPDX-License-Identifier: GPL-2.0-only
 /*
- * aes-ccm-glue.c - AES-CCM transform for ARMv8 with Crypto Extensions
+ * aes-ce-ccm-glue.c - AES-CCM transform for ARMv8 with Crypto Extensions
  *
- * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
+ * Copyright (C) 2013 - 2017 Linaro Ltd.
+ * Copyright (C) 2024 Google LLC
+ *
+ * Author: Ard Biesheuvel <ardb@kernel.org>
  */
 
 #include <asm/neon.h>
        struct crypto_aes_ctx *ctx = crypto_aead_ctx(aead);
        struct skcipher_walk walk;
        u8 __aligned(8) mac[AES_BLOCK_SIZE];
-       u8 buf[AES_BLOCK_SIZE];
+       u8 orig_iv[AES_BLOCK_SIZE];
        u32 len = req->cryptlen;
        int err;
 
                return err;
 
        /* preserve the original iv for the final round */
-       memcpy(buf, req->iv, AES_BLOCK_SIZE);
+       memcpy(orig_iv, req->iv, AES_BLOCK_SIZE);
 
        err = skcipher_walk_aead_encrypt(&walk, req, false);
        if (unlikely(err))
 
        do {
                u32 tail = walk.nbytes % AES_BLOCK_SIZE;
+               const u8 *src = walk.src.virt.addr;
+               u8 *dst = walk.dst.virt.addr;
+               u8 buf[AES_BLOCK_SIZE];
 
                if (walk.nbytes == walk.total)
                        tail = 0;
 
-               ce_aes_ccm_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  walk.nbytes - tail, ctx->key_enc,
-                                  num_rounds(ctx), mac, walk.iv);
+               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+                       src = dst = memcpy(&buf[sizeof(buf) - walk.nbytes],
+                                          src, walk.nbytes);
+
+               ce_aes_ccm_encrypt(dst, src, walk.nbytes - tail,
+                                  ctx->key_enc, num_rounds(ctx),
+                                  mac, walk.iv);
+
+               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+                       memcpy(walk.dst.virt.addr, dst, walk.nbytes);
 
                if (walk.nbytes == walk.total)
-                       ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
+                       ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
 
                if (walk.nbytes) {
                        err = skcipher_walk_done(&walk, tail);
        unsigned int authsize = crypto_aead_authsize(aead);
        struct skcipher_walk walk;
        u8 __aligned(8) mac[AES_BLOCK_SIZE];
-       u8 buf[AES_BLOCK_SIZE];
+       u8 orig_iv[AES_BLOCK_SIZE];
        u32 len = req->cryptlen - authsize;
        int err;
 
                return err;
 
        /* preserve the original iv for the final round */
-       memcpy(buf, req->iv, AES_BLOCK_SIZE);
+       memcpy(orig_iv, req->iv, AES_BLOCK_SIZE);
 
        err = skcipher_walk_aead_decrypt(&walk, req, false);
        if (unlikely(err))
 
        do {
                u32 tail = walk.nbytes % AES_BLOCK_SIZE;
+               const u8 *src = walk.src.virt.addr;
+               u8 *dst = walk.dst.virt.addr;
+               u8 buf[AES_BLOCK_SIZE];
 
                if (walk.nbytes == walk.total)
                        tail = 0;
 
-               ce_aes_ccm_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                                  walk.nbytes - tail, ctx->key_enc,
-                                  num_rounds(ctx), mac, walk.iv);
+               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+                       src = dst = memcpy(&buf[sizeof(buf) - walk.nbytes],
+                                          src, walk.nbytes);
+
+               ce_aes_ccm_decrypt(dst, src, walk.nbytes - tail,
+                                  ctx->key_enc, num_rounds(ctx),
+                                  mac, walk.iv);
+
+               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+                       memcpy(walk.dst.virt.addr, dst, walk.nbytes);
 
                if (walk.nbytes == walk.total)
-                       ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
+                       ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
 
                if (walk.nbytes) {
                        err = skcipher_walk_done(&walk, tail);
                return err;
 
        /* compare calculated auth tag with the stored one */
-       scatterwalk_map_and_copy(buf, req->src,
+       scatterwalk_map_and_copy(orig_iv, req->src,
                                 req->assoclen + req->cryptlen - authsize,
                                 authsize, 0);
 
-       if (crypto_memneq(mac, buf, authsize))
+       if (crypto_memneq(mac, orig_iv, authsize))
                return -EBADMSG;
        return 0;
 }
 module_exit(aes_mod_exit);
 
 MODULE_DESCRIPTION("Synchronous AES in CCM mode using ARMv8 Crypto Extensions");
-MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
+MODULE_AUTHOR("Ard Biesheuvel <ardb@kernel.org>");
 MODULE_LICENSE("GPL v2");
 MODULE_ALIAS_CRYPTO("ccm(aes)");