#include <linux/processor.h>
 #include <linux/preempt.h>
 #include <linux/string.h>
+#include <linux/sched.h>
 #include <asm/sigcontext.h>
 #include <asm/fpu-types.h>
 #include <asm/fpu-insn.h>
 #define KERNEL_VXR             (KERNEL_VXR_LOW    | KERNEL_VXR_HIGH)
 #define KERNEL_FPR             (KERNEL_FPC        | KERNEL_VXR_LOW)
 
-/*
- * Note the functions below must be called with preemption disabled.
- * Do not enable preemption before calling __kernel_fpu_end() to prevent
- * an corruption of an existing kernel FPU state.
- *
- * Prefer using the kernel_fpu_begin()/kernel_fpu_end() pair of functions.
- */
 void __kernel_fpu_begin(struct kernel_fpu *state, int flags);
 void __kernel_fpu_end(struct kernel_fpu *state, int flags);
 
 
 static inline void kernel_fpu_begin(struct kernel_fpu *state, int flags)
 {
-       preempt_disable();
-       state->mask = S390_lowcore.fpu_flags;
+       state->mask = READ_ONCE(current->thread.kfpu_flags);
        if (!test_thread_flag(TIF_FPU)) {
                /* Save user space FPU state and register contents */
                save_user_fpu_regs();
                /* Save FPU/vector register in-use by the kernel */
                __kernel_fpu_begin(state, flags);
        }
-       S390_lowcore.fpu_flags |= flags;
+       __atomic_or(flags, ¤t->thread.kfpu_flags);
 }
 
 static inline void kernel_fpu_end(struct kernel_fpu *state, int flags)
 {
-       S390_lowcore.fpu_flags = state->mask;
+       WRITE_ONCE(current->thread.kfpu_flags, state->mask);
        if (state->mask & flags) {
                /* Restore FPU/vector register in-use by the kernel */
                __kernel_fpu_end(state, flags);
        }
-       preempt_enable();
+}
+
+static inline void save_kernel_fpu_regs(struct thread_struct *thread)
+{
+       struct fpu *state = &thread->kfpu;
+
+       if (!thread->kfpu_flags)
+               return;
+       fpu_stfpc(&state->fpc);
+       if (likely(cpu_has_vx()))
+               save_vx_regs(state->vxrs);
+       else
+               save_fp_regs(state->fprs);
+}
+
+static inline void restore_kernel_fpu_regs(struct thread_struct *thread)
+{
+       struct fpu *state = &thread->kfpu;
+
+       if (!thread->kfpu_flags)
+               return;
+       fpu_lfpc(&state->fpc);
+       if (likely(cpu_has_vx()))
+               load_vx_regs(state->vxrs);
+       else
+               load_fp_regs(state->fprs);
 }
 
 static inline void convert_vx_to_fp(freg_t *fprs, __vector128 *vxrs)
 
        unsigned int gmap_write_flag;           /* gmap fault write indication */
        unsigned int gmap_int_code;             /* int code of last gmap fault */
        unsigned int gmap_pfault;               /* signal of a pending guest pfault */
+       int kfpu_flags;                         /* kernel fpu flags */
 
        /* Per-thread information related to debugging */
        struct per_regs per_user;               /* User specified PER registers */
        struct gs_cb *gs_bc_cb;                 /* Broadcast guarded storage cb */
        struct pgm_tdb trap_tdb;                /* Transaction abort diagnose block */
        struct fpu ufpu;                        /* User FP and VX register save area */
+       struct fpu kfpu;                        /* Kernel FP and VX register save area */
 };
 
 /* Flag to disable transactions. */
 
 
        *dst = *src;
        dst->thread.ufpu.regs = dst->thread.ufpu.fprs;
+       dst->thread.kfpu_flags = 0;
 
        /*
         * Don't transfer over the runtime instrumentation or the guarded
 struct task_struct *__switch_to(struct task_struct *prev, struct task_struct *next)
 {
        save_user_fpu_regs();
+       save_kernel_fpu_regs(&prev->thread);
        save_access_regs(&prev->thread.acrs[0]);
        save_ri_cb(prev->thread.ri_cb);
        save_gs_cb(prev->thread.gs_cb);
        update_cr_regs(next);
+       restore_kernel_fpu_regs(&next->thread);
        restore_access_regs(&next->thread.acrs[0]);
        restore_ri_cb(next->thread.ri_cb, prev->thread.ri_cb);
        restore_gs_cb(next->thread.gs_cb);