return dev_is_pci(dev) && to_pci_dev(dev)->untrusted;
 }
 
-static bool dev_use_swiotlb(struct device *dev)
+static bool dev_use_swiotlb(struct device *dev, size_t size,
+                           enum dma_data_direction dir)
 {
-       return IS_ENABLED(CONFIG_SWIOTLB) && dev_is_untrusted(dev);
+       return IS_ENABLED(CONFIG_SWIOTLB) &&
+               (dev_is_untrusted(dev) ||
+                dma_kmalloc_needs_bounce(dev, size, dir));
+}
+
+static bool dev_use_sg_swiotlb(struct device *dev, struct scatterlist *sg,
+                              int nents, enum dma_data_direction dir)
+{
+       struct scatterlist *s;
+       int i;
+
+       if (!IS_ENABLED(CONFIG_SWIOTLB))
+               return false;
+
+       if (dev_is_untrusted(dev))
+               return true;
+
+       /*
+        * If kmalloc() buffers are not DMA-safe for this device and
+        * direction, check the individual lengths in the sg list. If any
+        * element is deemed unsafe, use the swiotlb for bouncing.
+        */
+       if (!dma_kmalloc_safe(dev, dir)) {
+               for_each_sg(sg, s, nents, i)
+                       if (!dma_kmalloc_size_aligned(s->length))
+                               return true;
+       }
+
+       return false;
 }
 
 /**
 {
        phys_addr_t phys;
 
-       if (dev_is_dma_coherent(dev) && !dev_use_swiotlb(dev))
+       if (dev_is_dma_coherent(dev) && !dev_use_swiotlb(dev, size, dir))
                return;
 
        phys = iommu_iova_to_phys(iommu_get_dma_domain(dev), dma_handle);
 {
        phys_addr_t phys;
 
-       if (dev_is_dma_coherent(dev) && !dev_use_swiotlb(dev))
+       if (dev_is_dma_coherent(dev) && !dev_use_swiotlb(dev, size, dir))
                return;
 
        phys = iommu_iova_to_phys(iommu_get_dma_domain(dev), dma_handle);
        struct scatterlist *sg;
        int i;
 
-       if (dev_use_swiotlb(dev))
+       if (sg_dma_is_swiotlb(sgl))
                for_each_sg(sgl, sg, nelems, i)
                        iommu_dma_sync_single_for_cpu(dev, sg_dma_address(sg),
                                                      sg->length, dir);
        struct scatterlist *sg;
        int i;
 
-       if (dev_use_swiotlb(dev))
+       if (sg_dma_is_swiotlb(sgl))
                for_each_sg(sgl, sg, nelems, i)
                        iommu_dma_sync_single_for_device(dev,
                                                         sg_dma_address(sg),
         * If both the physical buffer start address and size are
         * page aligned, we don't need to use a bounce page.
         */
-       if (dev_use_swiotlb(dev) && iova_offset(iovad, phys | size)) {
+       if (dev_use_swiotlb(dev, size, dir) &&
+           iova_offset(iovad, phys | size)) {
                void *padding_start;
                size_t padding_size, aligned_size;
 
        struct scatterlist *s;
        int i;
 
+       sg_dma_mark_swiotlb(sg);
+
        for_each_sg(sg, s, nents, i) {
                sg_dma_address(s) = iommu_dma_map_page(dev, sg_page(s),
                                s->offset, s->length, dir, attrs);
                        goto out;
        }
 
-       if (dev_use_swiotlb(dev))
+       if (dev_use_sg_swiotlb(dev, sg, nents, dir))
                return iommu_dma_map_sg_swiotlb(dev, sg, nents, dir, attrs);
 
        if (!(attrs & DMA_ATTR_SKIP_CPU_SYNC))
        struct scatterlist *tmp;
        int i;
 
-       if (dev_use_swiotlb(dev)) {
+       if (sg_dma_is_swiotlb(sg)) {
                iommu_dma_unmap_sg_swiotlb(dev, sg, nents, dir, attrs);
                return;
        }
 
 /*
  * One 64-bit architectures there is a 4-byte padding in struct scatterlist
  * (assuming also CONFIG_NEED_SG_DMA_LENGTH is set). Use this padding for DMA
- * flags bits to indicate when a specific dma address is a bus address.
+ * flags bits to indicate when a specific dma address is a bus address or the
+ * buffer may have been bounced via SWIOTLB.
  */
 #ifdef CONFIG_NEED_SG_DMA_FLAGS
 
-#define SG_DMA_BUS_ADDRESS (1 << 0)
+#define SG_DMA_BUS_ADDRESS     (1 << 0)
+#define SG_DMA_SWIOTLB         (1 << 1)
 
 /**
  * sg_dma_is_bus_address - Return whether a given segment was marked
        sg->dma_flags &= ~SG_DMA_BUS_ADDRESS;
 }
 
+/**
+ * sg_dma_is_swiotlb - Return whether the scatterlist was marked for SWIOTLB
+ *                     bouncing
+ * @sg:                SG entry
+ *
+ * Description:
+ *   Returns true if the scatterlist was marked for SWIOTLB bouncing. Not all
+ *   elements may have been bounced, so the caller would have to check
+ *   individual SG entries with is_swiotlb_buffer().
+ */
+static inline bool sg_dma_is_swiotlb(struct scatterlist *sg)
+{
+       return sg->dma_flags & SG_DMA_SWIOTLB;
+}
+
+/**
+ * sg_dma_mark_swiotlb - Mark the scatterlist for SWIOTLB bouncing
+ * @sg:                SG entry
+ *
+ * Description:
+ *   Marks a a scatterlist for SWIOTLB bounce. Not all SG entries may be
+ *   bounced.
+ */
+static inline void sg_dma_mark_swiotlb(struct scatterlist *sg)
+{
+       sg->dma_flags |= SG_DMA_SWIOTLB;
+}
+
 #else
 
 static inline bool sg_dma_is_bus_address(struct scatterlist *sg)
 static inline void sg_dma_unmark_bus_address(struct scatterlist *sg)
 {
 }
+static inline bool sg_dma_is_swiotlb(struct scatterlist *sg)
+{
+       return false;
+}
+static inline void sg_dma_mark_swiotlb(struct scatterlist *sg)
+{
+}
 
 #endif /* CONFIG_NEED_SG_DMA_FLAGS */