{
        u32 *state, state_buf[16 + 2] __aligned(8);
        struct skcipher_walk walk;
+       int next_yield = 4096; /* bytes until next FPU yield */
        int err;
 
        BUILD_BUG_ON(CHACHA_STATE_ALIGN != 16);
        while (walk.nbytes > 0) {
                unsigned int nbytes = walk.nbytes;
 
-               if (nbytes < walk.total)
+               if (nbytes < walk.total) {
                        nbytes = round_down(nbytes, walk.stride);
+                       next_yield -= nbytes;
+               }
 
                chacha_dosimd(state, walk.dst.virt.addr, walk.src.virt.addr,
                              nbytes, ctx->nrounds);
 
+               if (next_yield <= 0) {
+                       /* temporarily allow preemption */
+                       kernel_fpu_end();
+                       kernel_fpu_begin();
+                       next_yield = 4096;
+               }
+
                err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
        }