static bool chacha20_use_avx2;
 #endif
 
+static unsigned int chacha20_advance(unsigned int len, unsigned int maxblocks)
+{
+       len = min(len, maxblocks * CHACHA20_BLOCK_SIZE);
+       return round_up(len, CHACHA20_BLOCK_SIZE) / CHACHA20_BLOCK_SIZE;
+}
+
 static void chacha20_dosimd(u32 *state, u8 *dst, const u8 *src,
                            unsigned int bytes)
 {
                        dst += CHACHA20_BLOCK_SIZE * 8;
                        state[12] += 8;
                }
+               if (bytes > CHACHA20_BLOCK_SIZE * 4) {
+                       chacha20_8block_xor_avx2(state, dst, src, bytes);
+                       state[12] += chacha20_advance(bytes, 8);
+                       return;
+               }
        }
 #endif
        while (bytes >= CHACHA20_BLOCK_SIZE * 4) {
                dst += CHACHA20_BLOCK_SIZE * 4;
                state[12] += 4;
        }
-       while (bytes >= CHACHA20_BLOCK_SIZE) {
-               chacha20_block_xor_ssse3(state, dst, src, bytes);
-               bytes -= CHACHA20_BLOCK_SIZE;
-               src += CHACHA20_BLOCK_SIZE;
-               dst += CHACHA20_BLOCK_SIZE;
-               state[12]++;
+       if (bytes > CHACHA20_BLOCK_SIZE) {
+               chacha20_4block_xor_ssse3(state, dst, src, bytes);
+               state[12] += chacha20_advance(bytes, 4);
+               return;
        }
        if (bytes) {
                chacha20_block_xor_ssse3(state, dst, src, bytes);
+               state[12]++;
        }
 }
 
 
        kernel_fpu_begin();
 
-       while (walk.nbytes >= CHACHA20_BLOCK_SIZE) {
-               chacha20_dosimd(state, walk.dst.virt.addr, walk.src.virt.addr,
-                               rounddown(walk.nbytes, CHACHA20_BLOCK_SIZE));
-               err = skcipher_walk_done(&walk,
-                                        walk.nbytes % CHACHA20_BLOCK_SIZE);
-       }
+       while (walk.nbytes > 0) {
+               unsigned int nbytes = walk.nbytes;
+
+               if (nbytes < walk.total)
+                       nbytes = round_down(nbytes, walk.stride);
 
-       if (walk.nbytes) {
                chacha20_dosimd(state, walk.dst.virt.addr, walk.src.virt.addr,
-                               walk.nbytes);
-               err = skcipher_walk_done(&walk, 0);
+                               nbytes);
+
+               err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
        }
 
        kernel_fpu_end();