return s->size + sizeof(struct obj_cgroup *);
 }
 
-/*
- * Returns false if the allocation should fail.
- */
-static bool __memcg_slab_pre_alloc_hook(struct kmem_cache *s,
-                                       struct list_lru *lru,
-                                       struct obj_cgroup **objcgp,
-                                       size_t objects, gfp_t flags)
+static bool __memcg_slab_post_alloc_hook(struct kmem_cache *s,
+                                        struct list_lru *lru,
+                                        gfp_t flags, size_t size,
+                                        void **p)
 {
+       struct obj_cgroup *objcg;
+       struct slab *slab;
+       unsigned long off;
+       size_t i;
+
        /*
         * The obtained objcg pointer is safe to use within the current scope,
         * defined by current task or set_active_memcg() pair.
         * obj_cgroup_get() is used to get a permanent reference.
         */
-       struct obj_cgroup *objcg = current_obj_cgroup();
+       objcg = current_obj_cgroup();
        if (!objcg)
                return true;
 
+       /*
+        * slab_alloc_node() avoids the NULL check, so we might be called with a
+        * single NULL object. kmem_cache_alloc_bulk() aborts if it can't fill
+        * the whole requested size.
+        * return success as there's nothing to free back
+        */
+       if (unlikely(*p == NULL))
+               return true;
+
+       flags &= gfp_allowed_mask;
+
        if (lru) {
                int ret;
                struct mem_cgroup *memcg;
                        return false;
        }
 
-       if (obj_cgroup_charge(objcg, flags, objects * obj_full_size(s)))
+       if (obj_cgroup_charge(objcg, flags, size * obj_full_size(s)))
                return false;
 
-       *objcgp = objcg;
+       for (i = 0; i < size; i++) {
+               slab = virt_to_slab(p[i]);
+
+               if (!slab_obj_exts(slab) &&
+                   alloc_slab_obj_exts(slab, s, flags, false)) {
+                       obj_cgroup_uncharge(objcg, obj_full_size(s));
+                       continue;
+               }
+
+               off = obj_to_index(s, slab, p[i]);
+               obj_cgroup_get(objcg);
+               slab_obj_exts(slab)[off].objcg = objcg;
+               mod_objcg_state(objcg, slab_pgdat(slab),
+                               cache_vmstat_idx(s), obj_full_size(s));
+       }
+
        return true;
 }
 
-/*
- * Returns false if the allocation should fail.
- */
+static void memcg_alloc_abort_single(struct kmem_cache *s, void *object);
+
 static __fastpath_inline
-bool memcg_slab_pre_alloc_hook(struct kmem_cache *s, struct list_lru *lru,
-                              struct obj_cgroup **objcgp, size_t objects,
-                              gfp_t flags)
+bool memcg_slab_post_alloc_hook(struct kmem_cache *s, struct list_lru *lru,
+                               gfp_t flags, size_t size, void **p)
 {
-       if (!memcg_kmem_online())
+       if (likely(!memcg_kmem_online()))
                return true;
 
        if (likely(!(flags & __GFP_ACCOUNT) && !(s->flags & SLAB_ACCOUNT)))
                return true;
 
-       return likely(__memcg_slab_pre_alloc_hook(s, lru, objcgp, objects,
-                                                 flags));
-}
-
-static void __memcg_slab_post_alloc_hook(struct kmem_cache *s,
-                                        struct obj_cgroup *objcg,
-                                        gfp_t flags, size_t size,
-                                        void **p)
-{
-       struct slab *slab;
-       unsigned long off;
-       size_t i;
-
-       flags &= gfp_allowed_mask;
-
-       for (i = 0; i < size; i++) {
-               if (likely(p[i])) {
-                       slab = virt_to_slab(p[i]);
-
-                       if (!slab_obj_exts(slab) &&
-                           alloc_slab_obj_exts(slab, s, flags, false)) {
-                               obj_cgroup_uncharge(objcg, obj_full_size(s));
-                               continue;
-                       }
+       if (likely(__memcg_slab_post_alloc_hook(s, lru, flags, size, p)))
+               return true;
 
-                       off = obj_to_index(s, slab, p[i]);
-                       obj_cgroup_get(objcg);
-                       slab_obj_exts(slab)[off].objcg = objcg;
-                       mod_objcg_state(objcg, slab_pgdat(slab),
-                                       cache_vmstat_idx(s), obj_full_size(s));
-               } else {
-                       obj_cgroup_uncharge(objcg, obj_full_size(s));
-               }
+       if (likely(size == 1)) {
+               memcg_alloc_abort_single(s, *p);
+               *p = NULL;
+       } else {
+               kmem_cache_free_bulk(s, size, p);
        }
-}
-
-static __fastpath_inline
-void memcg_slab_post_alloc_hook(struct kmem_cache *s, struct obj_cgroup *objcg,
-                               gfp_t flags, size_t size, void **p)
-{
-       if (likely(!memcg_kmem_online() || !objcg))
-               return;
 
-       return __memcg_slab_post_alloc_hook(s, objcg, flags, size, p);
+       return false;
 }
 
 static void __memcg_slab_free_hook(struct kmem_cache *s, struct slab *slab,
 
        __memcg_slab_free_hook(s, slab, p, objects, obj_exts);
 }
-
-static inline
-void memcg_slab_alloc_error_hook(struct kmem_cache *s, int objects,
-                          struct obj_cgroup *objcg)
-{
-       if (objcg)
-               obj_cgroup_uncharge(objcg, objects * obj_full_size(s));
-}
 #else /* CONFIG_MEMCG_KMEM */
-static inline bool memcg_slab_pre_alloc_hook(struct kmem_cache *s,
-                                            struct list_lru *lru,
-                                            struct obj_cgroup **objcgp,
-                                            size_t objects, gfp_t flags)
-{
-       return true;
-}
-
-static inline void memcg_slab_post_alloc_hook(struct kmem_cache *s,
-                                             struct obj_cgroup *objcg,
+static inline bool memcg_slab_post_alloc_hook(struct kmem_cache *s,
+                                             struct list_lru *lru,
                                              gfp_t flags, size_t size,
                                              void **p)
 {
+       return true;
 }
 
 static inline void memcg_slab_free_hook(struct kmem_cache *s, struct slab *slab,
                                        void **p, int objects)
 {
 }
-
-static inline
-void memcg_slab_alloc_error_hook(struct kmem_cache *s, int objects,
-                                struct obj_cgroup *objcg)
-{
-}
 #endif /* CONFIG_MEMCG_KMEM */
 
 /*
 ALLOW_ERROR_INJECTION(should_failslab, ERRNO);
 
 static __fastpath_inline
-struct kmem_cache *slab_pre_alloc_hook(struct kmem_cache *s,
-                                      struct list_lru *lru,
-                                      struct obj_cgroup **objcgp,
-                                      size_t size, gfp_t flags)
+struct kmem_cache *slab_pre_alloc_hook(struct kmem_cache *s, gfp_t flags)
 {
        flags &= gfp_allowed_mask;
 
        if (unlikely(should_failslab(s, flags)))
                return NULL;
 
-       if (unlikely(!memcg_slab_pre_alloc_hook(s, lru, objcgp, size, flags)))
-               return NULL;
-
        return s;
 }
 
 static __fastpath_inline
-void slab_post_alloc_hook(struct kmem_cache *s,        struct obj_cgroup *objcg,
+bool slab_post_alloc_hook(struct kmem_cache *s, struct list_lru *lru,
                          gfp_t flags, size_t size, void **p, bool init,
                          unsigned int orig_size)
 {
                }
        }
 
-       memcg_slab_post_alloc_hook(s, objcg, flags, size, p);
+       return memcg_slab_post_alloc_hook(s, lru, flags, size, p);
 }
 
 /*
                gfp_t gfpflags, int node, unsigned long addr, size_t orig_size)
 {
        void *object;
-       struct obj_cgroup *objcg = NULL;
        bool init = false;
 
-       s = slab_pre_alloc_hook(s, lru, &objcg, 1, gfpflags);
+       s = slab_pre_alloc_hook(s, gfpflags);
        if (unlikely(!s))
                return NULL;
 
        /*
         * When init equals 'true', like for kzalloc() family, only
         * @orig_size bytes might be zeroed instead of s->object_size
+        * In case this fails due to memcg_slab_post_alloc_hook(),
+        * object is set to NULL
         */
-       slab_post_alloc_hook(s, objcg, gfpflags, 1, &object, init, orig_size);
+       slab_post_alloc_hook(s, lru, gfpflags, 1, &object, init, orig_size);
 
        return object;
 }
                do_slab_free(s, slab, object, object, 1, addr);
 }
 
+#ifdef CONFIG_MEMCG_KMEM
+/* Do not inline the rare memcg charging failed path into the allocation path */
+static noinline
+void memcg_alloc_abort_single(struct kmem_cache *s, void *object)
+{
+       if (likely(slab_free_hook(s, object, slab_want_init_on_free(s))))
+               do_slab_free(s, virt_to_slab(object), object, object, 1, _RET_IP_);
+}
+#endif
+
 static __fastpath_inline
 void slab_free_bulk(struct kmem_cache *s, struct slab *slab, void *head,
                    void *tail, void **p, int cnt, unsigned long addr)
                                 void **p)
 {
        int i;
-       struct obj_cgroup *objcg = NULL;
 
        if (!size)
                return 0;
 
-       /* memcg and kmem_cache debug support */
-       s = slab_pre_alloc_hook(s, NULL, &objcg, size, flags);
+       s = slab_pre_alloc_hook(s, flags);
        if (unlikely(!s))
                return 0;
 
        i = __kmem_cache_alloc_bulk(s, flags, size, p);
+       if (unlikely(i == 0))
+               return 0;
 
        /*
         * memcg and kmem_cache debug support and memory initialization.
         * Done outside of the IRQ disabled fastpath loop.
         */
-       if (likely(i != 0)) {
-               slab_post_alloc_hook(s, objcg, flags, size, p,
-                       slab_want_init_on_alloc(flags, s), s->object_size);
-       } else {
-               memcg_slab_alloc_error_hook(s, size, objcg);
+       if (unlikely(!slab_post_alloc_hook(s, NULL, flags, size, p,
+                   slab_want_init_on_alloc(flags, s), s->object_size))) {
+               return 0;
        }
-
        return i;
 }
 EXPORT_SYMBOL(kmem_cache_alloc_bulk_noprof);