ENTRY(chacha20_block_xor_ssse3)
        # %rdi: Input state matrix, s
-       # %rsi: 1 data block output, o
-       # %rdx: 1 data block input, i
+       # %rsi: up to 1 data block output, o
+       # %rdx: up to 1 data block input, i
+       # %rcx: input/output length in bytes
 
        # This function encrypts one ChaCha20 block by loading the state matrix
        # in four SSE registers. It performs matrix operation on four words in
-       # parallel, but requireds shuffling to rearrange the words after each
+       # parallel, but requires shuffling to rearrange the words after each
        # round. 8/16-bit word rotation is done with the slightly better
        # performing SSSE3 byte shuffling, 7/12-bit word rotation uses
        # traditional shift+OR.
        movdqa          ROT8(%rip),%xmm4
        movdqa          ROT16(%rip),%xmm5
 
-       mov     $10,%ecx
+       mov             %rcx,%rax
+       mov             $10,%ecx
 
 .Ldoubleround:
 
        jnz             .Ldoubleround
 
        # o0 = i0 ^ (x0 + s0)
-       movdqu          0x00(%rdx),%xmm4
        paddd           %xmm8,%xmm0
+       cmp             $0x10,%rax
+       jl              .Lxorpart
+       movdqu          0x00(%rdx),%xmm4
        pxor            %xmm4,%xmm0
        movdqu          %xmm0,0x00(%rsi)
        # o1 = i1 ^ (x1 + s1)
-       movdqu          0x10(%rdx),%xmm5
        paddd           %xmm9,%xmm1
-       pxor            %xmm5,%xmm1
-       movdqu          %xmm1,0x10(%rsi)
+       movdqa          %xmm1,%xmm0
+       cmp             $0x20,%rax
+       jl              .Lxorpart
+       movdqu          0x10(%rdx),%xmm0
+       pxor            %xmm1,%xmm0
+       movdqu          %xmm0,0x10(%rsi)
        # o2 = i2 ^ (x2 + s2)
-       movdqu          0x20(%rdx),%xmm6
        paddd           %xmm10,%xmm2
-       pxor            %xmm6,%xmm2
-       movdqu          %xmm2,0x20(%rsi)
+       movdqa          %xmm2,%xmm0
+       cmp             $0x30,%rax
+       jl              .Lxorpart
+       movdqu          0x20(%rdx),%xmm0
+       pxor            %xmm2,%xmm0
+       movdqu          %xmm0,0x20(%rsi)
        # o3 = i3 ^ (x3 + s3)
-       movdqu          0x30(%rdx),%xmm7
        paddd           %xmm11,%xmm3
-       pxor            %xmm7,%xmm3
-       movdqu          %xmm3,0x30(%rsi)
-
+       movdqa          %xmm3,%xmm0
+       cmp             $0x40,%rax
+       jl              .Lxorpart
+       movdqu          0x30(%rdx),%xmm0
+       pxor            %xmm3,%xmm0
+       movdqu          %xmm0,0x30(%rsi)
+
+.Ldone:
        ret
+
+.Lxorpart:
+       # xor remaining bytes from partial register into output
+       mov             %rax,%r9
+       and             $0x0f,%r9
+       jz              .Ldone
+       and             $~0x0f,%rax
+
+       mov             %rsi,%r11
+
+       lea             8(%rsp),%r10
+       sub             $0x10,%rsp
+       and             $~31,%rsp
+
+       lea             (%rdx,%rax),%rsi
+       mov             %rsp,%rdi
+       mov             %r9,%rcx
+       rep movsb
+
+       pxor            0x00(%rsp),%xmm0
+       movdqa          %xmm0,0x00(%rsp)
+
+       mov             %rsp,%rsi
+       lea             (%r11,%rax),%rdi
+       mov             %r9,%rcx
+       rep movsb
+
+       lea             -8(%r10),%rsp
+       jmp             .Ldone
+
 ENDPROC(chacha20_block_xor_ssse3)
 
 ENTRY(chacha20_4block_xor_ssse3)
 
 
 #define CHACHA20_STATE_ALIGN 16
 
-asmlinkage void chacha20_block_xor_ssse3(u32 *state, u8 *dst, const u8 *src);
+asmlinkage void chacha20_block_xor_ssse3(u32 *state, u8 *dst, const u8 *src,
+                                        unsigned int len);
 asmlinkage void chacha20_4block_xor_ssse3(u32 *state, u8 *dst, const u8 *src);
 #ifdef CONFIG_AS_AVX2
 asmlinkage void chacha20_8block_xor_avx2(u32 *state, u8 *dst, const u8 *src);
 static void chacha20_dosimd(u32 *state, u8 *dst, const u8 *src,
                            unsigned int bytes)
 {
-       u8 buf[CHACHA20_BLOCK_SIZE];
-
 #ifdef CONFIG_AS_AVX2
        if (chacha20_use_avx2) {
                while (bytes >= CHACHA20_BLOCK_SIZE * 8) {
                state[12] += 4;
        }
        while (bytes >= CHACHA20_BLOCK_SIZE) {
-               chacha20_block_xor_ssse3(state, dst, src);
+               chacha20_block_xor_ssse3(state, dst, src, bytes);
                bytes -= CHACHA20_BLOCK_SIZE;
                src += CHACHA20_BLOCK_SIZE;
                dst += CHACHA20_BLOCK_SIZE;
                state[12]++;
        }
        if (bytes) {
-               memcpy(buf, src, bytes);
-               chacha20_block_xor_ssse3(state, buf, buf);
-               memcpy(dst, buf, bytes);
+               chacha20_block_xor_ssse3(state, dst, src, bytes);
        }
 }