return NULL;
 }
 
-/*
- * Charge the slab page belonging to the non-root kmem_cache.
- * Can be called for non-root kmem_caches only.
- */
-static __always_inline int memcg_charge_slab(struct page *page,
-                                            gfp_t gfp, int order,
-                                            struct kmem_cache *s)
-{
-       int nr_pages = 1 << order;
-       struct mem_cgroup *memcg;
-       struct lruvec *lruvec;
-       int ret;
-
-       rcu_read_lock();
-       memcg = READ_ONCE(s->memcg_params.memcg);
-       while (memcg && !css_tryget_online(&memcg->css))
-               memcg = parent_mem_cgroup(memcg);
-       rcu_read_unlock();
-
-       if (unlikely(!memcg || mem_cgroup_is_root(memcg))) {
-               mod_node_page_state(page_pgdat(page), cache_vmstat_idx(s),
-                                   nr_pages << PAGE_SHIFT);
-               percpu_ref_get_many(&s->memcg_params.refcnt, nr_pages);
-               return 0;
-       }
-
-       ret = memcg_kmem_charge(memcg, gfp, nr_pages);
-       if (ret)
-               goto out;
-
-       lruvec = mem_cgroup_lruvec(memcg, page_pgdat(page));
-       mod_lruvec_state(lruvec, cache_vmstat_idx(s), nr_pages << PAGE_SHIFT);
-
-       percpu_ref_get_many(&s->memcg_params.refcnt, nr_pages);
-out:
-       css_put(&memcg->css);
-       return ret;
-}
-
-/*
- * Uncharge a slab page belonging to a non-root kmem_cache.
- * Can be called for non-root kmem_caches only.
- */
-static __always_inline void memcg_uncharge_slab(struct page *page, int order,
-                                               struct kmem_cache *s)
-{
-       int nr_pages = 1 << order;
-       struct mem_cgroup *memcg;
-       struct lruvec *lruvec;
-
-       rcu_read_lock();
-       memcg = READ_ONCE(s->memcg_params.memcg);
-       if (likely(!mem_cgroup_is_root(memcg))) {
-               lruvec = mem_cgroup_lruvec(memcg, page_pgdat(page));
-               mod_lruvec_state(lruvec, cache_vmstat_idx(s),
-                                -(nr_pages << PAGE_SHIFT));
-               memcg_kmem_uncharge(memcg, nr_pages);
-       } else {
-               mod_node_page_state(page_pgdat(page), cache_vmstat_idx(s),
-                                   -(nr_pages << PAGE_SHIFT));
-       }
-       rcu_read_unlock();
-
-       percpu_ref_put_many(&s->memcg_params.refcnt, nr_pages);
-}
-
 static inline int memcg_alloc_page_obj_cgroups(struct page *page,
                                               struct kmem_cache *s, gfp_t gfp)
 {
        page->obj_cgroups = NULL;
 }
 
+static inline size_t obj_full_size(struct kmem_cache *s)
+{
+       /*
+        * For each accounted object there is an extra space which is used
+        * to store obj_cgroup membership. Charge it too.
+        */
+       return s->size + sizeof(struct obj_cgroup *);
+}
+
+static inline struct kmem_cache *memcg_slab_pre_alloc_hook(struct kmem_cache *s,
+                                               struct obj_cgroup **objcgp,
+                                               size_t objects, gfp_t flags)
+{
+       struct kmem_cache *cachep;
+
+       cachep = memcg_kmem_get_cache(s, objcgp);
+       if (is_root_cache(cachep))
+               return s;
+
+       if (obj_cgroup_charge(*objcgp, flags, objects * obj_full_size(s))) {
+               obj_cgroup_put(*objcgp);
+               memcg_kmem_put_cache(cachep);
+               cachep = NULL;
+       }
+
+       return cachep;
+}
+
+static inline void mod_objcg_state(struct obj_cgroup *objcg,
+                                  struct pglist_data *pgdat,
+                                  int idx, int nr)
+{
+       struct mem_cgroup *memcg;
+       struct lruvec *lruvec;
+
+       rcu_read_lock();
+       memcg = obj_cgroup_memcg(objcg);
+       lruvec = mem_cgroup_lruvec(memcg, pgdat);
+       mod_memcg_lruvec_state(lruvec, idx, nr);
+       rcu_read_unlock();
+}
+
 static inline void memcg_slab_post_alloc_hook(struct kmem_cache *s,
                                              struct obj_cgroup *objcg,
                                              size_t size, void **p)
                        off = obj_to_index(s, page, p[i]);
                        obj_cgroup_get(objcg);
                        page_obj_cgroups(page)[off] = objcg;
+                       mod_objcg_state(objcg, page_pgdat(page),
+                                       cache_vmstat_idx(s), obj_full_size(s));
+               } else {
+                       obj_cgroup_uncharge(objcg, obj_full_size(s));
                }
        }
        obj_cgroup_put(objcg);
        off = obj_to_index(s, page, p);
        objcg = page_obj_cgroups(page)[off];
        page_obj_cgroups(page)[off] = NULL;
+
+       obj_cgroup_uncharge(objcg, obj_full_size(s));
+       mod_objcg_state(objcg, page_pgdat(page), cache_vmstat_idx(s),
+                       -obj_full_size(s));
+
        obj_cgroup_put(objcg);
 }
 
        return NULL;
 }
 
-static inline int memcg_charge_slab(struct page *page, gfp_t gfp, int order,
-                                   struct kmem_cache *s)
-{
-       return 0;
-}
-
-static inline void memcg_uncharge_slab(struct page *page, int order,
-                                      struct kmem_cache *s)
-{
-}
-
 static inline int memcg_alloc_page_obj_cgroups(struct page *page,
                                               struct kmem_cache *s, gfp_t gfp)
 {
 {
 }
 
+static inline struct kmem_cache *memcg_slab_pre_alloc_hook(struct kmem_cache *s,
+                                               struct obj_cgroup **objcgp,
+                                               size_t objects, gfp_t flags)
+{
+       return NULL;
+}
+
 static inline void memcg_slab_post_alloc_hook(struct kmem_cache *s,
                                              struct obj_cgroup *objcg,
                                              size_t size, void **p)
                                            gfp_t gfp, int order,
                                            struct kmem_cache *s)
 {
-       int ret;
-
-       if (is_root_cache(s)) {
-               mod_node_page_state(page_pgdat(page), cache_vmstat_idx(s),
-                                   PAGE_SIZE << order);
-               return 0;
-       }
+#ifdef CONFIG_MEMCG_KMEM
+       if (memcg_kmem_enabled() && !is_root_cache(s)) {
+               int ret;
 
-       ret = memcg_alloc_page_obj_cgroups(page, s, gfp);
-       if (ret)
-               return ret;
+               ret = memcg_alloc_page_obj_cgroups(page, s, gfp);
+               if (ret)
+                       return ret;
 
-       return memcg_charge_slab(page, gfp, order, s);
+               percpu_ref_get_many(&s->memcg_params.refcnt, 1 << order);
+       }
+#endif
+       mod_node_page_state(page_pgdat(page), cache_vmstat_idx(s),
+                           PAGE_SIZE << order);
+       return 0;
 }
 
 static __always_inline void uncharge_slab_page(struct page *page, int order,
                                               struct kmem_cache *s)
 {
-       if (is_root_cache(s)) {
-               mod_node_page_state(page_pgdat(page), cache_vmstat_idx(s),
-                                   -(PAGE_SIZE << order));
-               return;
+#ifdef CONFIG_MEMCG_KMEM
+       if (memcg_kmem_enabled() && !is_root_cache(s)) {
+               memcg_free_page_obj_cgroups(page);
+               percpu_ref_put_many(&s->memcg_params.refcnt, 1 << order);
        }
-
-       memcg_free_page_obj_cgroups(page);
-       memcg_uncharge_slab(page, order, s);
+#endif
+       mod_node_page_state(page_pgdat(page), cache_vmstat_idx(s),
+                           -(PAGE_SIZE << order));
 }
 
 static inline struct kmem_cache *cache_from_obj(struct kmem_cache *s, void *x)
 
        if (memcg_kmem_enabled() &&
            ((flags & __GFP_ACCOUNT) || (s->flags & SLAB_ACCOUNT)))
-               return memcg_kmem_get_cache(s, objcgp);
+               return memcg_slab_pre_alloc_hook(s, objcgp, size, flags);
 
        return s;
 }