(iter->iomap.flags & IOMAP_F_DIRTY);
 }
 
-static bool dax_fault_is_cow(const struct iomap_iter *iter)
-{
-       return (iter->flags & IOMAP_WRITE) &&
-               (iter->iomap.flags & IOMAP_F_SHARED);
-}
-
 /*
  * By this point grab_mapping_entry() has ensured that we have a locked entry
  * of the appropriate size so we don't have to worry about downgrading PMDs to
 {
        struct address_space *mapping = vmf->vma->vm_file->f_mapping;
        void *new_entry = dax_make_entry(pfn, flags);
-       bool dirty = !dax_fault_is_synchronous(iter, vmf->vma);
-       bool cow = dax_fault_is_cow(iter);
+       bool write = iter->flags & IOMAP_WRITE;
+       bool dirty = write && !dax_fault_is_synchronous(iter, vmf->vma);
+       bool shared = iter->iomap.flags & IOMAP_F_SHARED;
 
        if (dirty)
                __mark_inode_dirty(mapping->host, I_DIRTY_PAGES);
 
-       if (cow || (dax_is_zero_entry(entry) && !(flags & DAX_ZERO_PAGE))) {
+       if (shared || (dax_is_zero_entry(entry) && !(flags & DAX_ZERO_PAGE))) {
                unsigned long index = xas->xa_index;
                /* we are replacing a zero page with block mapping */
                if (dax_is_pmd_entry(entry))
 
        xas_reset(xas);
        xas_lock_irq(xas);
-       if (cow || dax_is_zero_entry(entry) || dax_is_empty_entry(entry)) {
+       if (shared || dax_is_zero_entry(entry) || dax_is_empty_entry(entry)) {
                void *old;
 
                dax_disassociate_entry(entry, mapping, false);
                dax_associate_entry(new_entry, mapping, vmf->vma, vmf->address,
-                               cow);
+                               shared);
                /*
                 * Only swap our new entry into the page cache if the current
                 * entry is a zero page or an empty entry.  If a normal PTE or
        if (dirty)
                xas_set_mark(xas, PAGECACHE_TAG_DIRTY);
 
-       if (cow)
+       if (write && shared)
                xas_set_mark(xas, PAGECACHE_TAG_TOWRITE);
 
        xas_unlock_irq(xas);
 
                return error;
        error = xfs_bmapi_read(ip, offset_fsb, end_fsb - offset_fsb, &imap,
                               &nimaps, 0);
-       if (!error && (flags & IOMAP_REPORT))
+       if (!error && ((flags & IOMAP_REPORT) || IS_DAX(inode)))
                error = xfs_reflink_trim_around_shared(ip, &imap, &shared);
        xfs_iunlock(ip, lockmode);