io_uring/sqpoll: manage task_work privately
authorJens Axboe <axboe@kernel.dk>
Fri, 2 Feb 2024 17:20:05 +0000 (10:20 -0700)
committerJens Axboe <axboe@kernel.dk>
Thu, 8 Feb 2024 20:27:06 +0000 (13:27 -0700)
Decouple from task_work running, and cap the number of entries we process
at the time. If we exceed that number, push remaining entries to a retry
list that we'll process first next time.

We cap the number of entries to process at 8, which is fairly random.
We just want to get enough per-ctx batching here, while not processing
endlessly.

Since we manually run PF_IO_WORKER related task_work anyway as the task
never exits to userspace, with this we no longer need to add an actual
task_work item to the per-process list.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
io_uring/io_uring.c
io_uring/io_uring.h
io_uring/sqpoll.c

index bfd2f0fff153cf48c4812aa916339a05bd7b4378..f1e3646c727c1e206bdbb556068f966cf9de4f9f 100644 (file)
@@ -1173,7 +1173,14 @@ static void ctx_flush_and_put(struct io_ring_ctx *ctx, struct io_tw_state *ts)
        percpu_ref_put(&ctx->refs);
 }
 
-static void handle_tw_list(struct llist_node *node, unsigned int *count)
+/*
+ * Run queued task_work, returning the number of entries processed in *count.
+ * If more entries than max_entries are available, stop processing once this
+ * is reached and return the rest of the list.
+ */
+struct llist_node *io_handle_tw_list(struct llist_node *node,
+                                    unsigned int *count,
+                                    unsigned int max_entries)
 {
        struct io_ring_ctx *ctx = NULL;
        struct io_tw_state ts = { };
@@ -1200,9 +1207,10 @@ static void handle_tw_list(struct llist_node *node, unsigned int *count)
                        ctx = NULL;
                        cond_resched();
                }
-       } while (node);
+       } while (node && *count < max_entries);
 
        ctx_flush_and_put(ctx, &ts);
+       return node;
 }
 
 /**
@@ -1247,27 +1255,41 @@ static __cold void io_fallback_tw(struct io_uring_task *tctx, bool sync)
        }
 }
 
-void tctx_task_work(struct callback_head *cb)
+struct llist_node *tctx_task_work_run(struct io_uring_task *tctx,
+                                     unsigned int max_entries,
+                                     unsigned int *count)
 {
-       struct io_uring_task *tctx = container_of(cb, struct io_uring_task,
-                                                 task_work);
        struct llist_node *node;
-       unsigned int count = 0;
 
        if (unlikely(current->flags & PF_EXITING)) {
                io_fallback_tw(tctx, true);
-               return;
+               return NULL;
        }
 
        node = llist_del_all(&tctx->task_list);
-       if (node)
-               handle_tw_list(llist_reverse_order(node), &count);
+       if (node) {
+               node = llist_reverse_order(node);
+               node = io_handle_tw_list(node, count, max_entries);
+       }
 
        /* relaxed read is enough as only the task itself sets ->in_cancel */
        if (unlikely(atomic_read(&tctx->in_cancel)))
                io_uring_drop_tctx_refs(current);
 
-       trace_io_uring_task_work_run(tctx, count);
+       trace_io_uring_task_work_run(tctx, *count);
+       return node;
+}
+
+void tctx_task_work(struct callback_head *cb)
+{
+       struct io_uring_task *tctx;
+       struct llist_node *ret;
+       unsigned int count = 0;
+
+       tctx = container_of(cb, struct io_uring_task, task_work);
+       ret = tctx_task_work_run(tctx, UINT_MAX, &count);
+       /* can't happen */
+       WARN_ON_ONCE(ret);
 }
 
 static inline void io_req_local_work_add(struct io_kiocb *req, unsigned flags)
@@ -1350,6 +1372,10 @@ static void io_req_normal_work_add(struct io_kiocb *req)
        if (ctx->flags & IORING_SETUP_TASKRUN_FLAG)
                atomic_or(IORING_SQ_TASKRUN, &ctx->rings->sq_flags);
 
+       /* SQPOLL doesn't need the task_work added, it'll run it itself */
+       if (ctx->flags & IORING_SETUP_SQPOLL)
+               return;
+
        if (likely(!task_work_add(req->task, &tctx->task_work, ctx->notify_method)))
                return;
 
index 46795ee462dfbdd6f8eed870fe305e7110be6297..38af827887860a6763f08da1e111a13150069d07 100644 (file)
@@ -57,6 +57,8 @@ void io_queue_iowq(struct io_kiocb *req, struct io_tw_state *ts_dont_use);
 void io_req_task_complete(struct io_kiocb *req, struct io_tw_state *ts);
 void io_req_task_queue_fail(struct io_kiocb *req, int ret);
 void io_req_task_submit(struct io_kiocb *req, struct io_tw_state *ts);
+struct llist_node *io_handle_tw_list(struct llist_node *node, unsigned int *count, unsigned int max_entries);
+struct llist_node *tctx_task_work_run(struct io_uring_task *tctx, unsigned int max_entries, unsigned int *count);
 void tctx_task_work(struct callback_head *cb);
 __cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
 int io_uring_alloc_task_context(struct task_struct *task,
@@ -275,6 +277,8 @@ static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
 
 static inline int io_run_task_work(void)
 {
+       bool ret = false;
+
        /*
         * Always check-and-clear the task_work notification signal. With how
         * signaling works for task_work, we can find it set with nothing to
@@ -286,18 +290,26 @@ static inline int io_run_task_work(void)
         * PF_IO_WORKER never returns to userspace, so check here if we have
         * notify work that needs processing.
         */
-       if (current->flags & PF_IO_WORKER &&
-           test_thread_flag(TIF_NOTIFY_RESUME)) {
-               __set_current_state(TASK_RUNNING);
-               resume_user_mode_work(NULL);
+       if (current->flags & PF_IO_WORKER) {
+               if (test_thread_flag(TIF_NOTIFY_RESUME)) {
+                       __set_current_state(TASK_RUNNING);
+                       resume_user_mode_work(NULL);
+               }
+               if (current->io_uring) {
+                       unsigned int count = 0;
+
+                       tctx_task_work_run(current->io_uring, UINT_MAX, &count);
+                       if (count)
+                               ret = true;
+               }
        }
        if (task_work_pending(current)) {
                __set_current_state(TASK_RUNNING);
                task_work_run();
-               return 1;
+               ret = true;
        }
 
-       return 0;
+       return ret;
 }
 
 static inline bool io_task_work_pending(struct io_ring_ctx *ctx)
index 65b5dbe3c850ed564432c76f17e64739d430f2fe..28bf0e085d310899f3faedff0a6193f7fb8a0596 100644 (file)
@@ -18,6 +18,7 @@
 #include "sqpoll.h"
 
 #define IORING_SQPOLL_CAP_ENTRIES_VALUE 8
+#define IORING_TW_CAP_ENTRIES_VALUE    8
 
 enum {
        IO_SQ_THREAD_SHOULD_STOP = 0,
@@ -219,8 +220,31 @@ static bool io_sqd_handle_event(struct io_sq_data *sqd)
        return did_sig || test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
 }
 
+/*
+ * Run task_work, processing the retry_list first. The retry_list holds
+ * entries that we passed on in the previous run, if we had more task_work
+ * than we were asked to process. Newly queued task_work isn't run until the
+ * retry list has been fully processed.
+ */
+static unsigned int io_sq_tw(struct llist_node **retry_list, int max_entries)
+{
+       struct io_uring_task *tctx = current->io_uring;
+       unsigned int count = 0;
+
+       if (*retry_list) {
+               *retry_list = io_handle_tw_list(*retry_list, &count, max_entries);
+               if (count >= max_entries)
+                       return count;
+               max_entries -= count;
+       }
+
+       *retry_list = tctx_task_work_run(tctx, max_entries, &count);
+       return count;
+}
+
 static int io_sq_thread(void *data)
 {
+       struct llist_node *retry_list = NULL;
        struct io_sq_data *sqd = data;
        struct io_ring_ctx *ctx;
        unsigned long timeout = 0;
@@ -257,7 +281,7 @@ static int io_sq_thread(void *data)
                        if (!sqt_spin && (ret > 0 || !wq_list_empty(&ctx->iopoll_list)))
                                sqt_spin = true;
                }
-               if (io_run_task_work())
+               if (io_sq_tw(&retry_list, IORING_TW_CAP_ENTRIES_VALUE))
                        sqt_spin = true;
 
                if (sqt_spin || !time_after(jiffies, timeout)) {
@@ -312,6 +336,9 @@ static int io_sq_thread(void *data)
                timeout = jiffies + sqd->sq_thread_idle;
        }
 
+       if (retry_list)
+               io_sq_tw(&retry_list, UINT_MAX);
+
        io_uring_cancel_generic(true, sqd);
        sqd->thread = NULL;
        list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)