#include <linux/prefetch.h>
 #include <linux/sort.h>
 
+static int __must_check
+bch2_trans_update_by_path(struct btree_trans *, struct btree_path *,
+                         struct bkey_i *, enum btree_update_flags);
+
 static inline int btree_insert_entry_cmp(const struct btree_insert_entry *l,
                                         const struct btree_insert_entry *r)
 {
        return   cmp_int(l->btree_id,   r->btree_id) ?:
+                cmp_int(l->cached,     r->cached) ?:
                 -cmp_int(l->level,     r->level) ?:
                 bpos_cmp(l->k->k.p,    r->k->k.p);
 }
 
        i->k->k.needs_whiteout = false;
 
-       did_work = !i->cached
-               ? btree_insert_key_leaf(trans, i)
-               : bch2_btree_insert_key_cached(trans, i->path, i->k);
+       if (!i->cached)
+               did_work = btree_insert_key_leaf(trans, i);
+       else if (!i->key_cache_already_flushed)
+               did_work = bch2_btree_insert_key_cached(trans, i->path, i->k);
+       else {
+               bch2_btree_key_cache_drop(trans, i->path);
+               did_work = false;
+       }
        if (!did_work)
                return;
 
                        goto out_reset;
        }
 
-#ifdef CONFIG_BCACHEFS_DEBUG
-       /*
-        * if BTREE_TRIGGER_NORUN is set, it means we're probably being called
-        * from the key cache flush code:
-        */
-       trans_for_each_update(trans, i)
-               if (!i->cached &&
-                   !(i->flags & BTREE_TRIGGER_NORUN))
-                       bch2_btree_key_cache_verify_clean(trans,
-                                       i->btree_id, i->k->k.p);
-#endif
-
        ret = bch2_trans_commit_run_triggers(trans);
        if (ret)
                goto out;
        return ret;
 }
 
-int __must_check bch2_trans_update_by_path(struct btree_trans *trans, struct btree_path *path,
-                                  struct bkey_i *k, enum btree_update_flags flags)
+static int __must_check
+bch2_trans_update_by_path_trace(struct btree_trans *trans, struct btree_path *path,
+                               struct bkey_i *k, enum btree_update_flags flags,
+                               unsigned long ip)
 {
        struct bch_fs *c = trans->c;
        struct btree_insert_entry *i, n;
+       int ret = 0;
 
        BUG_ON(!path->should_be_locked);
 
                .cached         = path->cached,
                .path           = path,
                .k              = k,
-               .ip_allocated   = _RET_IP_,
+               .ip_allocated   = ip,
        };
 
 #ifdef CONFIG_BCACHEFS_DEBUG
            !btree_insert_entry_cmp(&n, i)) {
                BUG_ON(i->insert_trigger_run || i->overwrite_trigger_run);
 
-               /*
-                * This is a hack to ensure that inode creates update the btree,
-                * not the key cache, which helps with cache coherency issues in
-                * other areas:
-                */
-               if (n.cached && !i->cached) {
-                       i->k = n.k;
-                       i->flags = n.flags;
-                       return 0;
-               }
-
                bch2_path_put(trans, i->path, true);
                i->flags        = n.flags;
                i->cached       = n.cached;
                }
        }
 
-       __btree_path_get(n.path, true);
-       return 0;
+       __btree_path_get(i->path, true);
+
+       /*
+        * If a key is present in the key cache, it must also exist in the
+        * btree - this is necessary for cache coherency. When iterating over
+        * a btree that's cached in the key cache, the btree iter code checks
+        * the key cache - but the key has to exist in the btree for that to
+        * work:
+        */
+       if (path->cached &&
+           bkey_deleted(&i->old_k) &&
+           !(flags & BTREE_UPDATE_NO_KEY_CACHE_COHERENCY)) {
+               struct btree_path *btree_path;
+
+               i->key_cache_already_flushed = true;
+               i->flags |= BTREE_TRIGGER_NORUN;
+
+               btree_path = bch2_path_get(trans, path->btree_id, path->pos,
+                                          1, 0, BTREE_ITER_INTENT);
+
+               ret = bch2_btree_path_traverse(trans, btree_path, 0);
+               if (ret)
+                       goto err;
+
+               btree_path->should_be_locked = true;
+               ret = bch2_trans_update_by_path_trace(trans, btree_path, k, flags, ip);
+err:
+               bch2_path_put(trans, btree_path, true);
+       }
+
+       return ret;
+}
+
+static int __must_check
+bch2_trans_update_by_path(struct btree_trans *trans, struct btree_path *path,
+                         struct bkey_i *k, enum btree_update_flags flags)
+{
+       return bch2_trans_update_by_path_trace(trans, path, k, flags, _RET_IP_);
 }
 
 int __must_check bch2_trans_update(struct btree_trans *trans, struct btree_iter *iter,
                                   struct bkey_i *k, enum btree_update_flags flags)
 {
+       struct btree_path *path = iter->update_path ?: iter->path;
+       struct bkey_cached *ck;
+       int ret;
+
        if (iter->flags & BTREE_ITER_IS_EXTENTS)
                return bch2_trans_update_extent(trans, iter, k, flags);
 
        if (bkey_deleted(&k->k) &&
+           !(flags & BTREE_UPDATE_KEY_CACHE_RECLAIM) &&
            (iter->flags & BTREE_ITER_FILTER_SNAPSHOTS)) {
-               int ret = need_whiteout_for_snapshot(trans, iter->btree_id, k->k.p);
+               ret = need_whiteout_for_snapshot(trans, iter->btree_id, k->k.p);
                if (unlikely(ret < 0))
                        return ret;
 
                        k->k.type = KEY_TYPE_whiteout;
        }
 
-       return bch2_trans_update_by_path(trans, iter->update_path ?: iter->path,
-                                        k, flags);
+       /*
+        * Ensure that updates to cached btrees go to the key cache:
+        */
+       if (!(flags & BTREE_UPDATE_KEY_CACHE_RECLAIM) &&
+           !path->cached &&
+           !path->level &&
+           btree_id_cached(trans->c, path->btree_id)) {
+               if (!iter->key_cache_path ||
+                   !iter->key_cache_path->should_be_locked ||
+                   bpos_cmp(iter->key_cache_path->pos, k->k.p)) {
+                       if (!iter->key_cache_path)
+                               iter->key_cache_path =
+                                       bch2_path_get(trans, path->btree_id, path->pos, 1, 0,
+                                                     BTREE_ITER_INTENT|BTREE_ITER_CACHED);
+
+                       iter->key_cache_path =
+                               bch2_btree_path_set_pos(trans, iter->key_cache_path, path->pos,
+                                                       iter->flags & BTREE_ITER_INTENT);
+
+                       ret = bch2_btree_path_traverse(trans, iter->key_cache_path,
+                                                      BTREE_ITER_CACHED);
+                       if (unlikely(ret))
+                               return ret;
+
+                       ck = (void *) iter->key_cache_path->l[0].b;
+
+                       if (test_bit(BKEY_CACHED_DIRTY, &ck->flags)) {
+                               trace_trans_restart_key_cache_raced(trans->fn, _RET_IP_);
+                               btree_trans_restart(trans);
+                               return -EINTR;
+                       }
+
+                       iter->key_cache_path->should_be_locked = true;
+               }
+
+               path = iter->key_cache_path;
+       }
+
+       return bch2_trans_update_by_path(trans, path, k, flags);
 }
 
 void bch2_trans_commit_hook(struct btree_trans *trans,