#include <linux/sched/clock.h>
 #include <linux/uuid.h>
 #include <linux/ras.h>
+#include <linux/task_work.h>
 
 #include <acpi/actbl1.h>
 #include <acpi/ghes.h>
                ghes_ack_error(ghes->generic_v2);
 }
 
-static void ghes_handle_memory_failure(struct acpi_hest_generic_data *gdata, int sev)
+/*
+ * Called as task_work before returning to user-space.
+ * Ensure any queued work has been done before we return to the context that
+ * triggered the notification.
+ */
+static void ghes_kick_task_work(struct callback_head *head)
+{
+       struct acpi_hest_generic_status *estatus;
+       struct ghes_estatus_node *estatus_node;
+       u32 node_len;
+
+       estatus_node = container_of(head, struct ghes_estatus_node, task_work);
+       if (IS_ENABLED(CONFIG_ACPI_APEI_MEMORY_FAILURE))
+               memory_failure_queue_kick(estatus_node->task_work_cpu);
+
+       estatus = GHES_ESTATUS_FROM_NODE(estatus_node);
+       node_len = GHES_ESTATUS_NODE_LEN(cper_estatus_len(estatus));
+       gen_pool_free(ghes_estatus_pool, (unsigned long)estatus_node, node_len);
+}
+
+static bool ghes_handle_memory_failure(struct acpi_hest_generic_data *gdata,
+                                      int sev)
 {
-#ifdef CONFIG_ACPI_APEI_MEMORY_FAILURE
        unsigned long pfn;
        int flags = -1;
        int sec_sev = ghes_severity(gdata->error_severity);
        struct cper_sec_mem_err *mem_err = acpi_hest_get_payload(gdata);
 
+       if (!IS_ENABLED(CONFIG_ACPI_APEI_MEMORY_FAILURE))
+               return false;
+
        if (!(mem_err->validation_bits & CPER_MEM_VALID_PA))
-               return;
+               return false;
 
        pfn = mem_err->physical_addr >> PAGE_SHIFT;
        if (!pfn_valid(pfn)) {
                pr_warn_ratelimited(FW_WARN GHES_PFX
                "Invalid address in generic error data: %#llx\n",
                mem_err->physical_addr);
-               return;
+               return false;
        }
 
        /* iff following two events can be handled properly by now */
        if (sev == GHES_SEV_RECOVERABLE && sec_sev == GHES_SEV_RECOVERABLE)
                flags = 0;
 
-       if (flags != -1)
+       if (flags != -1) {
                memory_failure_queue(pfn, flags);
-#endif
+               return true;
+       }
+
+       return false;
 }
 
 /*
 #endif
 }
 
-static void ghes_do_proc(struct ghes *ghes,
+static bool ghes_do_proc(struct ghes *ghes,
                         const struct acpi_hest_generic_status *estatus)
 {
        int sev, sec_sev;
        guid_t *sec_type;
        const guid_t *fru_id = &guid_null;
        char *fru_text = "";
+       bool queued = false;
 
        sev = ghes_severity(estatus->error_severity);
        apei_estatus_for_each_section(estatus, gdata) {
                        ghes_edac_report_mem_error(sev, mem_err);
 
                        arch_apei_report_mem_error(sev, mem_err);
-                       ghes_handle_memory_failure(gdata, sev);
+                       queued = ghes_handle_memory_failure(gdata, sev);
                }
                else if (guid_equal(sec_type, &CPER_SEC_PCIE)) {
                        ghes_handle_aer(gdata);
                                               gdata->error_data_length);
                }
        }
+
+       return queued;
 }
 
 static void __ghes_print_estatus(const char *pfx,
        struct ghes_estatus_node *estatus_node;
        struct acpi_hest_generic *generic;
        struct acpi_hest_generic_status *estatus;
+       bool task_work_pending;
        u32 len, node_len;
+       int ret;
 
        llnode = llist_del_all(&ghes_estatus_llist);
        /*
                estatus = GHES_ESTATUS_FROM_NODE(estatus_node);
                len = cper_estatus_len(estatus);
                node_len = GHES_ESTATUS_NODE_LEN(len);
-               ghes_do_proc(estatus_node->ghes, estatus);
+               task_work_pending = ghes_do_proc(estatus_node->ghes, estatus);
                if (!ghes_estatus_cached(estatus)) {
                        generic = estatus_node->generic;
                        if (ghes_print_estatus(NULL, generic, estatus))
                                ghes_estatus_cache_add(generic, estatus);
                }
-               gen_pool_free(ghes_estatus_pool, (unsigned long)estatus_node,
-                             node_len);
+
+               if (task_work_pending && current->mm != &init_mm) {
+                       estatus_node->task_work.func = ghes_kick_task_work;
+                       estatus_node->task_work_cpu = smp_processor_id();
+                       ret = task_work_add(current, &estatus_node->task_work,
+                                           true);
+                       if (ret)
+                               estatus_node->task_work.func = NULL;
+               }
+
+               if (!estatus_node->task_work.func)
+                       gen_pool_free(ghes_estatus_pool,
+                                     (unsigned long)estatus_node, node_len);
+
                llnode = next;
        }
 }
 
        estatus_node->ghes = ghes;
        estatus_node->generic = ghes->generic;
+       estatus_node->task_work.func = NULL;
        estatus = GHES_ESTATUS_FROM_NODE(estatus_node);
 
        if (__ghes_read_estatus(estatus, buf_paddr, fixmap_idx, len)) {