xa_lock_irqsave(&ent->mkeys, flags);
        push_to_reserved(ent, mr);
-       ent->total_mrs++;
        /* If we are doing fill_to_high_water then keep going. */
        queue_adjust_cache_locked(ent);
        xa_unlock_irqrestore(&ent->mkeys, flags);
        init_waitqueue_head(&mr->mmkey.wait);
        mr->mmkey.type = MLX5_MKEY_MR;
        WRITE_ONCE(ent->dev->cache.last_add, jiffies);
-       xa_lock_irq(&ent->mkeys);
-       ent->total_mrs++;
-       xa_unlock_irq(&ent->mkeys);
        kfree(in);
        return mr;
 free_mr:
        if (!ent->stored)
                return;
        mr = pop_stored_mkey(ent);
-       ent->total_mrs--;
        xa_unlock_irq(&ent->mkeys);
        mlx5_core_destroy_mkey(ent->dev->mdev, mr->mmkey.key);
        kfree(mr);
         * mkeys.
         */
        xa_lock_irq(&ent->mkeys);
-       if (target < ent->total_mrs - ent->stored) {
+       if (target < ent->in_use) {
                err = -EINVAL;
                goto err_unlock;
        }
-       target = target - (ent->total_mrs - ent->stored);
+       target = target - ent->in_use;
        if (target < ent->limit || target > ent->limit*2) {
                err = -EINVAL;
                goto err_unlock;
        char lbuf[20];
        int err;
 
-       err = snprintf(lbuf, sizeof(lbuf), "%d\n", ent->total_mrs);
+       err = snprintf(lbuf, sizeof(lbuf), "%ld\n", ent->stored + ent->in_use);
        if (err < 0)
                return err;
 
                return ERR_PTR(-EOPNOTSUPP);
 
        xa_lock_irq(&ent->mkeys);
+       ent->in_use++;
+
        if (!ent->stored) {
                queue_adjust_cache_locked(ent);
                ent->miss++;
                xa_unlock_irq(&ent->mkeys);
                mr = create_cache_mr(ent);
-               if (IS_ERR(mr))
+               if (IS_ERR(mr)) {
+                       xa_lock_irq(&ent->mkeys);
+                       ent->in_use--;
+                       xa_unlock_irq(&ent->mkeys);
                        return mr;
+               }
        } else {
                mr = pop_stored_mkey(ent);
                queue_adjust_cache_locked(ent);
        xa_lock_irq(&ent->mkeys);
        while (ent->stored) {
                mr = pop_stored_mkey(ent);
-               ent->total_mrs--;
                xa_unlock_irq(&ent->mkeys);
                mlx5_core_destroy_mkey(dev->mdev, mr->mmkey.key);
                kfree(mr);
 
        /* Stop DMA */
        if (mr->cache_ent) {
+               xa_lock_irq(&mr->cache_ent->mkeys);
+               mr->cache_ent->in_use--;
+               xa_unlock_irq(&mr->cache_ent->mkeys);
+
                if (mlx5r_umr_revoke_mr(mr) ||
-                   push_mkey(mr->cache_ent, false, mr)) {
-                       xa_lock_irq(&mr->cache_ent->mkeys);
-                       mr->cache_ent->total_mrs--;
-                       xa_unlock_irq(&mr->cache_ent->mkeys);
+                   push_mkey(mr->cache_ent, false, mr))
                        mr->cache_ent = NULL;
-               }
        }
        if (!mr->cache_ent) {
                rc = destroy_mkey(to_mdev(mr->ibmr.device), mr);