{
        unsigned long remaining, bytes;
 
-       /* Cap to one page in the first iteration, if PAGE_SIZE unaligned. */
-       bytes = !bitmap->mapped.pgoff ? bitmap->mapped.npages << PAGE_SHIFT :
-                                       PAGE_SIZE - bitmap->mapped.pgoff;
+       bytes = (bitmap->mapped.npages << PAGE_SHIFT) - bitmap->mapped.pgoff;
 
        remaining = bitmap->mapped_total_index - bitmap->mapped_base_index;
        remaining = min_t(unsigned long, remaining,
  * Set the bits corresponding to the range [iova .. iova+length-1] in
  * the user bitmap.
  *
- * Return: The number of bits set.
  */
 void iova_bitmap_set(struct iova_bitmap *bitmap,
                     unsigned long iova, size_t length)
 {
        struct iova_bitmap_map *mapped = &bitmap->mapped;
-       unsigned long offset = (iova - mapped->iova) >> mapped->pgshift;
-       unsigned long nbits = max_t(unsigned long, 1, length >> mapped->pgshift);
-       unsigned long page_idx = offset / BITS_PER_PAGE;
-       unsigned long page_offset = mapped->pgoff;
-       void *kaddr;
-
-       offset = offset % BITS_PER_PAGE;
+       unsigned long cur_bit = ((iova - mapped->iova) >>
+                       mapped->pgshift) + mapped->pgoff * BITS_PER_BYTE;
+       unsigned long last_bit = (((iova + length - 1) - mapped->iova) >>
+                       mapped->pgshift) + mapped->pgoff * BITS_PER_BYTE;
 
        do {
-               unsigned long size = min(BITS_PER_PAGE - offset, nbits);
+               unsigned int page_idx = cur_bit / BITS_PER_PAGE;
+               unsigned int offset = cur_bit % BITS_PER_PAGE;
+               unsigned int nbits = min(BITS_PER_PAGE - offset,
+                                        last_bit - cur_bit + 1);
+               void *kaddr;
 
                kaddr = kmap_local_page(mapped->pages[page_idx]);
-               bitmap_set(kaddr + page_offset, offset, size);
+               bitmap_set(kaddr, offset, nbits);
                kunmap_local(kaddr);
-               page_offset = offset = 0;
-               nbits -= size;
-               page_idx++;
-       } while (nbits > 0);
+               cur_bit += nbits;
+       } while (cur_bit <= last_bit);
 }
 EXPORT_SYMBOL_GPL(iova_bitmap_set);