ib_dma_unmap_sg(dev, sg, sg_cnt, dir);
 }
 
-static int rdma_rw_map_sg(struct ib_device *dev, struct scatterlist *sg,
-                         u32 sg_cnt, enum dma_data_direction dir)
+static int rdma_rw_map_sgtable(struct ib_device *dev, struct sg_table *sgt,
+                              enum dma_data_direction dir)
 {
-       if (is_pci_p2pdma_page(sg_page(sg))) {
+       int nents;
+
+       if (is_pci_p2pdma_page(sg_page(sgt->sgl))) {
                if (WARN_ON_ONCE(ib_uses_virt_dma(dev)))
                        return 0;
-               return pci_p2pdma_map_sg(dev->dma_device, sg, sg_cnt, dir);
+               nents = pci_p2pdma_map_sg(dev->dma_device, sgt->sgl,
+                                         sgt->orig_nents, dir);
+               if (!nents)
+                       return -EIO;
+               sgt->nents = nents;
+               return 0;
        }
-       return ib_dma_map_sg(dev, sg, sg_cnt, dir);
+       return ib_dma_map_sgtable_attrs(dev, sgt, dir, 0);
 }
 
 /**
                u64 remote_addr, u32 rkey, enum dma_data_direction dir)
 {
        struct ib_device *dev = qp->pd->device;
+       struct sg_table sgt = {
+               .sgl = sg,
+               .orig_nents = sg_cnt,
+       };
        int ret;
 
-       ret = rdma_rw_map_sg(dev, sg, sg_cnt, dir);
-       if (!ret)
-               return -ENOMEM;
-       sg_cnt = ret;
+       ret = rdma_rw_map_sgtable(dev, &sgt, dir);
+       if (ret)
+               return ret;
+       sg_cnt = sgt.nents;
 
        /*
         * Skip to the S/G entry that sg_offset falls into:
        return ret;
 
 out_unmap_sg:
-       rdma_rw_unmap_sg(dev, sg, sg_cnt, dir);
+       rdma_rw_unmap_sg(dev, sgt.sgl, sgt.orig_nents, dir);
        return ret;
 }
 EXPORT_SYMBOL(rdma_rw_ctx_init);
        struct ib_device *dev = qp->pd->device;
        u32 pages_per_mr = rdma_rw_fr_page_list_len(qp->pd->device,
                                                    qp->integrity_en);
+       struct sg_table sgt = {
+               .sgl = sg,
+               .orig_nents = sg_cnt,
+       };
+       struct sg_table prot_sgt = {
+               .sgl = prot_sg,
+               .orig_nents = prot_sg_cnt,
+       };
        struct ib_rdma_wr *rdma_wr;
        int count = 0, ret;
 
                return -EINVAL;
        }
 
-       ret = rdma_rw_map_sg(dev, sg, sg_cnt, dir);
-       if (!ret)
-               return -ENOMEM;
-       sg_cnt = ret;
+       ret = rdma_rw_map_sgtable(dev, &sgt, dir);
+       if (ret)
+               return ret;
 
        if (prot_sg_cnt) {
-               ret = rdma_rw_map_sg(dev, prot_sg, prot_sg_cnt, dir);
-               if (!ret) {
-                       ret = -ENOMEM;
+               ret = rdma_rw_map_sgtable(dev, &prot_sgt, dir);
+               if (ret)
                        goto out_unmap_sg;
-               }
-               prot_sg_cnt = ret;
        }
 
        ctx->type = RDMA_RW_SIG_MR;
 
        memcpy(ctx->reg->mr->sig_attrs, sig_attrs, sizeof(struct ib_sig_attrs));
 
-       ret = ib_map_mr_sg_pi(ctx->reg->mr, sg, sg_cnt, NULL, prot_sg,
-                             prot_sg_cnt, NULL, SZ_4K);
+       ret = ib_map_mr_sg_pi(ctx->reg->mr, sg, sgt.nents, NULL, prot_sg,
+                             prot_sgt.nents, NULL, SZ_4K);
        if (unlikely(ret)) {
-               pr_err("failed to map PI sg (%u)\n", sg_cnt + prot_sg_cnt);
+               pr_err("failed to map PI sg (%u)\n",
+                      sgt.nents + prot_sgt.nents);
                goto out_destroy_sig_mr;
        }
 
 out_free_ctx:
        kfree(ctx->reg);
 out_unmap_prot_sg:
-       if (prot_sg_cnt)
-               rdma_rw_unmap_sg(dev, prot_sg, prot_sg_cnt, dir);
+       if (prot_sgt.nents)
+               rdma_rw_unmap_sg(dev, prot_sgt.sgl, prot_sgt.orig_nents, dir);
 out_unmap_sg:
-       rdma_rw_unmap_sg(dev, sg, sg_cnt, dir);
+       rdma_rw_unmap_sg(dev, sgt.sgl, sgt.orig_nents, dir);
        return ret;
 }
 EXPORT_SYMBOL(rdma_rw_ctx_signature_init);