#include <linux/minmax.h>
 #include <linux/swab.h>
 
+/*
+ * Common helper for find_bit() function family
+ * @FETCH: The expression that fetches and pre-processes each word of bitmap(s)
+ * @MUNGE: The expression that post-processes a word containing found bit (may be empty)
+ * @size: The bitmap size in bits
+ */
+#define FIND_FIRST_BIT(FETCH, MUNGE, size)                                     \
+({                                                                             \
+       unsigned long idx, val, sz = (size);                                    \
+                                                                               \
+       for (idx = 0; idx * BITS_PER_LONG < sz; idx++) {                        \
+               val = (FETCH);                                                  \
+               if (val) {                                                      \
+                       sz = min(idx * BITS_PER_LONG + __ffs(MUNGE(val)), sz);  \
+                       break;                                                  \
+               }                                                               \
+       }                                                                       \
+                                                                               \
+       sz;                                                                     \
+})
+
 #if !defined(find_next_bit) || !defined(find_next_zero_bit) ||                 \
        !defined(find_next_bit_le) || !defined(find_next_zero_bit_le) ||        \
        !defined(find_next_and_bit)
  */
 unsigned long _find_first_bit(const unsigned long *addr, unsigned long size)
 {
-       unsigned long idx;
-
-       for (idx = 0; idx * BITS_PER_LONG < size; idx++) {
-               if (addr[idx])
-                       return min(idx * BITS_PER_LONG + __ffs(addr[idx]), size);
-       }
-
-       return size;
+       return FIND_FIRST_BIT(addr[idx], /* nop */, size);
 }
 EXPORT_SYMBOL(_find_first_bit);
 #endif
                                  const unsigned long *addr2,
                                  unsigned long size)
 {
-       unsigned long idx, val;
-
-       for (idx = 0; idx * BITS_PER_LONG < size; idx++) {
-               val = addr1[idx] & addr2[idx];
-               if (val)
-                       return min(idx * BITS_PER_LONG + __ffs(val), size);
-       }
-
-       return size;
+       return FIND_FIRST_BIT(addr1[idx] & addr2[idx], /* nop */, size);
 }
 EXPORT_SYMBOL(_find_first_and_bit);
 #endif
  */
 unsigned long _find_first_zero_bit(const unsigned long *addr, unsigned long size)
 {
-       unsigned long idx;
-
-       for (idx = 0; idx * BITS_PER_LONG < size; idx++) {
-               if (addr[idx] != ~0UL)
-                       return min(idx * BITS_PER_LONG + ffz(addr[idx]), size);
-       }
-
-       return size;
+       return FIND_FIRST_BIT(~addr[idx], /* nop */, size);
 }
 EXPORT_SYMBOL(_find_first_zero_bit);
 #endif