static int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
                                        void __user *arg)
 {
-       struct io_uring_task *tctx = current->io_uring;
+       struct io_uring_task *tctx = NULL;
+       struct io_sq_data *sqd = NULL;
        __u32 new_count[2];
        int i, ret;
 
-       if (!tctx || !tctx->io_wq)
-               return -EINVAL;
        if (copy_from_user(new_count, arg, sizeof(new_count)))
                return -EFAULT;
        for (i = 0; i < ARRAY_SIZE(new_count); i++)
                if (new_count[i] > INT_MAX)
                        return -EINVAL;
 
+       if (ctx->flags & IORING_SETUP_SQPOLL) {
+               sqd = ctx->sq_data;
+               if (sqd) {
+                       mutex_lock(&sqd->lock);
+                       tctx = sqd->thread->io_uring;
+               }
+       } else {
+               tctx = current->io_uring;
+       }
+
+       ret = -EINVAL;
+       if (!tctx || !tctx->io_wq)
+               goto err;
+
        ret = io_wq_max_workers(tctx->io_wq, new_count);
        if (ret)
-               return ret;
+               goto err;
+
+       if (sqd)
+               mutex_unlock(&sqd->lock);
 
        if (copy_to_user(arg, new_count, sizeof(new_count)))
                return -EFAULT;
 
        return 0;
+err:
+       if (sqd)
+               mutex_unlock(&sqd->lock);
+       return ret;
 }
 
 static bool io_register_op_must_quiesce(int op)