#include <linux/kthread.h>
 #include <linux/sched/cputime.h>
 
-static inline bool rebalance_ptr_pred(struct bch_fs *c,
-                                     struct extent_ptr_decoded p,
-                                     struct bch_io_opts *io_opts)
+/*
+ * Check if an extent should be moved:
+ * returns -1 if it should not be moved, or
+ * device of pointer that should be moved, if known, or INT_MAX if unknown
+ */
+static int __bch2_rebalance_pred(struct bch_fs *c,
+                                struct bkey_s_c k,
+                                struct bch_io_opts *io_opts)
 {
-       if (io_opts->background_target &&
-           !bch2_dev_in_target(c, p.ptr.dev, io_opts->background_target) &&
-           !p.ptr.cached)
-               return true;
+       struct bkey_ptrs_c ptrs = bch2_bkey_ptrs_c(k);
+       const union bch_extent_entry *entry;
+       struct extent_ptr_decoded p;
+
+       if (io_opts->background_compression)
+               bkey_for_each_ptr_decode(k.k, ptrs, p, entry)
+                       if (!p.ptr.cached &&
+                           p.crc.compression_type !=
+                           bch2_compression_opt_to_type[io_opts->background_compression])
+                               return p.ptr.dev;
 
-       if (io_opts->background_compression &&
-           p.crc.compression_type !=
-           bch2_compression_opt_to_type[io_opts->background_compression])
-               return true;
+       if (io_opts->background_target)
+               bkey_for_each_ptr_decode(k.k, ptrs, p, entry)
+                       if (!p.ptr.cached &&
+                           !bch2_dev_in_target(c, p.ptr.dev, io_opts->background_target))
+                               return p.ptr.dev;
 
-       return false;
+       return -1;
 }
 
 void bch2_rebalance_add_key(struct bch_fs *c,
                            struct bkey_s_c k,
                            struct bch_io_opts *io_opts)
 {
-       struct bkey_ptrs_c ptrs = bch2_bkey_ptrs_c(k);
-       const union bch_extent_entry *entry;
-       struct extent_ptr_decoded p;
+       atomic64_t *counter;
+       int dev;
 
-       if (!io_opts->background_target &&
-           !io_opts->background_compression)
+       dev = __bch2_rebalance_pred(c, k, io_opts);
+       if (dev < 0)
                return;
 
-       bkey_for_each_ptr_decode(k.k, ptrs, p, entry)
-               if (rebalance_ptr_pred(c, p, io_opts)) {
-                       struct bch_dev *ca = bch_dev_bkey_exists(c, p.ptr.dev);
+       counter = dev < INT_MAX
+               ? &bch_dev_bkey_exists(c, dev)->rebalance_work
+               : &c->rebalance.work_unknown_dev;
 
-                       if (atomic64_add_return(p.crc.compressed_size,
-                                               &ca->rebalance_work) ==
-                           p.crc.compressed_size)
-                               rebalance_wakeup(c);
-               }
-}
-
-void bch2_rebalance_add_work(struct bch_fs *c, u64 sectors)
-{
-       if (atomic64_add_return(sectors, &c->rebalance.work_unknown_dev) ==
-           sectors)
+       if (atomic64_add_return(k.k->size, counter) == k.k->size)
                rebalance_wakeup(c);
 }
 
                                    struct bch_io_opts *io_opts,
                                    struct data_opts *data_opts)
 {
-       struct bkey_ptrs_c ptrs = bch2_bkey_ptrs_c(k);
-       const union bch_extent_entry *entry;
-       struct extent_ptr_decoded p;
-       unsigned nr_replicas = 0;
-
-       bkey_for_each_ptr_decode(k.k, ptrs, p, entry) {
-               nr_replicas += !p.ptr.cached;
-
-               if (rebalance_ptr_pred(c, p, io_opts))
-                       goto found;
+       if (__bch2_rebalance_pred(c, k, io_opts) >= 0) {
+               data_opts->target               = io_opts->background_target;
+               data_opts->btree_insert_flags   = 0;
+               return DATA_ADD_REPLICAS;
+       } else {
+               return DATA_SKIP;
        }
+}
 
-       if (nr_replicas < io_opts->data_replicas)
-               goto found;
-
-       return DATA_SKIP;
-found:
-       data_opts->target               = io_opts->background_target;
-       data_opts->btree_insert_flags   = 0;
-       return DATA_ADD_REPLICAS;
+void bch2_rebalance_add_work(struct bch_fs *c, u64 sectors)
+{
+       if (atomic64_add_return(sectors, &c->rebalance.work_unknown_dev) ==
+           sectors)
+               rebalance_wakeup(c);
 }
 
 struct rebalance_work {