From 17437f311490d873a5157f65a84317d16270fd38 Mon Sep 17 00:00:00 2001
From: Jens Axboe <axboe@kernel.dk>
Date: Wed, 25 May 2022 09:13:39 -0600
Subject: [PATCH] io_uring: move SQPOLL related handling into its own file

Signed-off-by: Jens Axboe <axboe@kernel.dk>
---
 io_uring/Makefile   |   3 +-
 io_uring/io_uring.c | 467 +-------------------------------------------
 io_uring/io_uring.h |  34 ++++
 io_uring/sqpoll.c   | 426 ++++++++++++++++++++++++++++++++++++++++
 io_uring/sqpoll.h   |  29 +++
 5 files changed, 497 insertions(+), 462 deletions(-)
 create mode 100644 io_uring/sqpoll.c
 create mode 100644 io_uring/sqpoll.h

diff --git a/io_uring/Makefile b/io_uring/Makefile
index 6ae4e45a15dbf..c59a9ca74262a 100644
--- a/io_uring/Makefile
+++ b/io_uring/Makefile
@@ -5,5 +5,6 @@
 obj-$(CONFIG_IO_URING)		+= io_uring.o xattr.o nop.o fs.o splice.o \
 					sync.o advise.o filetable.o \
 					openclose.o uring_cmd.o epoll.o \
-					statx.o net.o msg_ring.o timeout.o
+					statx.o net.o msg_ring.o timeout.o \
+					sqpoll.o
 obj-$(CONFIG_IO_WQ)		+= io-wq.o
diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c
index 3fc59a22d54e8..17c555aa03bc9 100644
--- a/io_uring/io_uring.c
+++ b/io_uring/io_uring.c
@@ -92,6 +92,7 @@
 #include "io_uring_types.h"
 #include "io_uring.h"
 #include "refs.h"
+#include "sqpoll.h"
 
 #include "xattr.h"
 #include "nop.h"
@@ -109,7 +110,6 @@
 
 #define IORING_MAX_ENTRIES	32768
 #define IORING_MAX_CQ_ENTRIES	(2 * IORING_MAX_ENTRIES)
-#define IORING_SQPOLL_CAP_ENTRIES_VALUE 8
 
 /* only define max */
 #define IORING_MAX_FIXED_FILES	(1U << 20)
@@ -214,31 +214,6 @@ struct io_buffer {
 	__u16 bgid;
 };
 
-enum {
-	IO_SQ_THREAD_SHOULD_STOP = 0,
-	IO_SQ_THREAD_SHOULD_PARK,
-};
-
-struct io_sq_data {
-	refcount_t		refs;
-	atomic_t		park_pending;
-	struct mutex		lock;
-
-	/* ctx's that are using this sqd */
-	struct list_head	ctx_list;
-
-	struct task_struct	*thread;
-	struct wait_queue_head	wait;
-
-	unsigned		sq_thread_idle;
-	int			sq_cpu;
-	pid_t			task_pid;
-	pid_t			task_tgid;
-
-	unsigned long		state;
-	struct completion	exited;
-};
-
 #define IO_COMPL_BATCH			32
 #define IO_REQ_CACHE_SIZE		32
 #define IO_REQ_ALLOC_BATCH		8
@@ -402,7 +377,6 @@ static void io_uring_del_tctx_node(unsigned long index);
 static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
 					 struct task_struct *task,
 					 bool cancel_all);
-static void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
 
 static void io_dismantle_req(struct io_kiocb *req);
 static int __io_register_rsrc_update(struct io_ring_ctx *ctx, unsigned type,
@@ -1079,13 +1053,6 @@ static void __io_commit_cqring_flush(struct io_ring_ctx *ctx)
 		io_eventfd_signal(ctx);
 }
 
-static inline bool io_sqring_full(struct io_ring_ctx *ctx)
-{
-	struct io_rings *r = ctx->rings;
-
-	return READ_ONCE(r->sq.tail) - ctx->cached_sq_head == ctx->sq_entries;
-}
-
 static inline unsigned int __io_cqring_events(struct io_ring_ctx *ctx)
 {
 	return ctx->cached_cq_tail - READ_ONCE(ctx->rings->cq.head);
@@ -1974,28 +1941,7 @@ static unsigned io_cqring_events(struct io_ring_ctx *ctx)
 	return __io_cqring_events(ctx);
 }
 
-static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
-{
-	struct io_rings *rings = ctx->rings;
-
-	/* make sure SQ entry isn't read before tail */
-	return smp_load_acquire(&rings->sq.tail) - ctx->cached_sq_head;
-}
-
-static inline bool io_run_task_work(void)
-{
-	if (test_thread_flag(TIF_NOTIFY_SIGNAL) || task_work_pending(current)) {
-		__set_current_state(TASK_RUNNING);
-		clear_notify_signal();
-		if (task_work_pending(current))
-			task_work_run();
-		return true;
-	}
-
-	return false;
-}
-
-static int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin)
+int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin)
 {
 	struct io_wq_work_node *pos, *start, *prev;
 	unsigned int poll_flags = BLK_POLL_NOSLEEP;
@@ -5297,7 +5243,7 @@ static const struct io_uring_sqe *io_get_sqe(struct io_ring_ctx *ctx)
 	return NULL;
 }
 
-static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
+int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
 	__must_hold(&ctx->uring_lock)
 {
 	unsigned int entries = io_sqring_entries(ctx);
@@ -5349,173 +5295,6 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr)
 	return ret;
 }
 
-static inline bool io_sqd_events_pending(struct io_sq_data *sqd)
-{
-	return READ_ONCE(sqd->state);
-}
-
-static int __io_sq_thread(struct io_ring_ctx *ctx, bool cap_entries)
-{
-	unsigned int to_submit;
-	int ret = 0;
-
-	to_submit = io_sqring_entries(ctx);
-	/* if we're handling multiple rings, cap submit size for fairness */
-	if (cap_entries && to_submit > IORING_SQPOLL_CAP_ENTRIES_VALUE)
-		to_submit = IORING_SQPOLL_CAP_ENTRIES_VALUE;
-
-	if (!wq_list_empty(&ctx->iopoll_list) || to_submit) {
-		const struct cred *creds = NULL;
-
-		if (ctx->sq_creds != current_cred())
-			creds = override_creds(ctx->sq_creds);
-
-		mutex_lock(&ctx->uring_lock);
-		if (!wq_list_empty(&ctx->iopoll_list))
-			io_do_iopoll(ctx, true);
-
-		/*
-		 * Don't submit if refs are dying, good for io_uring_register(),
-		 * but also it is relied upon by io_ring_exit_work()
-		 */
-		if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)) &&
-		    !(ctx->flags & IORING_SETUP_R_DISABLED))
-			ret = io_submit_sqes(ctx, to_submit);
-		mutex_unlock(&ctx->uring_lock);
-
-		if (to_submit && wq_has_sleeper(&ctx->sqo_sq_wait))
-			wake_up(&ctx->sqo_sq_wait);
-		if (creds)
-			revert_creds(creds);
-	}
-
-	return ret;
-}
-
-static __cold void io_sqd_update_thread_idle(struct io_sq_data *sqd)
-{
-	struct io_ring_ctx *ctx;
-	unsigned sq_thread_idle = 0;
-
-	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
-		sq_thread_idle = max(sq_thread_idle, ctx->sq_thread_idle);
-	sqd->sq_thread_idle = sq_thread_idle;
-}
-
-static bool io_sqd_handle_event(struct io_sq_data *sqd)
-{
-	bool did_sig = false;
-	struct ksignal ksig;
-
-	if (test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state) ||
-	    signal_pending(current)) {
-		mutex_unlock(&sqd->lock);
-		if (signal_pending(current))
-			did_sig = get_signal(&ksig);
-		cond_resched();
-		mutex_lock(&sqd->lock);
-	}
-	return did_sig || test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-}
-
-static int io_sq_thread(void *data)
-{
-	struct io_sq_data *sqd = data;
-	struct io_ring_ctx *ctx;
-	unsigned long timeout = 0;
-	char buf[TASK_COMM_LEN];
-	DEFINE_WAIT(wait);
-
-	snprintf(buf, sizeof(buf), "iou-sqp-%d", sqd->task_pid);
-	set_task_comm(current, buf);
-
-	if (sqd->sq_cpu != -1)
-		set_cpus_allowed_ptr(current, cpumask_of(sqd->sq_cpu));
-	else
-		set_cpus_allowed_ptr(current, cpu_online_mask);
-	current->flags |= PF_NO_SETAFFINITY;
-
-	audit_alloc_kernel(current);
-
-	mutex_lock(&sqd->lock);
-	while (1) {
-		bool cap_entries, sqt_spin = false;
-
-		if (io_sqd_events_pending(sqd) || signal_pending(current)) {
-			if (io_sqd_handle_event(sqd))
-				break;
-			timeout = jiffies + sqd->sq_thread_idle;
-		}
-
-		cap_entries = !list_is_singular(&sqd->ctx_list);
-		list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
-			int ret = __io_sq_thread(ctx, cap_entries);
-
-			if (!sqt_spin && (ret > 0 || !wq_list_empty(&ctx->iopoll_list)))
-				sqt_spin = true;
-		}
-		if (io_run_task_work())
-			sqt_spin = true;
-
-		if (sqt_spin || !time_after(jiffies, timeout)) {
-			cond_resched();
-			if (sqt_spin)
-				timeout = jiffies + sqd->sq_thread_idle;
-			continue;
-		}
-
-		prepare_to_wait(&sqd->wait, &wait, TASK_INTERRUPTIBLE);
-		if (!io_sqd_events_pending(sqd) && !task_work_pending(current)) {
-			bool needs_sched = true;
-
-			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
-				atomic_or(IORING_SQ_NEED_WAKEUP,
-						&ctx->rings->sq_flags);
-				if ((ctx->flags & IORING_SETUP_IOPOLL) &&
-				    !wq_list_empty(&ctx->iopoll_list)) {
-					needs_sched = false;
-					break;
-				}
-
-				/*
-				 * Ensure the store of the wakeup flag is not
-				 * reordered with the load of the SQ tail
-				 */
-				smp_mb__after_atomic();
-
-				if (io_sqring_entries(ctx)) {
-					needs_sched = false;
-					break;
-				}
-			}
-
-			if (needs_sched) {
-				mutex_unlock(&sqd->lock);
-				schedule();
-				mutex_lock(&sqd->lock);
-			}
-			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
-				atomic_andnot(IORING_SQ_NEED_WAKEUP,
-						&ctx->rings->sq_flags);
-		}
-
-		finish_wait(&sqd->wait, &wait);
-		timeout = jiffies + sqd->sq_thread_idle;
-	}
-
-	io_uring_cancel_generic(true, sqd);
-	sqd->thread = NULL;
-	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
-		atomic_or(IORING_SQ_NEED_WAKEUP, &ctx->rings->sq_flags);
-	io_run_task_work();
-	mutex_unlock(&sqd->lock);
-
-	audit_free(current);
-
-	complete(&sqd->exited);
-	do_exit(0);
-}
-
 struct io_wait_queue {
 	struct wait_queue_entry wq;
 	struct io_ring_ctx *ctx;
@@ -5934,131 +5713,6 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 	return ret;
 }
 
-static void io_sq_thread_unpark(struct io_sq_data *sqd)
-	__releases(&sqd->lock)
-{
-	WARN_ON_ONCE(sqd->thread == current);
-
-	/*
-	 * Do the dance but not conditional clear_bit() because it'd race with
-	 * other threads incrementing park_pending and setting the bit.
-	 */
-	clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	if (atomic_dec_return(&sqd->park_pending))
-		set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	mutex_unlock(&sqd->lock);
-}
-
-static void io_sq_thread_park(struct io_sq_data *sqd)
-	__acquires(&sqd->lock)
-{
-	WARN_ON_ONCE(sqd->thread == current);
-
-	atomic_inc(&sqd->park_pending);
-	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	mutex_lock(&sqd->lock);
-	if (sqd->thread)
-		wake_up_process(sqd->thread);
-}
-
-static void io_sq_thread_stop(struct io_sq_data *sqd)
-{
-	WARN_ON_ONCE(sqd->thread == current);
-	WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state));
-
-	set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-	mutex_lock(&sqd->lock);
-	if (sqd->thread)
-		wake_up_process(sqd->thread);
-	mutex_unlock(&sqd->lock);
-	wait_for_completion(&sqd->exited);
-}
-
-static void io_put_sq_data(struct io_sq_data *sqd)
-{
-	if (refcount_dec_and_test(&sqd->refs)) {
-		WARN_ON_ONCE(atomic_read(&sqd->park_pending));
-
-		io_sq_thread_stop(sqd);
-		kfree(sqd);
-	}
-}
-
-static void io_sq_thread_finish(struct io_ring_ctx *ctx)
-{
-	struct io_sq_data *sqd = ctx->sq_data;
-
-	if (sqd) {
-		io_sq_thread_park(sqd);
-		list_del_init(&ctx->sqd_list);
-		io_sqd_update_thread_idle(sqd);
-		io_sq_thread_unpark(sqd);
-
-		io_put_sq_data(sqd);
-		ctx->sq_data = NULL;
-	}
-}
-
-static struct io_sq_data *io_attach_sq_data(struct io_uring_params *p)
-{
-	struct io_ring_ctx *ctx_attach;
-	struct io_sq_data *sqd;
-	struct fd f;
-
-	f = fdget(p->wq_fd);
-	if (!f.file)
-		return ERR_PTR(-ENXIO);
-	if (f.file->f_op != &io_uring_fops) {
-		fdput(f);
-		return ERR_PTR(-EINVAL);
-	}
-
-	ctx_attach = f.file->private_data;
-	sqd = ctx_attach->sq_data;
-	if (!sqd) {
-		fdput(f);
-		return ERR_PTR(-EINVAL);
-	}
-	if (sqd->task_tgid != current->tgid) {
-		fdput(f);
-		return ERR_PTR(-EPERM);
-	}
-
-	refcount_inc(&sqd->refs);
-	fdput(f);
-	return sqd;
-}
-
-static struct io_sq_data *io_get_sq_data(struct io_uring_params *p,
-					 bool *attached)
-{
-	struct io_sq_data *sqd;
-
-	*attached = false;
-	if (p->flags & IORING_SETUP_ATTACH_WQ) {
-		sqd = io_attach_sq_data(p);
-		if (!IS_ERR(sqd)) {
-			*attached = true;
-			return sqd;
-		}
-		/* fall through for EPERM case, setup new sqd/task */
-		if (PTR_ERR(sqd) != -EPERM)
-			return sqd;
-	}
-
-	sqd = kzalloc(sizeof(*sqd), GFP_KERNEL);
-	if (!sqd)
-		return ERR_PTR(-ENOMEM);
-
-	atomic_set(&sqd->park_pending, 0);
-	refcount_set(&sqd->refs, 1);
-	INIT_LIST_HEAD(&sqd->ctx_list);
-	mutex_init(&sqd->lock);
-	init_waitqueue_head(&sqd->wait);
-	init_completion(&sqd->exited);
-	return sqd;
-}
-
 /*
  * Ensure the UNIX gc is aware of our file set, so we are certain that
  * the io_uring can be safely unregistered on process exit, even if we have
@@ -6495,8 +6149,8 @@ static struct io_wq *io_init_wq_offload(struct io_ring_ctx *ctx,
 	return io_wq_create(concurrency, &data);
 }
 
-static __cold int io_uring_alloc_task_context(struct task_struct *task,
-					      struct io_ring_ctx *ctx)
+__cold int io_uring_alloc_task_context(struct task_struct *task,
+				       struct io_ring_ctx *ctx)
 {
 	struct io_uring_task *tctx;
 	int ret;
@@ -6554,96 +6208,6 @@ void __io_uring_free(struct task_struct *tsk)
 	tsk->io_uring = NULL;
 }
 
-static __cold int io_sq_offload_create(struct io_ring_ctx *ctx,
-				       struct io_uring_params *p)
-{
-	int ret;
-
-	/* Retain compatibility with failing for an invalid attach attempt */
-	if ((ctx->flags & (IORING_SETUP_ATTACH_WQ | IORING_SETUP_SQPOLL)) ==
-				IORING_SETUP_ATTACH_WQ) {
-		struct fd f;
-
-		f = fdget(p->wq_fd);
-		if (!f.file)
-			return -ENXIO;
-		if (f.file->f_op != &io_uring_fops) {
-			fdput(f);
-			return -EINVAL;
-		}
-		fdput(f);
-	}
-	if (ctx->flags & IORING_SETUP_SQPOLL) {
-		struct task_struct *tsk;
-		struct io_sq_data *sqd;
-		bool attached;
-
-		ret = security_uring_sqpoll();
-		if (ret)
-			return ret;
-
-		sqd = io_get_sq_data(p, &attached);
-		if (IS_ERR(sqd)) {
-			ret = PTR_ERR(sqd);
-			goto err;
-		}
-
-		ctx->sq_creds = get_current_cred();
-		ctx->sq_data = sqd;
-		ctx->sq_thread_idle = msecs_to_jiffies(p->sq_thread_idle);
-		if (!ctx->sq_thread_idle)
-			ctx->sq_thread_idle = HZ;
-
-		io_sq_thread_park(sqd);
-		list_add(&ctx->sqd_list, &sqd->ctx_list);
-		io_sqd_update_thread_idle(sqd);
-		/* don't attach to a dying SQPOLL thread, would be racy */
-		ret = (attached && !sqd->thread) ? -ENXIO : 0;
-		io_sq_thread_unpark(sqd);
-
-		if (ret < 0)
-			goto err;
-		if (attached)
-			return 0;
-
-		if (p->flags & IORING_SETUP_SQ_AFF) {
-			int cpu = p->sq_thread_cpu;
-
-			ret = -EINVAL;
-			if (cpu >= nr_cpu_ids || !cpu_online(cpu))
-				goto err_sqpoll;
-			sqd->sq_cpu = cpu;
-		} else {
-			sqd->sq_cpu = -1;
-		}
-
-		sqd->task_pid = current->pid;
-		sqd->task_tgid = current->tgid;
-		tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
-		if (IS_ERR(tsk)) {
-			ret = PTR_ERR(tsk);
-			goto err_sqpoll;
-		}
-
-		sqd->thread = tsk;
-		ret = io_uring_alloc_task_context(tsk, ctx);
-		wake_up_new_task(tsk);
-		if (ret)
-			goto err;
-	} else if (p->flags & IORING_SETUP_SQ_AFF) {
-		/* Can't have SQ_AFF without SQPOLL */
-		ret = -EINVAL;
-		goto err;
-	}
-
-	return 0;
-err_sqpoll:
-	complete(&ctx->sq_data->exited);
-err:
-	io_sq_thread_finish(ctx);
-	return ret;
-}
-
 static inline void __io_unaccount_mem(struct user_struct *user,
 				      unsigned long nr_pages)
 {
@@ -7755,8 +7319,7 @@ static s64 tctx_inflight(struct io_uring_task *tctx, bool tracked)
  * Find any io_uring ctx that this task has registered or done IO on, and cancel
  * requests. @sqd should be not-null IFF it's an SQPOLL thread cancellation.
  */
-static __cold void io_uring_cancel_generic(bool cancel_all,
-					   struct io_sq_data *sqd)
+__cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd)
 {
 	struct io_uring_task *tctx = current->io_uring;
 	struct io_ring_ctx *ctx;
@@ -8034,24 +7597,6 @@ static unsigned long io_uring_nommu_get_unmapped_area(struct file *file,
 
 #endif /* !CONFIG_MMU */
 
-static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
-{
-	DEFINE_WAIT(wait);
-
-	do {
-		if (!io_sqring_full(ctx))
-			break;
-		prepare_to_wait(&ctx->sqo_sq_wait, &wait, TASK_INTERRUPTIBLE);
-
-		if (!io_sqring_full(ctx))
-			break;
-		schedule();
-	} while (!signal_pending(current));
-
-	finish_wait(&ctx->sqo_sq_wait, &wait);
-	return 0;
-}
-
 static int io_validate_ext_arg(unsigned flags, const void __user *argp, size_t argsz)
 {
 	if (flags & IORING_ENTER_EXT_ARG) {
diff --git a/io_uring/io_uring.h b/io_uring/io_uring.h
index e285e12ccbdbb..1da8e66507a35 100644
--- a/io_uring/io_uring.h
+++ b/io_uring/io_uring.h
@@ -64,6 +64,34 @@ static inline void io_commit_cqring(struct io_ring_ctx *ctx)
 	smp_store_release(&ctx->rings->cq.tail, ctx->cached_cq_tail);
 }
 
+static inline bool io_sqring_full(struct io_ring_ctx *ctx)
+{
+	struct io_rings *r = ctx->rings;
+
+	return READ_ONCE(r->sq.tail) - ctx->cached_sq_head == ctx->sq_entries;
+}
+
+static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
+{
+	struct io_rings *rings = ctx->rings;
+
+	/* make sure SQ entry isn't read before tail */
+	return smp_load_acquire(&rings->sq.tail) - ctx->cached_sq_head;
+}
+
+static inline bool io_run_task_work(void)
+{
+	if (test_thread_flag(TIF_NOTIFY_SIGNAL) || task_work_pending(current)) {
+		__set_current_state(TASK_RUNNING);
+		clear_notify_signal();
+		if (task_work_pending(current))
+			task_work_run();
+		return true;
+	}
+
+	return false;
+}
+
 void __io_req_complete(struct io_kiocb *req, unsigned issue_flags);
 void io_req_complete_post(struct io_kiocb *req);
 void __io_req_complete_post(struct io_kiocb *req);
@@ -101,6 +129,12 @@ void io_req_tw_post_queue(struct io_kiocb *req, s32 res, u32 cflags);
 void io_req_task_complete(struct io_kiocb *req, bool *locked);
 void io_req_task_queue_fail(struct io_kiocb *req, int ret);
 int io_try_cancel(struct io_kiocb *req, struct io_cancel_data *cd);
+__cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
+int io_uring_alloc_task_context(struct task_struct *task,
+				struct io_ring_ctx *ctx);
+
+int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr);
+int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin);
 
 void io_free_req(struct io_kiocb *req);
 void io_queue_next(struct io_kiocb *req);
diff --git a/io_uring/sqpoll.c b/io_uring/sqpoll.c
new file mode 100644
index 0000000000000..149d5c976f146
--- /dev/null
+++ b/io_uring/sqpoll.c
@@ -0,0 +1,426 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Contains the core associated with submission side polling of the SQ
+ * ring, offloading submissions from the application to a kernel thread.
+ */
+#include <linux/kernel.h>
+#include <linux/errno.h>
+#include <linux/file.h>
+#include <linux/mm.h>
+#include <linux/slab.h>
+#include <linux/audit.h>
+#include <linux/security.h>
+#include <linux/io_uring.h>
+
+#include <uapi/linux/io_uring.h>
+
+#include "io_uring_types.h"
+#include "io_uring.h"
+#include "sqpoll.h"
+
+#define IORING_SQPOLL_CAP_ENTRIES_VALUE 8
+
+enum {
+	IO_SQ_THREAD_SHOULD_STOP = 0,
+	IO_SQ_THREAD_SHOULD_PARK,
+};
+
+void io_sq_thread_unpark(struct io_sq_data *sqd)
+	__releases(&sqd->lock)
+{
+	WARN_ON_ONCE(sqd->thread == current);
+
+	/*
+	 * Do the dance but not conditional clear_bit() because it'd race with
+	 * other threads incrementing park_pending and setting the bit.
+	 */
+	clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
+	if (atomic_dec_return(&sqd->park_pending))
+		set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
+	mutex_unlock(&sqd->lock);
+}
+
+void io_sq_thread_park(struct io_sq_data *sqd)
+	__acquires(&sqd->lock)
+{
+	WARN_ON_ONCE(sqd->thread == current);
+
+	atomic_inc(&sqd->park_pending);
+	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
+	mutex_lock(&sqd->lock);
+	if (sqd->thread)
+		wake_up_process(sqd->thread);
+}
+
+void io_sq_thread_stop(struct io_sq_data *sqd)
+{
+	WARN_ON_ONCE(sqd->thread == current);
+	WARN_ON_ONCE(test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state));
+
+	set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+	mutex_lock(&sqd->lock);
+	if (sqd->thread)
+		wake_up_process(sqd->thread);
+	mutex_unlock(&sqd->lock);
+	wait_for_completion(&sqd->exited);
+}
+
+void io_put_sq_data(struct io_sq_data *sqd)
+{
+	if (refcount_dec_and_test(&sqd->refs)) {
+		WARN_ON_ONCE(atomic_read(&sqd->park_pending));
+
+		io_sq_thread_stop(sqd);
+		kfree(sqd);
+	}
+}
+
+static __cold void io_sqd_update_thread_idle(struct io_sq_data *sqd)
+{
+	struct io_ring_ctx *ctx;
+	unsigned sq_thread_idle = 0;
+
+	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+		sq_thread_idle = max(sq_thread_idle, ctx->sq_thread_idle);
+	sqd->sq_thread_idle = sq_thread_idle;
+}
+
+void io_sq_thread_finish(struct io_ring_ctx *ctx)
+{
+	struct io_sq_data *sqd = ctx->sq_data;
+
+	if (sqd) {
+		io_sq_thread_park(sqd);
+		list_del_init(&ctx->sqd_list);
+		io_sqd_update_thread_idle(sqd);
+		io_sq_thread_unpark(sqd);
+
+		io_put_sq_data(sqd);
+		ctx->sq_data = NULL;
+	}
+}
+
+static struct io_sq_data *io_attach_sq_data(struct io_uring_params *p)
+{
+	struct io_ring_ctx *ctx_attach;
+	struct io_sq_data *sqd;
+	struct fd f;
+
+	f = fdget(p->wq_fd);
+	if (!f.file)
+		return ERR_PTR(-ENXIO);
+	if (!io_is_uring_fops(f.file)) {
+		fdput(f);
+		return ERR_PTR(-EINVAL);
+	}
+
+	ctx_attach = f.file->private_data;
+	sqd = ctx_attach->sq_data;
+	if (!sqd) {
+		fdput(f);
+		return ERR_PTR(-EINVAL);
+	}
+	if (sqd->task_tgid != current->tgid) {
+		fdput(f);
+		return ERR_PTR(-EPERM);
+	}
+
+	refcount_inc(&sqd->refs);
+	fdput(f);
+	return sqd;
+}
+
+static struct io_sq_data *io_get_sq_data(struct io_uring_params *p,
+					 bool *attached)
+{
+	struct io_sq_data *sqd;
+
+	*attached = false;
+	if (p->flags & IORING_SETUP_ATTACH_WQ) {
+		sqd = io_attach_sq_data(p);
+		if (!IS_ERR(sqd)) {
+			*attached = true;
+			return sqd;
+		}
+		/* fall through for EPERM case, setup new sqd/task */
+		if (PTR_ERR(sqd) != -EPERM)
+			return sqd;
+	}
+
+	sqd = kzalloc(sizeof(*sqd), GFP_KERNEL);
+	if (!sqd)
+		return ERR_PTR(-ENOMEM);
+
+	atomic_set(&sqd->park_pending, 0);
+	refcount_set(&sqd->refs, 1);
+	INIT_LIST_HEAD(&sqd->ctx_list);
+	mutex_init(&sqd->lock);
+	init_waitqueue_head(&sqd->wait);
+	init_completion(&sqd->exited);
+	return sqd;
+}
+
+static inline bool io_sqd_events_pending(struct io_sq_data *sqd)
+{
+	return READ_ONCE(sqd->state);
+}
+
+static int __io_sq_thread(struct io_ring_ctx *ctx, bool cap_entries)
+{
+	unsigned int to_submit;
+	int ret = 0;
+
+	to_submit = io_sqring_entries(ctx);
+	/* if we're handling multiple rings, cap submit size for fairness */
+	if (cap_entries && to_submit > IORING_SQPOLL_CAP_ENTRIES_VALUE)
+		to_submit = IORING_SQPOLL_CAP_ENTRIES_VALUE;
+
+	if (!wq_list_empty(&ctx->iopoll_list) || to_submit) {
+		const struct cred *creds = NULL;
+
+		if (ctx->sq_creds != current_cred())
+			creds = override_creds(ctx->sq_creds);
+
+		mutex_lock(&ctx->uring_lock);
+		if (!wq_list_empty(&ctx->iopoll_list))
+			io_do_iopoll(ctx, true);
+
+		/*
+		 * Don't submit if refs are dying, good for io_uring_register(),
+		 * but also it is relied upon by io_ring_exit_work()
+		 */
+		if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)) &&
+		    !(ctx->flags & IORING_SETUP_R_DISABLED))
+			ret = io_submit_sqes(ctx, to_submit);
+		mutex_unlock(&ctx->uring_lock);
+
+		if (to_submit && wq_has_sleeper(&ctx->sqo_sq_wait))
+			wake_up(&ctx->sqo_sq_wait);
+		if (creds)
+			revert_creds(creds);
+	}
+
+	return ret;
+}
+
+static bool io_sqd_handle_event(struct io_sq_data *sqd)
+{
+	bool did_sig = false;
+	struct ksignal ksig;
+
+	if (test_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state) ||
+	    signal_pending(current)) {
+		mutex_unlock(&sqd->lock);
+		if (signal_pending(current))
+			did_sig = get_signal(&ksig);
+		cond_resched();
+		mutex_lock(&sqd->lock);
+	}
+	return did_sig || test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
+}
+
+static int io_sq_thread(void *data)
+{
+	struct io_sq_data *sqd = data;
+	struct io_ring_ctx *ctx;
+	unsigned long timeout = 0;
+	char buf[TASK_COMM_LEN];
+	DEFINE_WAIT(wait);
+
+	snprintf(buf, sizeof(buf), "iou-sqp-%d", sqd->task_pid);
+	set_task_comm(current, buf);
+
+	if (sqd->sq_cpu != -1)
+		set_cpus_allowed_ptr(current, cpumask_of(sqd->sq_cpu));
+	else
+		set_cpus_allowed_ptr(current, cpu_online_mask);
+	current->flags |= PF_NO_SETAFFINITY;
+
+	audit_alloc_kernel(current);
+
+	mutex_lock(&sqd->lock);
+	while (1) {
+		bool cap_entries, sqt_spin = false;
+
+		if (io_sqd_events_pending(sqd) || signal_pending(current)) {
+			if (io_sqd_handle_event(sqd))
+				break;
+			timeout = jiffies + sqd->sq_thread_idle;
+		}
+
+		cap_entries = !list_is_singular(&sqd->ctx_list);
+		list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
+			int ret = __io_sq_thread(ctx, cap_entries);
+
+			if (!sqt_spin && (ret > 0 || !wq_list_empty(&ctx->iopoll_list)))
+				sqt_spin = true;
+		}
+		if (io_run_task_work())
+			sqt_spin = true;
+
+		if (sqt_spin || !time_after(jiffies, timeout)) {
+			cond_resched();
+			if (sqt_spin)
+				timeout = jiffies + sqd->sq_thread_idle;
+			continue;
+		}
+
+		prepare_to_wait(&sqd->wait, &wait, TASK_INTERRUPTIBLE);
+		if (!io_sqd_events_pending(sqd) && !task_work_pending(current)) {
+			bool needs_sched = true;
+
+			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
+				atomic_or(IORING_SQ_NEED_WAKEUP,
+						&ctx->rings->sq_flags);
+				if ((ctx->flags & IORING_SETUP_IOPOLL) &&
+				    !wq_list_empty(&ctx->iopoll_list)) {
+					needs_sched = false;
+					break;
+				}
+
+				/*
+				 * Ensure the store of the wakeup flag is not
+				 * reordered with the load of the SQ tail
+				 */
+				smp_mb__after_atomic();
+
+				if (io_sqring_entries(ctx)) {
+					needs_sched = false;
+					break;
+				}
+			}
+
+			if (needs_sched) {
+				mutex_unlock(&sqd->lock);
+				schedule();
+				mutex_lock(&sqd->lock);
+			}
+			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+				atomic_andnot(IORING_SQ_NEED_WAKEUP,
+						&ctx->rings->sq_flags);
+		}
+
+		finish_wait(&sqd->wait, &wait);
+		timeout = jiffies + sqd->sq_thread_idle;
+	}
+
+	io_uring_cancel_generic(true, sqd);
+	sqd->thread = NULL;
+	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+		atomic_or(IORING_SQ_NEED_WAKEUP, &ctx->rings->sq_flags);
+	io_run_task_work();
+	mutex_unlock(&sqd->lock);
+
+	audit_free(current);
+
+	complete(&sqd->exited);
+	do_exit(0);
+}
+
+int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
+{
+	DEFINE_WAIT(wait);
+
+	do {
+		if (!io_sqring_full(ctx))
+			break;
+		prepare_to_wait(&ctx->sqo_sq_wait, &wait, TASK_INTERRUPTIBLE);
+
+		if (!io_sqring_full(ctx))
+			break;
+		schedule();
+	} while (!signal_pending(current));
+
+	finish_wait(&ctx->sqo_sq_wait, &wait);
+	return 0;
+}
+
+__cold int io_sq_offload_create(struct io_ring_ctx *ctx,
+				struct io_uring_params *p)
+{
+	int ret;
+
+	/* Retain compatibility with failing for an invalid attach attempt */
+	if ((ctx->flags & (IORING_SETUP_ATTACH_WQ | IORING_SETUP_SQPOLL)) ==
+				IORING_SETUP_ATTACH_WQ) {
+		struct fd f;
+
+		f = fdget(p->wq_fd);
+		if (!f.file)
+			return -ENXIO;
+		if (!io_is_uring_fops(f.file)) {
+			fdput(f);
+			return -EINVAL;
+		}
+		fdput(f);
+	}
+	if (ctx->flags & IORING_SETUP_SQPOLL) {
+		struct task_struct *tsk;
+		struct io_sq_data *sqd;
+		bool attached;
+
+		ret = security_uring_sqpoll();
+		if (ret)
+			return ret;
+
+		sqd = io_get_sq_data(p, &attached);
+		if (IS_ERR(sqd)) {
+			ret = PTR_ERR(sqd);
+			goto err;
+		}
+
+		ctx->sq_creds = get_current_cred();
+		ctx->sq_data = sqd;
+		ctx->sq_thread_idle = msecs_to_jiffies(p->sq_thread_idle);
+		if (!ctx->sq_thread_idle)
+			ctx->sq_thread_idle = HZ;
+
+		io_sq_thread_park(sqd);
+		list_add(&ctx->sqd_list, &sqd->ctx_list);
+		io_sqd_update_thread_idle(sqd);
+		/* don't attach to a dying SQPOLL thread, would be racy */
+		ret = (attached && !sqd->thread) ? -ENXIO : 0;
+		io_sq_thread_unpark(sqd);
+
+		if (ret < 0)
+			goto err;
+		if (attached)
+			return 0;
+
+		if (p->flags & IORING_SETUP_SQ_AFF) {
+			int cpu = p->sq_thread_cpu;
+
+			ret = -EINVAL;
+			if (cpu >= nr_cpu_ids || !cpu_online(cpu))
+				goto err_sqpoll;
+			sqd->sq_cpu = cpu;
+		} else {
+			sqd->sq_cpu = -1;
+		}
+
+		sqd->task_pid = current->pid;
+		sqd->task_tgid = current->tgid;
+		tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
+		if (IS_ERR(tsk)) {
+			ret = PTR_ERR(tsk);
+			goto err_sqpoll;
+		}
+
+		sqd->thread = tsk;
+		ret = io_uring_alloc_task_context(tsk, ctx);
+		wake_up_new_task(tsk);
+		if (ret)
+			goto err;
+	} else if (p->flags & IORING_SETUP_SQ_AFF) {
+		/* Can't have SQ_AFF without SQPOLL */
+		ret = -EINVAL;
+		goto err;
+	}
+
+	return 0;
+err_sqpoll:
+	complete(&ctx->sq_data->exited);
+err:
+	io_sq_thread_finish(ctx);
+	return ret;
+}
diff --git a/io_uring/sqpoll.h b/io_uring/sqpoll.h
new file mode 100644
index 0000000000000..0c3fbcd1f583f
--- /dev/null
+++ b/io_uring/sqpoll.h
@@ -0,0 +1,29 @@
+// SPDX-License-Identifier: GPL-2.0
+
+struct io_sq_data {
+	refcount_t		refs;
+	atomic_t		park_pending;
+	struct mutex		lock;
+
+	/* ctx's that are using this sqd */
+	struct list_head	ctx_list;
+
+	struct task_struct	*thread;
+	struct wait_queue_head	wait;
+
+	unsigned		sq_thread_idle;
+	int			sq_cpu;
+	pid_t			task_pid;
+	pid_t			task_tgid;
+
+	unsigned long		state;
+	struct completion	exited;
+};
+
+int io_sq_offload_create(struct io_ring_ctx *ctx, struct io_uring_params *p);
+void io_sq_thread_finish(struct io_ring_ctx *ctx);
+void io_sq_thread_stop(struct io_sq_data *sqd);
+void io_sq_thread_park(struct io_sq_data *sqd);
+void io_sq_thread_unpark(struct io_sq_data *sqd);
+void io_put_sq_data(struct io_sq_data *sqd);
+int io_sqpoll_wait_sq(struct io_ring_ctx *ctx);
-- 
2.30.2