aese    \vb\().16b, v4.16b
        .endm
 
-       /*
-        * void ce_aes_ccm_final(u8 mac[], u8 const ctr[], u8 const rk[],
-        *                       u32 rounds);
-        */
-SYM_FUNC_START(ce_aes_ccm_final)
-       ld1     {v0.16b}, [x0]                  /* load mac */
-       ld1     {v1.16b}, [x1]                  /* load 1st ctriv */
-
-       aes_encrypt     v0, v1, w3
-
-       /* final round key cancels out */
-       eor     v0.16b, v0.16b, v1.16b          /* en-/decrypt the mac */
-       st1     {v0.16b}, [x0]                  /* store result */
-       ret
-SYM_FUNC_END(ce_aes_ccm_final)
-
        .macro  aes_ccm_do_crypt,enc
        load_round_keys x3, w4, x10
 
-       cbz     x2, 5f
-       ldr     x8, [x6, #8]                    /* load lower ctr */
        ld1     {v0.16b}, [x5]                  /* load mac */
+       cbz     x2, ce_aes_ccm_final
+       ldr     x8, [x6, #8]                    /* load lower ctr */
 CPU_LE(        rev     x8, x8                  )       /* keep swabbed ctr in reg */
 0:     /* outer loop */
        ld1     {v1.8b}, [x6]                   /* load upper ctr */
        st1     {v6.16b}, [x0], #16             /* write output block */
        bne     0b
 CPU_LE(        rev     x8, x8                  )
-       st1     {v0.16b}, [x5]                  /* store mac */
        str     x8, [x6, #8]                    /* store lsb end of ctr (BE) */
-5:     ret
+       cbnz    x7, ce_aes_ccm_final
+       st1     {v0.16b}, [x5]                  /* store mac */
+       ret
        .endm
 
 SYM_FUNC_START_LOCAL(ce_aes_ccm_crypt_tail)
        tbl     v2.16b, {v2.16b}, v9.16b        /* copy plaintext to start of v2 */
        eor     v0.16b, v0.16b, v2.16b          /* fold plaintext into mac */
 
-       st1     {v0.16b}, [x5]                  /* store mac */
        st1     {v7.16b}, [x0]                  /* store output block */
+       cbz     x7, 0f
+
+SYM_INNER_LABEL(ce_aes_ccm_final, SYM_L_LOCAL)
+       ld1     {v1.16b}, [x7]                  /* load 1st ctriv */
+
+       aes_encrypt     v0, v1, w4
+
+       /* final round key cancels out */
+       eor     v0.16b, v0.16b, v1.16b          /* en-/decrypt the mac */
+0:     st1     {v0.16b}, [x5]                  /* store result */
        ret
 SYM_FUNC_END(ce_aes_ccm_crypt_tail)
 
        /*
         * void ce_aes_ccm_encrypt(u8 out[], u8 const in[], u32 cbytes,
         *                         u8 const rk[], u32 rounds, u8 mac[],
-        *                         u8 ctr[]);
+        *                         u8 ctr[], u8 const final_iv[]);
         * void ce_aes_ccm_decrypt(u8 out[], u8 const in[], u32 cbytes,
         *                         u8 const rk[], u32 rounds, u8 mac[],
-        *                         u8 ctr[]);
+        *                         u8 ctr[], u8 const final_iv[]);
         */
 SYM_FUNC_START(ce_aes_ccm_encrypt)
        movi    v22.16b, #255
 
 
 asmlinkage void ce_aes_ccm_encrypt(u8 out[], u8 const in[], u32 cbytes,
                                   u32 const rk[], u32 rounds, u8 mac[],
-                                  u8 ctr[]);
+                                  u8 ctr[], u8 const final_iv[]);
 
 asmlinkage void ce_aes_ccm_decrypt(u8 out[], u8 const in[], u32 cbytes,
                                   u32 const rk[], u32 rounds, u8 mac[],
-                                  u8 ctr[]);
-
-asmlinkage void ce_aes_ccm_final(u8 mac[], u8 const ctr[], u32 const rk[],
-                                u32 rounds);
+                                  u8 ctr[], u8 const final_iv[]);
 
 static int ccm_setkey(struct crypto_aead *tfm, const u8 *in_key,
                      unsigned int key_len)
                const u8 *src = walk.src.virt.addr;
                u8 *dst = walk.dst.virt.addr;
                u8 buf[AES_BLOCK_SIZE];
+               u8 *final_iv = NULL;
 
-               if (walk.nbytes == walk.total)
+               if (walk.nbytes == walk.total) {
                        tail = 0;
+                       final_iv = orig_iv;
+               }
 
                if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
                        src = dst = memcpy(&buf[sizeof(buf) - walk.nbytes],
 
                ce_aes_ccm_encrypt(dst, src, walk.nbytes - tail,
                                   ctx->key_enc, num_rounds(ctx),
-                                  mac, walk.iv);
+                                  mac, walk.iv, final_iv);
 
                if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
                        memcpy(walk.dst.virt.addr, dst, walk.nbytes);
 
-               if (walk.nbytes == walk.total)
-                       ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
-
                if (walk.nbytes) {
                        err = skcipher_walk_done(&walk, tail);
                }
                const u8 *src = walk.src.virt.addr;
                u8 *dst = walk.dst.virt.addr;
                u8 buf[AES_BLOCK_SIZE];
+               u8 *final_iv = NULL;
 
-               if (walk.nbytes == walk.total)
+               if (walk.nbytes == walk.total) {
                        tail = 0;
+                       final_iv = orig_iv;
+               }
 
                if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
                        src = dst = memcpy(&buf[sizeof(buf) - walk.nbytes],
 
                ce_aes_ccm_decrypt(dst, src, walk.nbytes - tail,
                                   ctx->key_enc, num_rounds(ctx),
-                                  mac, walk.iv);
+                                  mac, walk.iv, final_iv);
 
                if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
                        memcpy(walk.dst.virt.addr, dst, walk.nbytes);
 
-               if (walk.nbytes == walk.total)
-                       ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
-
                if (walk.nbytes) {
                        err = skcipher_walk_done(&walk, tail);
                }