return !ret ? NULL : ret + LLIST_NODE_SZ;
 }
 
-/* Most of the logic is taken from setup_kmalloc_cache_index_table() */
 static __init int bpf_mem_cache_adjust_size(void)
 {
-       unsigned int size, index;
+       unsigned int size;
 
-       /* Normally KMALLOC_MIN_SIZE is 8-bytes, but it can be
-        * up-to 256-bytes.
+       /* Adjusting the indexes in size_index() according to the object_size
+        * of underlying slab cache, so bpf_mem_alloc() will select a
+        * bpf_mem_cache with unit_size equal to the object_size of
+        * the underlying slab cache.
+        *
+        * The maximal value of KMALLOC_MIN_SIZE and __kmalloc_minalign() is
+        * 256-bytes, so only do adjustment for [8-bytes, 192-bytes].
         */
-       size = KMALLOC_MIN_SIZE;
-       if (size <= 192)
-               index = size_index[(size - 1) / 8];
-       else
-               index = fls(size - 1) - 1;
-       for (size = 8; size < KMALLOC_MIN_SIZE && size <= 192; size += 8)
-               size_index[(size - 1) / 8] = index;
+       for (size = 192; size >= 8; size -= 8) {
+               unsigned int kmalloc_size, index;
 
-       /* The minimal alignment is 64-bytes, so disable 96-bytes cache and
-        * use 128-bytes cache instead.
-        */
-       if (KMALLOC_MIN_SIZE >= 64) {
-               index = size_index[(128 - 1) / 8];
-               for (size = 64 + 8; size <= 96; size += 8)
-                       size_index[(size - 1) / 8] = index;
-       }
+               kmalloc_size = kmalloc_size_roundup(size);
+               if (kmalloc_size == size)
+                       continue;
 
-       /* The minimal alignment is 128-bytes, so disable 192-bytes cache and
-        * use 256-bytes cache instead.
-        */
-       if (KMALLOC_MIN_SIZE >= 128) {
-               index = fls(256 - 1) - 1;
-               for (size = 128 + 8; size <= 192; size += 8)
+               if (kmalloc_size <= 192)
+                       index = size_index[(kmalloc_size - 1) / 8];
+               else
+                       index = fls(kmalloc_size - 1) - 1;
+               /* Only overwrite if necessary */
+               if (size_index[(size - 1) / 8] != index)
                        size_index[(size - 1) / 8] = index;
        }