#include <linux/rculist_nulls.h>
 #include <linux/fs_struct.h>
 #include <linux/task_work.h>
+#include <linux/blk-cgroup.h>
 
 #include "io-wq.h"
 
 
        struct rcu_head rcu;
        struct mm_struct *mm;
+#ifdef CONFIG_BLK_CGROUP
+       struct cgroup_subsys_state *blkcg_css;
+#endif
        const struct cred *cur_creds;
        const struct cred *saved_creds;
        struct files_struct *restore_files;
                worker->mm = NULL;
        }
 
+#ifdef CONFIG_BLK_CGROUP
+       if (worker->blkcg_css) {
+               kthread_associate_blkcg(NULL);
+               worker->blkcg_css = NULL;
+       }
+#endif
+
        return dropped_lock;
 }
 
        work->flags |= IO_WQ_WORK_CANCEL;
 }
 
+static inline void io_wq_switch_blkcg(struct io_worker *worker,
+                                     struct io_wq_work *work)
+{
+#ifdef CONFIG_BLK_CGROUP
+       if (work->blkcg_css != worker->blkcg_css) {
+               kthread_associate_blkcg(work->blkcg_css);
+               worker->blkcg_css = work->blkcg_css;
+       }
+#endif
+}
+
 static void io_wq_switch_creds(struct io_worker *worker,
                               struct io_wq_work *work)
 {
        if (worker->cur_creds != work->creds)
                io_wq_switch_creds(worker, work);
        current->signal->rlim[RLIMIT_FSIZE].rlim_cur = work->fsize;
+       io_wq_switch_blkcg(worker, work);
 }
 
 static void io_assign_current_work(struct io_worker *worker,
 
 #include <linux/task_work.h>
 #include <linux/pagemap.h>
 #include <linux/io_uring.h>
+#include <linux/blk-cgroup.h>
 
 #define CREATE_TRACE_POINTS
 #include <trace/events/io_uring.h>
        /* Only used for accounting purposes */
        struct mm_struct        *mm_account;
 
+#ifdef CONFIG_BLK_CGROUP
+       struct cgroup_subsys_state      *sqo_blkcg_css;
+#endif
+
        struct io_sq_data       *sq_data;       /* if using sq thread polling */
 
        struct wait_queue_head  sqo_sq_wait;
        unsigned                needs_fsize : 1;
        /* must always have async data allocated */
        unsigned                needs_async_data : 1;
+       /* needs blkcg context, issues async io potentially */
+       unsigned                needs_blkcg : 1;
        /* size of async data needed, if any */
        unsigned short          async_size;
 };
                .pollin                 = 1,
                .buffer_select          = 1,
                .needs_async_data       = 1,
+               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
        },
        [IORING_OP_WRITEV] = {
                .pollout                = 1,
                .needs_fsize            = 1,
                .needs_async_data       = 1,
+               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
        },
        [IORING_OP_FSYNC] = {
                .needs_file             = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_READ_FIXED] = {
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
                .pollin                 = 1,
+               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
        },
        [IORING_OP_WRITE_FIXED] = {
                .unbound_nonreg_file    = 1,
                .pollout                = 1,
                .needs_fsize            = 1,
+               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
        },
        [IORING_OP_POLL_ADD] = {
        [IORING_OP_POLL_REMOVE] = {},
        [IORING_OP_SYNC_FILE_RANGE] = {
                .needs_file             = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_SENDMSG] = {
                .needs_mm               = 1,
                .needs_fs               = 1,
                .pollout                = 1,
                .needs_async_data       = 1,
+               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_msghdr),
        },
        [IORING_OP_RECVMSG] = {
                .pollin                 = 1,
                .buffer_select          = 1,
                .needs_async_data       = 1,
+               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_msghdr),
        },
        [IORING_OP_TIMEOUT] = {
        [IORING_OP_FALLOCATE] = {
                .needs_file             = 1,
                .needs_fsize            = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_OPENAT] = {
                .file_table             = 1,
                .needs_fs               = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_CLOSE] = {
                .needs_file             = 1,
                .needs_file_no_error    = 1,
                .file_table             = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_FILES_UPDATE] = {
                .needs_mm               = 1,
                .needs_mm               = 1,
                .needs_fs               = 1,
                .file_table             = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_READ] = {
                .needs_mm               = 1,
                .unbound_nonreg_file    = 1,
                .pollin                 = 1,
                .buffer_select          = 1,
+               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
        },
        [IORING_OP_WRITE] = {
                .unbound_nonreg_file    = 1,
                .pollout                = 1,
                .needs_fsize            = 1,
+               .needs_blkcg            = 1,
                .async_size             = sizeof(struct io_async_rw),
        },
        [IORING_OP_FADVISE] = {
                .needs_file             = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_MADVISE] = {
                .needs_mm               = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_SEND] = {
                .needs_mm               = 1,
                .needs_file             = 1,
                .unbound_nonreg_file    = 1,
                .pollout                = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_RECV] = {
                .needs_mm               = 1,
                .unbound_nonreg_file    = 1,
                .pollin                 = 1,
                .buffer_select          = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_OPENAT2] = {
                .file_table             = 1,
                .needs_fs               = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_EPOLL_CTL] = {
                .unbound_nonreg_file    = 1,
                .needs_file             = 1,
                .hash_reg_file          = 1,
                .unbound_nonreg_file    = 1,
+               .needs_blkcg            = 1,
        },
        [IORING_OP_PROVIDE_BUFFERS] = {},
        [IORING_OP_REMOVE_BUFFERS] = {},
        return __io_sq_thread_acquire_mm(ctx);
 }
 
+static void io_sq_thread_associate_blkcg(struct io_ring_ctx *ctx,
+                                        struct cgroup_subsys_state **cur_css)
+
+{
+#ifdef CONFIG_BLK_CGROUP
+       /* puts the old one when swapping */
+       if (*cur_css != ctx->sqo_blkcg_css) {
+               kthread_associate_blkcg(ctx->sqo_blkcg_css);
+               *cur_css = ctx->sqo_blkcg_css;
+       }
+#endif
+}
+
+static void io_sq_thread_unassociate_blkcg(void)
+{
+#ifdef CONFIG_BLK_CGROUP
+       kthread_associate_blkcg(NULL);
+#endif
+}
+
 static inline void req_set_fail_links(struct io_kiocb *req)
 {
        if ((req->flags & (REQ_F_LINK | REQ_F_HARDLINK)) == REQ_F_LINK)
                mmdrop(req->work.mm);
                req->work.mm = NULL;
        }
+#ifdef CONFIG_BLK_CGROUP
+       if (req->work.blkcg_css)
+               css_put(req->work.blkcg_css);
+#endif
        if (req->work.creds) {
                put_cred(req->work.creds);
                req->work.creds = NULL;
                mmgrab(current->mm);
                req->work.mm = current->mm;
        }
+#ifdef CONFIG_BLK_CGROUP
+       if (!req->work.blkcg_css && def->needs_blkcg) {
+               rcu_read_lock();
+               req->work.blkcg_css = blkcg_css();
+               /*
+                * This should be rare, either the cgroup is dying or the task
+                * is moving cgroups. Just punt to root for the handful of ios.
+                */
+               if (!css_tryget_online(req->work.blkcg_css))
+                       req->work.blkcg_css = NULL;
+               rcu_read_unlock();
+       }
+#endif
        if (!req->work.creds)
                req->work.creds = get_current_cred();
        if (!req->work.fs && def->needs_fs) {
 
 static int io_sq_thread(void *data)
 {
+       struct cgroup_subsys_state *cur_css = NULL;
        const struct cred *old_cred = NULL;
        struct io_sq_data *sqd = data;
        struct io_ring_ctx *ctx;
                                        revert_creds(old_cred);
                                old_cred = override_creds(ctx->creds);
                        }
+                       io_sq_thread_associate_blkcg(ctx, &cur_css);
 
                        ret |= __io_sq_thread(ctx, start_jiffies, cap_entries);
 
 
        io_run_task_work();
 
+       if (cur_css)
+               io_sq_thread_unassociate_blkcg();
        if (old_cred)
                revert_creds(old_cred);
 
                ctx->mm_account = NULL;
        }
 
+#ifdef CONFIG_BLK_CGROUP
+       if (ctx->sqo_blkcg_css)
+               css_put(ctx->sqo_blkcg_css);
+#endif
+
        io_sqe_files_unregister(ctx);
        io_eventfd_unregister(ctx);
        io_destroy_buffers(ctx);
        mmgrab(current->mm);
        ctx->mm_account = current->mm;
 
+#ifdef CONFIG_BLK_CGROUP
+       /*
+        * The sq thread will belong to the original cgroup it was inited in.
+        * If the cgroup goes offline (e.g. disabling the io controller), then
+        * issued bios will be associated with the closest cgroup later in the
+        * block layer.
+        */
+       rcu_read_lock();
+       ctx->sqo_blkcg_css = blkcg_css();
+       ret = css_tryget_online(ctx->sqo_blkcg_css);
+       rcu_read_unlock();
+       if (!ret) {
+               /* don't init against a dying cgroup, have the user try again */
+               ctx->sqo_blkcg_css = NULL;
+               ret = -ENODEV;
+               goto err;
+       }
+#endif
+
        /*
         * Account memory _before_ installing the file descriptor. Once
         * the descriptor is installed, it can get closed at any time. Also