enum rxe_mr_lookup_type type);
 int mr_check_range(struct rxe_mr *mr, u64 iova, size_t length);
 int advance_dma_data(struct rxe_dma_info *dma, unsigned int length);
-int rxe_invalidate_mr(struct rxe_qp *qp, u32 rkey);
+int rxe_invalidate_mr(struct rxe_qp *qp, u32 key);
 int rxe_reg_fast_mr(struct rxe_qp *qp, struct rxe_send_wqe *wqe);
 int rxe_mr_set_page(struct ib_mr *ibmr, u64 addr);
 int rxe_dereg_mr(struct ib_mr *ibmr, struct ib_udata *udata);
 
        return mr;
 }
 
-int rxe_invalidate_mr(struct rxe_qp *qp, u32 rkey)
+int rxe_invalidate_mr(struct rxe_qp *qp, u32 key)
 {
        struct rxe_dev *rxe = to_rdev(qp->ibqp.device);
        struct rxe_mr *mr;
        int ret;
 
-       mr = rxe_pool_get_index(&rxe->mr_pool, rkey >> 8);
+       mr = rxe_pool_get_index(&rxe->mr_pool, key >> 8);
        if (!mr) {
-               pr_err("%s: No MR for rkey %#x\n", __func__, rkey);
+               pr_err("%s: No MR for key %#x\n", __func__, key);
                ret = -EINVAL;
                goto err;
        }
 
-       if (rkey != mr->rkey) {
-               pr_err("%s: rkey (%#x) doesn't match mr->rkey (%#x)\n",
-                       __func__, rkey, mr->rkey);
+       if (mr->rkey ? (key != mr->rkey) : (key != mr->lkey)) {
+               pr_err("%s: wr key (%#x) doesn't match mr key (%#x)\n",
+                       __func__, key, (mr->rkey ? mr->rkey : mr->lkey));
                ret = -EINVAL;
                goto err_drop_ref;
        }