MODULE_LICENSE("GPL v2");
 
 /* defined in aes-modes.S */
-asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                int rounds, int blocks);
-asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
                                int rounds, int blocks);
 
-asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                int rounds, int blocks, u8 iv[]);
-asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
                                int rounds, int blocks, u8 iv[]);
 
-asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
+asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                int rounds, int blocks, u8 ctr[]);
 
-asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[],
-                               int rounds, int blocks, u8 const rk2[], u8 iv[],
+asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
+                               int rounds, int blocks, u32 const rk2[], u8 iv[],
                                int first);
-asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[],
-                               int rounds, int blocks, u8 const rk2[], u8 iv[],
+asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
+                               int rounds, int blocks, u32 const rk2[], u8 iv[],
                                int first);
 
 asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
                kernel_neon_begin();
                aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key_enc, rounds, blocks);
+                               ctx->key_enc, rounds, blocks);
                kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
                kernel_neon_begin();
                aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key_dec, rounds, blocks);
+                               ctx->key_dec, rounds, blocks);
                kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
                kernel_neon_begin();
                aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key_enc, rounds, blocks, walk.iv);
+                               ctx->key_enc, rounds, blocks, walk.iv);
                kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
                kernel_neon_begin();
                aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key_dec, rounds, blocks, walk.iv);
+                               ctx->key_dec, rounds, blocks, walk.iv);
                kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
                kernel_neon_begin();
                aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key_enc, rounds, blocks, walk.iv);
+                               ctx->key_enc, rounds, blocks, walk.iv);
                kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
                blocks = -1;
 
                kernel_neon_begin();
-               aes_ctr_encrypt(tail, NULL, (u8 *)ctx->key_enc, rounds,
+               aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
                                blocks, walk.iv);
                kernel_neon_end();
                crypto_xor_cpy(tdst, tsrc, tail, nbytes);
        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
                kernel_neon_begin();
                aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key1.key_enc, rounds, blocks,
-                               (u8 *)ctx->key2.key_enc, walk.iv, first);
+                               ctx->key1.key_enc, rounds, blocks,
+                               ctx->key2.key_enc, walk.iv, first);
                kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
        for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
                kernel_neon_begin();
                aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
-                               (u8 *)ctx->key1.key_dec, rounds, blocks,
-                               (u8 *)ctx->key2.key_enc, walk.iv, first);
+                               ctx->key1.key_dec, rounds, blocks,
+                               ctx->key2.key_enc, walk.iv, first);
                kernel_neon_end();
                err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
        }
 {
        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
        be128 *consts = (be128 *)ctx->consts;
-       u8 *rk = (u8 *)ctx->key.key_enc;
        int rounds = 6 + key_len / 4;
        int err;
 
 
        /* encrypt the zero vector */
        kernel_neon_begin();
-       aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, rk, rounds, 1);
+       aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
+                       rounds, 1);
        kernel_neon_end();
 
        cmac_gf128_mul_by_x(consts, consts);
        };
 
        struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
-       u8 *rk = (u8 *)ctx->key.key_enc;
        int rounds = 6 + key_len / 4;
        u8 key[AES_BLOCK_SIZE];
        int err;
                return err;
 
        kernel_neon_begin();
-       aes_ecb_encrypt(key, ks[0], rk, rounds, 1);
-       aes_ecb_encrypt(ctx->consts, ks[1], rk, rounds, 2);
+       aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
+       aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
        kernel_neon_end();
 
        return cbcmac_setkey(tfm, key, sizeof(key));