mm/hmm/test: add selftest driver for HMM
authorRalph Campbell <rcampbell@nvidia.com>
Wed, 22 Apr 2020 19:50:26 +0000 (12:50 -0700)
committerJason Gunthorpe <jgg@mellanox.com>
Tue, 19 May 2020 19:48:30 +0000 (16:48 -0300)
This driver is for testing device private memory migration and devices
which use hmm_range_fault() to access system memory via device page tables.

Link: https://lore.kernel.org/r/20200422195028.3684-2-rcampbell@nvidia.com
Link: https://lore.kernel.org/r/20200516010424.2013-1-rcampbell@nvidia.com
Signed-off-by: Ralph Campbell <rcampbell@nvidia.com>
Signed-off-by: Jérôme Glisse <jglisse@redhat.com>
Link: https://lore.kernel.org/r/20200509030225.14592-1-weiyongjun1@huawei.com
Link: https://lore.kernel.org/r/20200509030234.14747-1-weiyongjun1@huawei.com
Signed-off-by: Wei Yongjun <weiyongjun1@huawei.com>
Link: https://lore.kernel.org/r/20200511183704.GA225608@mwanda
Signed-off-by: Dan Carpenter <dan.carpenter@oracle.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
lib/Kconfig.debug
lib/Makefile
lib/test_hmm.c [new file with mode: 0644]
lib/test_hmm_uapi.h [new file with mode: 0644]

index 21d9c5f6e7ec7552be6f46fa352a29394465e3cc..6ddd646335dd7142214f9d136a0403e63aebbacf 100644 (file)
@@ -2201,6 +2201,19 @@ config TEST_MEMINIT
 
          If unsure, say N.
 
+config TEST_HMM
+       tristate "Test HMM (Heterogeneous Memory Management)"
+       depends on TRANSPARENT_HUGEPAGE
+       depends on DEVICE_PRIVATE
+       select HMM_MIRROR
+       select MMU_NOTIFIER
+       help
+         This is a pseudo device driver solely for testing HMM.
+         Say M here if you want to build the HMM test module.
+         Doing so will allow you to run tools/testing/selftest/vm/hmm-tests.
+
+         If unsure, say N.
+
 endif # RUNTIME_TESTING_MENU
 
 config MEMTEST
index 685aee60de1d5ea7d72041e4884e9d5cac1cdaf8..93d8ad358b44a841c06b8f3146f122481ee0a56c 100644 (file)
@@ -92,6 +92,7 @@ obj-$(CONFIG_TEST_STACKINIT) += test_stackinit.o
 obj-$(CONFIG_TEST_BLACKHOLE_DEV) += test_blackhole_dev.o
 obj-$(CONFIG_TEST_MEMINIT) += test_meminit.o
 obj-$(CONFIG_TEST_LOCKUP) += test_lockup.o
+obj-$(CONFIG_TEST_HMM) += test_hmm.o
 
 obj-$(CONFIG_TEST_LIVEPATCH) += livepatch/
 
diff --git a/lib/test_hmm.c b/lib/test_hmm.c
new file mode 100644 (file)
index 0000000..5c1858e
--- /dev/null
@@ -0,0 +1,1164 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * This is a module to test the HMM (Heterogeneous Memory Management)
+ * mirror and zone device private memory migration APIs of the kernel.
+ * Userspace programs can register with the driver to mirror their own address
+ * space and can use the device to read/write any valid virtual address.
+ */
+#include <linux/init.h>
+#include <linux/fs.h>
+#include <linux/mm.h>
+#include <linux/module.h>
+#include <linux/kernel.h>
+#include <linux/cdev.h>
+#include <linux/device.h>
+#include <linux/mutex.h>
+#include <linux/rwsem.h>
+#include <linux/sched.h>
+#include <linux/slab.h>
+#include <linux/highmem.h>
+#include <linux/delay.h>
+#include <linux/pagemap.h>
+#include <linux/hmm.h>
+#include <linux/vmalloc.h>
+#include <linux/swap.h>
+#include <linux/swapops.h>
+#include <linux/sched/mm.h>
+#include <linux/platform_device.h>
+
+#include "test_hmm_uapi.h"
+
+#define DMIRROR_NDEVICES               2
+#define DMIRROR_RANGE_FAULT_TIMEOUT    1000
+#define DEVMEM_CHUNK_SIZE              (256 * 1024 * 1024U)
+#define DEVMEM_CHUNKS_RESERVE          16
+
+static const struct dev_pagemap_ops dmirror_devmem_ops;
+static const struct mmu_interval_notifier_ops dmirror_min_ops;
+static dev_t dmirror_dev;
+static struct page *dmirror_zero_page;
+
+struct dmirror_device;
+
+struct dmirror_bounce {
+       void                    *ptr;
+       unsigned long           size;
+       unsigned long           addr;
+       unsigned long           cpages;
+};
+
+#define DPT_XA_TAG_WRITE 3UL
+
+/*
+ * Data structure to track address ranges and register for mmu interval
+ * notifier updates.
+ */
+struct dmirror_interval {
+       struct mmu_interval_notifier    notifier;
+       struct dmirror                  *dmirror;
+};
+
+/*
+ * Data attached to the open device file.
+ * Note that it might be shared after a fork().
+ */
+struct dmirror {
+       struct dmirror_device           *mdevice;
+       struct xarray                   pt;
+       struct mmu_interval_notifier    notifier;
+       struct mutex                    mutex;
+};
+
+/*
+ * ZONE_DEVICE pages for migration and simulating device memory.
+ */
+struct dmirror_chunk {
+       struct dev_pagemap      pagemap;
+       struct dmirror_device   *mdevice;
+};
+
+/*
+ * Per device data.
+ */
+struct dmirror_device {
+       struct cdev             cdevice;
+       struct hmm_devmem       *devmem;
+
+       unsigned int            devmem_capacity;
+       unsigned int            devmem_count;
+       struct dmirror_chunk    **devmem_chunks;
+       struct mutex            devmem_lock;    /* protects the above */
+
+       unsigned long           calloc;
+       unsigned long           cfree;
+       struct page             *free_pages;
+       spinlock_t              lock;           /* protects the above */
+};
+
+static struct dmirror_device dmirror_devices[DMIRROR_NDEVICES];
+
+static int dmirror_bounce_init(struct dmirror_bounce *bounce,
+                              unsigned long addr,
+                              unsigned long size)
+{
+       bounce->addr = addr;
+       bounce->size = size;
+       bounce->cpages = 0;
+       bounce->ptr = vmalloc(size);
+       if (!bounce->ptr)
+               return -ENOMEM;
+       return 0;
+}
+
+static void dmirror_bounce_fini(struct dmirror_bounce *bounce)
+{
+       vfree(bounce->ptr);
+}
+
+static int dmirror_fops_open(struct inode *inode, struct file *filp)
+{
+       struct cdev *cdev = inode->i_cdev;
+       struct dmirror *dmirror;
+       int ret;
+
+       /* Mirror this process address space */
+       dmirror = kzalloc(sizeof(*dmirror), GFP_KERNEL);
+       if (dmirror == NULL)
+               return -ENOMEM;
+
+       dmirror->mdevice = container_of(cdev, struct dmirror_device, cdevice);
+       mutex_init(&dmirror->mutex);
+       xa_init(&dmirror->pt);
+
+       ret = mmu_interval_notifier_insert(&dmirror->notifier, current->mm,
+                               0, ULONG_MAX & PAGE_MASK, &dmirror_min_ops);
+       if (ret) {
+               kfree(dmirror);
+               return ret;
+       }
+
+       filp->private_data = dmirror;
+       return 0;
+}
+
+static int dmirror_fops_release(struct inode *inode, struct file *filp)
+{
+       struct dmirror *dmirror = filp->private_data;
+
+       mmu_interval_notifier_remove(&dmirror->notifier);
+       xa_destroy(&dmirror->pt);
+       kfree(dmirror);
+       return 0;
+}
+
+static struct dmirror_device *dmirror_page_to_device(struct page *page)
+
+{
+       return container_of(page->pgmap, struct dmirror_chunk,
+                           pagemap)->mdevice;
+}
+
+static int dmirror_do_fault(struct dmirror *dmirror, struct hmm_range *range)
+{
+       unsigned long *pfns = range->hmm_pfns;
+       unsigned long pfn;
+
+       for (pfn = (range->start >> PAGE_SHIFT);
+            pfn < (range->end >> PAGE_SHIFT);
+            pfn++, pfns++) {
+               struct page *page;
+               void *entry;
+
+               /*
+                * Since we asked for hmm_range_fault() to populate pages,
+                * it shouldn't return an error entry on success.
+                */
+               WARN_ON(*pfns & HMM_PFN_ERROR);
+               WARN_ON(!(*pfns & HMM_PFN_VALID));
+
+               page = hmm_pfn_to_page(*pfns);
+               WARN_ON(!page);
+
+               entry = page;
+               if (*pfns & HMM_PFN_WRITE)
+                       entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
+               else if (WARN_ON(range->default_flags & HMM_PFN_WRITE))
+                       return -EFAULT;
+               entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
+               if (xa_is_err(entry))
+                       return xa_err(entry);
+       }
+
+       return 0;
+}
+
+static void dmirror_do_update(struct dmirror *dmirror, unsigned long start,
+                             unsigned long end)
+{
+       unsigned long pfn;
+       void *entry;
+
+       /*
+        * The XArray doesn't hold references to pages since it relies on
+        * the mmu notifier to clear page pointers when they become stale.
+        * Therefore, it is OK to just clear the entry.
+        */
+       xa_for_each_range(&dmirror->pt, pfn, entry, start >> PAGE_SHIFT,
+                         end >> PAGE_SHIFT)
+               xa_erase(&dmirror->pt, pfn);
+}
+
+static bool dmirror_interval_invalidate(struct mmu_interval_notifier *mni,
+                               const struct mmu_notifier_range *range,
+                               unsigned long cur_seq)
+{
+       struct dmirror *dmirror = container_of(mni, struct dmirror, notifier);
+
+       if (mmu_notifier_range_blockable(range))
+               mutex_lock(&dmirror->mutex);
+       else if (!mutex_trylock(&dmirror->mutex))
+               return false;
+
+       mmu_interval_set_seq(mni, cur_seq);
+       dmirror_do_update(dmirror, range->start, range->end);
+
+       mutex_unlock(&dmirror->mutex);
+       return true;
+}
+
+static const struct mmu_interval_notifier_ops dmirror_min_ops = {
+       .invalidate = dmirror_interval_invalidate,
+};
+
+static int dmirror_range_fault(struct dmirror *dmirror,
+                               struct hmm_range *range)
+{
+       struct mm_struct *mm = dmirror->notifier.mm;
+       unsigned long timeout =
+               jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
+       int ret;
+
+       while (true) {
+               if (time_after(jiffies, timeout)) {
+                       ret = -EBUSY;
+                       goto out;
+               }
+
+               range->notifier_seq = mmu_interval_read_begin(range->notifier);
+               down_read(&mm->mmap_sem);
+               ret = hmm_range_fault(range);
+               up_read(&mm->mmap_sem);
+               if (ret) {
+                       if (ret == -EBUSY)
+                               continue;
+                       goto out;
+               }
+
+               mutex_lock(&dmirror->mutex);
+               if (mmu_interval_read_retry(range->notifier,
+                                           range->notifier_seq)) {
+                       mutex_unlock(&dmirror->mutex);
+                       continue;
+               }
+               break;
+       }
+
+       ret = dmirror_do_fault(dmirror, range);
+
+       mutex_unlock(&dmirror->mutex);
+out:
+       return ret;
+}
+
+static int dmirror_fault(struct dmirror *dmirror, unsigned long start,
+                        unsigned long end, bool write)
+{
+       struct mm_struct *mm = dmirror->notifier.mm;
+       unsigned long addr;
+       unsigned long pfns[64];
+       struct hmm_range range = {
+               .notifier = &dmirror->notifier,
+               .hmm_pfns = pfns,
+               .pfn_flags_mask = 0,
+               .default_flags =
+                       HMM_PFN_REQ_FAULT | (write ? HMM_PFN_REQ_WRITE : 0),
+               .dev_private_owner = dmirror->mdevice,
+       };
+       int ret = 0;
+
+       /* Since the mm is for the mirrored process, get a reference first. */
+       if (!mmget_not_zero(mm))
+               return 0;
+
+       for (addr = start; addr < end; addr = range.end) {
+               range.start = addr;
+               range.end = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
+
+               ret = dmirror_range_fault(dmirror, &range);
+               if (ret)
+                       break;
+       }
+
+       mmput(mm);
+       return ret;
+}
+
+static int dmirror_do_read(struct dmirror *dmirror, unsigned long start,
+                          unsigned long end, struct dmirror_bounce *bounce)
+{
+       unsigned long pfn;
+       void *ptr;
+
+       ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
+
+       for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
+               void *entry;
+               struct page *page;
+               void *tmp;
+
+               entry = xa_load(&dmirror->pt, pfn);
+               page = xa_untag_pointer(entry);
+               if (!page)
+                       return -ENOENT;
+
+               tmp = kmap(page);
+               memcpy(ptr, tmp, PAGE_SIZE);
+               kunmap(page);
+
+               ptr += PAGE_SIZE;
+               bounce->cpages++;
+       }
+
+       return 0;
+}
+
+static int dmirror_read(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
+{
+       struct dmirror_bounce bounce;
+       unsigned long start, end;
+       unsigned long size = cmd->npages << PAGE_SHIFT;
+       int ret;
+
+       start = cmd->addr;
+       end = start + size;
+       if (end < start)
+               return -EINVAL;
+
+       ret = dmirror_bounce_init(&bounce, start, size);
+       if (ret)
+               return ret;
+
+       while (1) {
+               mutex_lock(&dmirror->mutex);
+               ret = dmirror_do_read(dmirror, start, end, &bounce);
+               mutex_unlock(&dmirror->mutex);
+               if (ret != -ENOENT)
+                       break;
+
+               start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
+               ret = dmirror_fault(dmirror, start, end, false);
+               if (ret)
+                       break;
+               cmd->faults++;
+       }
+
+       if (ret == 0) {
+               if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
+                                bounce.size))
+                       ret = -EFAULT;
+       }
+       cmd->cpages = bounce.cpages;
+       dmirror_bounce_fini(&bounce);
+       return ret;
+}
+
+static int dmirror_do_write(struct dmirror *dmirror, unsigned long start,
+                           unsigned long end, struct dmirror_bounce *bounce)
+{
+       unsigned long pfn;
+       void *ptr;
+
+       ptr = bounce->ptr + ((start - bounce->addr) & PAGE_MASK);
+
+       for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++) {
+               void *entry;
+               struct page *page;
+               void *tmp;
+
+               entry = xa_load(&dmirror->pt, pfn);
+               page = xa_untag_pointer(entry);
+               if (!page || xa_pointer_tag(entry) != DPT_XA_TAG_WRITE)
+                       return -ENOENT;
+
+               tmp = kmap(page);
+               memcpy(tmp, ptr, PAGE_SIZE);
+               kunmap(page);
+
+               ptr += PAGE_SIZE;
+               bounce->cpages++;
+       }
+
+       return 0;
+}
+
+static int dmirror_write(struct dmirror *dmirror, struct hmm_dmirror_cmd *cmd)
+{
+       struct dmirror_bounce bounce;
+       unsigned long start, end;
+       unsigned long size = cmd->npages << PAGE_SHIFT;
+       int ret;
+
+       start = cmd->addr;
+       end = start + size;
+       if (end < start)
+               return -EINVAL;
+
+       ret = dmirror_bounce_init(&bounce, start, size);
+       if (ret)
+               return ret;
+       if (copy_from_user(bounce.ptr, u64_to_user_ptr(cmd->ptr),
+                          bounce.size)) {
+               ret = -EFAULT;
+               goto fini;
+       }
+
+       while (1) {
+               mutex_lock(&dmirror->mutex);
+               ret = dmirror_do_write(dmirror, start, end, &bounce);
+               mutex_unlock(&dmirror->mutex);
+               if (ret != -ENOENT)
+                       break;
+
+               start = cmd->addr + (bounce.cpages << PAGE_SHIFT);
+               ret = dmirror_fault(dmirror, start, end, true);
+               if (ret)
+                       break;
+               cmd->faults++;
+       }
+
+fini:
+       cmd->cpages = bounce.cpages;
+       dmirror_bounce_fini(&bounce);
+       return ret;
+}
+
+static bool dmirror_allocate_chunk(struct dmirror_device *mdevice,
+                                  struct page **ppage)
+{
+       struct dmirror_chunk *devmem;
+       struct resource *res;
+       unsigned long pfn;
+       unsigned long pfn_first;
+       unsigned long pfn_last;
+       void *ptr;
+
+       mutex_lock(&mdevice->devmem_lock);
+
+       if (mdevice->devmem_count == mdevice->devmem_capacity) {
+               struct dmirror_chunk **new_chunks;
+               unsigned int new_capacity;
+
+               new_capacity = mdevice->devmem_capacity +
+                               DEVMEM_CHUNKS_RESERVE;
+               new_chunks = krealloc(mdevice->devmem_chunks,
+                               sizeof(new_chunks[0]) * new_capacity,
+                               GFP_KERNEL);
+               if (!new_chunks)
+                       goto err;
+               mdevice->devmem_capacity = new_capacity;
+               mdevice->devmem_chunks = new_chunks;
+       }
+
+       res = request_free_mem_region(&iomem_resource, DEVMEM_CHUNK_SIZE,
+                                       "hmm_dmirror");
+       if (IS_ERR(res))
+               goto err;
+
+       devmem = kzalloc(sizeof(*devmem), GFP_KERNEL);
+       if (!devmem)
+               goto err_release;
+
+       devmem->pagemap.type = MEMORY_DEVICE_PRIVATE;
+       devmem->pagemap.res = *res;
+       devmem->pagemap.ops = &dmirror_devmem_ops;
+       devmem->pagemap.owner = mdevice;
+
+       ptr = memremap_pages(&devmem->pagemap, numa_node_id());
+       if (IS_ERR(ptr))
+               goto err_free;
+
+       devmem->mdevice = mdevice;
+       pfn_first = devmem->pagemap.res.start >> PAGE_SHIFT;
+       pfn_last = pfn_first +
+               (resource_size(&devmem->pagemap.res) >> PAGE_SHIFT);
+       mdevice->devmem_chunks[mdevice->devmem_count++] = devmem;
+
+       mutex_unlock(&mdevice->devmem_lock);
+
+       pr_info("added new %u MB chunk (total %u chunks, %u MB) PFNs [0x%lx 0x%lx)\n",
+               DEVMEM_CHUNK_SIZE / (1024 * 1024),
+               mdevice->devmem_count,
+               mdevice->devmem_count * (DEVMEM_CHUNK_SIZE / (1024 * 1024)),
+               pfn_first, pfn_last);
+
+       spin_lock(&mdevice->lock);
+       for (pfn = pfn_first; pfn < pfn_last; pfn++) {
+               struct page *page = pfn_to_page(pfn);
+
+               page->zone_device_data = mdevice->free_pages;
+               mdevice->free_pages = page;
+       }
+       if (ppage) {
+               *ppage = mdevice->free_pages;
+               mdevice->free_pages = (*ppage)->zone_device_data;
+               mdevice->calloc++;
+       }
+       spin_unlock(&mdevice->lock);
+
+       return true;
+
+err_free:
+       kfree(devmem);
+err_release:
+       release_mem_region(devmem->pagemap.res.start,
+                          resource_size(&devmem->pagemap.res));
+err:
+       mutex_unlock(&mdevice->devmem_lock);
+       return false;
+}
+
+static struct page *dmirror_devmem_alloc_page(struct dmirror_device *mdevice)
+{
+       struct page *dpage = NULL;
+       struct page *rpage;
+
+       /*
+        * This is a fake device so we alloc real system memory to store
+        * our device memory.
+        */
+       rpage = alloc_page(GFP_HIGHUSER);
+       if (!rpage)
+               return NULL;
+
+       spin_lock(&mdevice->lock);
+
+       if (mdevice->free_pages) {
+               dpage = mdevice->free_pages;
+               mdevice->free_pages = dpage->zone_device_data;
+               mdevice->calloc++;
+               spin_unlock(&mdevice->lock);
+       } else {
+               spin_unlock(&mdevice->lock);
+               if (!dmirror_allocate_chunk(mdevice, &dpage))
+                       goto error;
+       }
+
+       dpage->zone_device_data = rpage;
+       get_page(dpage);
+       lock_page(dpage);
+       return dpage;
+
+error:
+       __free_page(rpage);
+       return NULL;
+}
+
+static void dmirror_migrate_alloc_and_copy(struct migrate_vma *args,
+                                          struct dmirror *dmirror)
+{
+       struct dmirror_device *mdevice = dmirror->mdevice;
+       const unsigned long *src = args->src;
+       unsigned long *dst = args->dst;
+       unsigned long addr;
+
+       for (addr = args->start; addr < args->end; addr += PAGE_SIZE,
+                                                  src++, dst++) {
+               struct page *spage;
+               struct page *dpage;
+               struct page *rpage;
+
+               if (!(*src & MIGRATE_PFN_MIGRATE))
+                       continue;
+
+               /*
+                * Note that spage might be NULL which is OK since it is an
+                * unallocated pte_none() or read-only zero page.
+                */
+               spage = migrate_pfn_to_page(*src);
+
+               /*
+                * Don't migrate device private pages from our own driver or
+                * others. For our own we would do a device private memory copy
+                * not a migration and for others, we would need to fault the
+                * other device's page into system memory first.
+                */
+               if (spage && is_zone_device_page(spage))
+                       continue;
+
+               dpage = dmirror_devmem_alloc_page(mdevice);
+               if (!dpage)
+                       continue;
+
+               rpage = dpage->zone_device_data;
+               if (spage)
+                       copy_highpage(rpage, spage);
+               else
+                       clear_highpage(rpage);
+
+               /*
+                * Normally, a device would use the page->zone_device_data to
+                * point to the mirror but here we use it to hold the page for
+                * the simulated device memory and that page holds the pointer
+                * to the mirror.
+                */
+               rpage->zone_device_data = dmirror;
+
+               *dst = migrate_pfn(page_to_pfn(dpage)) |
+                           MIGRATE_PFN_LOCKED;
+               if ((*src & MIGRATE_PFN_WRITE) ||
+                   (!spage && args->vma->vm_flags & VM_WRITE))
+                       *dst |= MIGRATE_PFN_WRITE;
+       }
+}
+
+static int dmirror_migrate_finalize_and_map(struct migrate_vma *args,
+                                           struct dmirror *dmirror)
+{
+       unsigned long start = args->start;
+       unsigned long end = args->end;
+       const unsigned long *src = args->src;
+       const unsigned long *dst = args->dst;
+       unsigned long pfn;
+
+       /* Map the migrated pages into the device's page tables. */
+       mutex_lock(&dmirror->mutex);
+
+       for (pfn = start >> PAGE_SHIFT; pfn < (end >> PAGE_SHIFT); pfn++,
+                                                               src++, dst++) {
+               struct page *dpage;
+               void *entry;
+
+               if (!(*src & MIGRATE_PFN_MIGRATE))
+                       continue;
+
+               dpage = migrate_pfn_to_page(*dst);
+               if (!dpage)
+                       continue;
+
+               /*
+                * Store the page that holds the data so the page table
+                * doesn't have to deal with ZONE_DEVICE private pages.
+                */
+               entry = dpage->zone_device_data;
+               if (*dst & MIGRATE_PFN_WRITE)
+                       entry = xa_tag_pointer(entry, DPT_XA_TAG_WRITE);
+               entry = xa_store(&dmirror->pt, pfn, entry, GFP_ATOMIC);
+               if (xa_is_err(entry)) {
+                       mutex_unlock(&dmirror->mutex);
+                       return xa_err(entry);
+               }
+       }
+
+       mutex_unlock(&dmirror->mutex);
+       return 0;
+}
+
+static int dmirror_migrate(struct dmirror *dmirror,
+                          struct hmm_dmirror_cmd *cmd)
+{
+       unsigned long start, end, addr;
+       unsigned long size = cmd->npages << PAGE_SHIFT;
+       struct mm_struct *mm = dmirror->notifier.mm;
+       struct vm_area_struct *vma;
+       unsigned long src_pfns[64];
+       unsigned long dst_pfns[64];
+       struct dmirror_bounce bounce;
+       struct migrate_vma args;
+       unsigned long next;
+       int ret;
+
+       start = cmd->addr;
+       end = start + size;
+       if (end < start)
+               return -EINVAL;
+
+       /* Since the mm is for the mirrored process, get a reference first. */
+       if (!mmget_not_zero(mm))
+               return -EINVAL;
+
+       down_read(&mm->mmap_sem);
+       for (addr = start; addr < end; addr = next) {
+               vma = find_vma(mm, addr);
+               if (!vma || addr < vma->vm_start ||
+                   !(vma->vm_flags & VM_READ)) {
+                       ret = -EINVAL;
+                       goto out;
+               }
+               next = min(end, addr + (ARRAY_SIZE(src_pfns) << PAGE_SHIFT));
+               if (next > vma->vm_end)
+                       next = vma->vm_end;
+
+               args.vma = vma;
+               args.src = src_pfns;
+               args.dst = dst_pfns;
+               args.start = addr;
+               args.end = next;
+               args.src_owner = NULL;
+               ret = migrate_vma_setup(&args);
+               if (ret)
+                       goto out;
+
+               dmirror_migrate_alloc_and_copy(&args, dmirror);
+               migrate_vma_pages(&args);
+               dmirror_migrate_finalize_and_map(&args, dmirror);
+               migrate_vma_finalize(&args);
+       }
+       up_read(&mm->mmap_sem);
+       mmput(mm);
+
+       /* Return the migrated data for verification. */
+       ret = dmirror_bounce_init(&bounce, start, size);
+       if (ret)
+               return ret;
+       mutex_lock(&dmirror->mutex);
+       ret = dmirror_do_read(dmirror, start, end, &bounce);
+       mutex_unlock(&dmirror->mutex);
+       if (ret == 0) {
+               if (copy_to_user(u64_to_user_ptr(cmd->ptr), bounce.ptr,
+                                bounce.size))
+                       ret = -EFAULT;
+       }
+       cmd->cpages = bounce.cpages;
+       dmirror_bounce_fini(&bounce);
+       return ret;
+
+out:
+       up_read(&mm->mmap_sem);
+       mmput(mm);
+       return ret;
+}
+
+static void dmirror_mkentry(struct dmirror *dmirror, struct hmm_range *range,
+                           unsigned char *perm, unsigned long entry)
+{
+       struct page *page;
+
+       if (entry & HMM_PFN_ERROR) {
+               *perm = HMM_DMIRROR_PROT_ERROR;
+               return;
+       }
+       if (!(entry & HMM_PFN_VALID)) {
+               *perm = HMM_DMIRROR_PROT_NONE;
+               return;
+       }
+
+       page = hmm_pfn_to_page(entry);
+       if (is_device_private_page(page)) {
+               /* Is the page migrated to this device or some other? */
+               if (dmirror->mdevice == dmirror_page_to_device(page))
+                       *perm = HMM_DMIRROR_PROT_DEV_PRIVATE_LOCAL;
+               else
+                       *perm = HMM_DMIRROR_PROT_DEV_PRIVATE_REMOTE;
+       } else if (is_zero_pfn(page_to_pfn(page)))
+               *perm = HMM_DMIRROR_PROT_ZERO;
+       else
+               *perm = HMM_DMIRROR_PROT_NONE;
+       if (entry & HMM_PFN_WRITE)
+               *perm |= HMM_DMIRROR_PROT_WRITE;
+       else
+               *perm |= HMM_DMIRROR_PROT_READ;
+}
+
+static bool dmirror_snapshot_invalidate(struct mmu_interval_notifier *mni,
+                               const struct mmu_notifier_range *range,
+                               unsigned long cur_seq)
+{
+       struct dmirror_interval *dmi =
+               container_of(mni, struct dmirror_interval, notifier);
+       struct dmirror *dmirror = dmi->dmirror;
+
+       if (mmu_notifier_range_blockable(range))
+               mutex_lock(&dmirror->mutex);
+       else if (!mutex_trylock(&dmirror->mutex))
+               return false;
+
+       /*
+        * Snapshots only need to set the sequence number since any
+        * invalidation in the interval invalidates the whole snapshot.
+        */
+       mmu_interval_set_seq(mni, cur_seq);
+
+       mutex_unlock(&dmirror->mutex);
+       return true;
+}
+
+static const struct mmu_interval_notifier_ops dmirror_mrn_ops = {
+       .invalidate = dmirror_snapshot_invalidate,
+};
+
+static int dmirror_range_snapshot(struct dmirror *dmirror,
+                                 struct hmm_range *range,
+                                 unsigned char *perm)
+{
+       struct mm_struct *mm = dmirror->notifier.mm;
+       struct dmirror_interval notifier;
+       unsigned long timeout =
+               jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
+       unsigned long i;
+       unsigned long n;
+       int ret = 0;
+
+       notifier.dmirror = dmirror;
+       range->notifier = &notifier.notifier;
+
+       ret = mmu_interval_notifier_insert(range->notifier, mm,
+                       range->start, range->end - range->start,
+                       &dmirror_mrn_ops);
+       if (ret)
+               return ret;
+
+       while (true) {
+               if (time_after(jiffies, timeout)) {
+                       ret = -EBUSY;
+                       goto out;
+               }
+
+               range->notifier_seq = mmu_interval_read_begin(range->notifier);
+
+               down_read(&mm->mmap_sem);
+               ret = hmm_range_fault(range);
+               up_read(&mm->mmap_sem);
+               if (ret) {
+                       if (ret == -EBUSY)
+                               continue;
+                       goto out;
+               }
+
+               mutex_lock(&dmirror->mutex);
+               if (mmu_interval_read_retry(range->notifier,
+                                           range->notifier_seq)) {
+                       mutex_unlock(&dmirror->mutex);
+                       continue;
+               }
+               break;
+       }
+
+       n = (range->end - range->start) >> PAGE_SHIFT;
+       for (i = 0; i < n; i++)
+               dmirror_mkentry(dmirror, range, perm + i, range->hmm_pfns[i]);
+
+       mutex_unlock(&dmirror->mutex);
+out:
+       mmu_interval_notifier_remove(range->notifier);
+       return ret;
+}
+
+static int dmirror_snapshot(struct dmirror *dmirror,
+                           struct hmm_dmirror_cmd *cmd)
+{
+       struct mm_struct *mm = dmirror->notifier.mm;
+       unsigned long start, end;
+       unsigned long size = cmd->npages << PAGE_SHIFT;
+       unsigned long addr;
+       unsigned long next;
+       unsigned long pfns[64];
+       unsigned char perm[64];
+       char __user *uptr;
+       struct hmm_range range = {
+               .hmm_pfns = pfns,
+               .dev_private_owner = dmirror->mdevice,
+       };
+       int ret = 0;
+
+       start = cmd->addr;
+       end = start + size;
+       if (end < start)
+               return -EINVAL;
+
+       /* Since the mm is for the mirrored process, get a reference first. */
+       if (!mmget_not_zero(mm))
+               return -EINVAL;
+
+       /*
+        * Register a temporary notifier to detect invalidations even if it
+        * overlaps with other mmu_interval_notifiers.
+        */
+       uptr = u64_to_user_ptr(cmd->ptr);
+       for (addr = start; addr < end; addr = next) {
+               unsigned long n;
+
+               next = min(addr + (ARRAY_SIZE(pfns) << PAGE_SHIFT), end);
+               range.start = addr;
+               range.end = next;
+
+               ret = dmirror_range_snapshot(dmirror, &range, perm);
+               if (ret)
+                       break;
+
+               n = (range.end - range.start) >> PAGE_SHIFT;
+               if (copy_to_user(uptr, perm, n)) {
+                       ret = -EFAULT;
+                       break;
+               }
+
+               cmd->cpages += n;
+               uptr += n;
+       }
+       mmput(mm);
+
+       return ret;
+}
+
+static long dmirror_fops_unlocked_ioctl(struct file *filp,
+                                       unsigned int command,
+                                       unsigned long arg)
+{
+       void __user *uarg = (void __user *)arg;
+       struct hmm_dmirror_cmd cmd;
+       struct dmirror *dmirror;
+       int ret;
+
+       dmirror = filp->private_data;
+       if (!dmirror)
+               return -EINVAL;
+
+       if (copy_from_user(&cmd, uarg, sizeof(cmd)))
+               return -EFAULT;
+
+       if (cmd.addr & ~PAGE_MASK)
+               return -EINVAL;
+       if (cmd.addr >= (cmd.addr + (cmd.npages << PAGE_SHIFT)))
+               return -EINVAL;
+
+       cmd.cpages = 0;
+       cmd.faults = 0;
+
+       switch (command) {
+       case HMM_DMIRROR_READ:
+               ret = dmirror_read(dmirror, &cmd);
+               break;
+
+       case HMM_DMIRROR_WRITE:
+               ret = dmirror_write(dmirror, &cmd);
+               break;
+
+       case HMM_DMIRROR_MIGRATE:
+               ret = dmirror_migrate(dmirror, &cmd);
+               break;
+
+       case HMM_DMIRROR_SNAPSHOT:
+               ret = dmirror_snapshot(dmirror, &cmd);
+               break;
+
+       default:
+               return -EINVAL;
+       }
+       if (ret)
+               return ret;
+
+       if (copy_to_user(uarg, &cmd, sizeof(cmd)))
+               return -EFAULT;
+
+       return 0;
+}
+
+static const struct file_operations dmirror_fops = {
+       .open           = dmirror_fops_open,
+       .release        = dmirror_fops_release,
+       .unlocked_ioctl = dmirror_fops_unlocked_ioctl,
+       .llseek         = default_llseek,
+       .owner          = THIS_MODULE,
+};
+
+static void dmirror_devmem_free(struct page *page)
+{
+       struct page *rpage = page->zone_device_data;
+       struct dmirror_device *mdevice;
+
+       if (rpage)
+               __free_page(rpage);
+
+       mdevice = dmirror_page_to_device(page);
+
+       spin_lock(&mdevice->lock);
+       mdevice->cfree++;
+       page->zone_device_data = mdevice->free_pages;
+       mdevice->free_pages = page;
+       spin_unlock(&mdevice->lock);
+}
+
+static vm_fault_t dmirror_devmem_fault_alloc_and_copy(struct migrate_vma *args,
+                                               struct dmirror_device *mdevice)
+{
+       const unsigned long *src = args->src;
+       unsigned long *dst = args->dst;
+       unsigned long start = args->start;
+       unsigned long end = args->end;
+       unsigned long addr;
+
+       for (addr = start; addr < end; addr += PAGE_SIZE,
+                                      src++, dst++) {
+               struct page *dpage, *spage;
+
+               spage = migrate_pfn_to_page(*src);
+               if (!spage || !(*src & MIGRATE_PFN_MIGRATE))
+                       continue;
+               spage = spage->zone_device_data;
+
+               dpage = alloc_page_vma(GFP_HIGHUSER_MOVABLE, args->vma, addr);
+               if (!dpage)
+                       continue;
+
+               lock_page(dpage);
+               copy_highpage(dpage, spage);
+               *dst = migrate_pfn(page_to_pfn(dpage)) | MIGRATE_PFN_LOCKED;
+               if (*src & MIGRATE_PFN_WRITE)
+                       *dst |= MIGRATE_PFN_WRITE;
+       }
+       return 0;
+}
+
+static void dmirror_devmem_fault_finalize_and_map(struct migrate_vma *args,
+                                                 struct dmirror *dmirror)
+{
+       /* Invalidate the device's page table mapping. */
+       mutex_lock(&dmirror->mutex);
+       dmirror_do_update(dmirror, args->start, args->end);
+       mutex_unlock(&dmirror->mutex);
+}
+
+static vm_fault_t dmirror_devmem_fault(struct vm_fault *vmf)
+{
+       struct migrate_vma args;
+       unsigned long src_pfns;
+       unsigned long dst_pfns;
+       struct page *rpage;
+       struct dmirror *dmirror;
+       vm_fault_t ret;
+
+       /*
+        * Normally, a device would use the page->zone_device_data to point to
+        * the mirror but here we use it to hold the page for the simulated
+        * device memory and that page holds the pointer to the mirror.
+        */
+       rpage = vmf->page->zone_device_data;
+       dmirror = rpage->zone_device_data;
+
+       /* FIXME demonstrate how we can adjust migrate range */
+       args.vma = vmf->vma;
+       args.start = vmf->address;
+       args.end = args.start + PAGE_SIZE;
+       args.src = &src_pfns;
+       args.dst = &dst_pfns;
+       args.src_owner = dmirror->mdevice;
+
+       if (migrate_vma_setup(&args))
+               return VM_FAULT_SIGBUS;
+
+       ret = dmirror_devmem_fault_alloc_and_copy(&args, dmirror->mdevice);
+       if (ret)
+               return ret;
+       migrate_vma_pages(&args);
+       dmirror_devmem_fault_finalize_and_map(&args, dmirror);
+       migrate_vma_finalize(&args);
+       return 0;
+}
+
+static const struct dev_pagemap_ops dmirror_devmem_ops = {
+       .page_free      = dmirror_devmem_free,
+       .migrate_to_ram = dmirror_devmem_fault,
+};
+
+static int dmirror_device_init(struct dmirror_device *mdevice, int id)
+{
+       dev_t dev;
+       int ret;
+
+       dev = MKDEV(MAJOR(dmirror_dev), id);
+       mutex_init(&mdevice->devmem_lock);
+       spin_lock_init(&mdevice->lock);
+
+       cdev_init(&mdevice->cdevice, &dmirror_fops);
+       mdevice->cdevice.owner = THIS_MODULE;
+       ret = cdev_add(&mdevice->cdevice, dev, 1);
+       if (ret)
+               return ret;
+
+       /* Build a list of free ZONE_DEVICE private struct pages */
+       dmirror_allocate_chunk(mdevice, NULL);
+
+       return 0;
+}
+
+static void dmirror_device_remove(struct dmirror_device *mdevice)
+{
+       unsigned int i;
+
+       if (mdevice->devmem_chunks) {
+               for (i = 0; i < mdevice->devmem_count; i++) {
+                       struct dmirror_chunk *devmem =
+                               mdevice->devmem_chunks[i];
+
+                       memunmap_pages(&devmem->pagemap);
+                       release_mem_region(devmem->pagemap.res.start,
+                                          resource_size(&devmem->pagemap.res));
+                       kfree(devmem);
+               }
+               kfree(mdevice->devmem_chunks);
+       }
+
+       cdev_del(&mdevice->cdevice);
+}
+
+static int __init hmm_dmirror_init(void)
+{
+       int ret;
+       int id;
+
+       ret = alloc_chrdev_region(&dmirror_dev, 0, DMIRROR_NDEVICES,
+                                 "HMM_DMIRROR");
+       if (ret)
+               goto err_unreg;
+
+       for (id = 0; id < DMIRROR_NDEVICES; id++) {
+               ret = dmirror_device_init(dmirror_devices + id, id);
+               if (ret)
+                       goto err_chrdev;
+       }
+
+       /*
+        * Allocate a zero page to simulate a reserved page of device private
+        * memory which is always zero. The zero_pfn page isn't used just to
+        * make the code here simpler (i.e., we need a struct page for it).
+        */
+       dmirror_zero_page = alloc_page(GFP_HIGHUSER | __GFP_ZERO);
+       if (!dmirror_zero_page) {
+               ret = -ENOMEM;
+               goto err_chrdev;
+       }
+
+       pr_info("HMM test module loaded. This is only for testing HMM.\n");
+       return 0;
+
+err_chrdev:
+       while (--id >= 0)
+               dmirror_device_remove(dmirror_devices + id);
+       unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
+err_unreg:
+       return ret;
+}
+
+static void __exit hmm_dmirror_exit(void)
+{
+       int id;
+
+       if (dmirror_zero_page)
+               __free_page(dmirror_zero_page);
+       for (id = 0; id < DMIRROR_NDEVICES; id++)
+               dmirror_device_remove(dmirror_devices + id);
+       unregister_chrdev_region(dmirror_dev, DMIRROR_NDEVICES);
+}
+
+module_init(hmm_dmirror_init);
+module_exit(hmm_dmirror_exit);
+MODULE_LICENSE("GPL");
diff --git a/lib/test_hmm_uapi.h b/lib/test_hmm_uapi.h
new file mode 100644 (file)
index 0000000..67b3b2e
--- /dev/null
@@ -0,0 +1,59 @@
+/* SPDX-License-Identifier: GPL-2.0 WITH Linux-syscall-note */
+/*
+ * This is a module to test the HMM (Heterogeneous Memory Management) API
+ * of the kernel. It allows a userspace program to expose its entire address
+ * space through the HMM test module device file.
+ */
+#ifndef _LIB_TEST_HMM_UAPI_H
+#define _LIB_TEST_HMM_UAPI_H
+
+#include <linux/types.h>
+#include <linux/ioctl.h>
+
+/*
+ * Structure to pass to the HMM test driver to mimic a device accessing
+ * system memory and ZONE_DEVICE private memory through device page tables.
+ *
+ * @addr: (in) user address the device will read/write
+ * @ptr: (in) user address where device data is copied to/from
+ * @npages: (in) number of pages to read/write
+ * @cpages: (out) number of pages copied
+ * @faults: (out) number of device page faults seen
+ */
+struct hmm_dmirror_cmd {
+       __u64           addr;
+       __u64           ptr;
+       __u64           npages;
+       __u64           cpages;
+       __u64           faults;
+};
+
+/* Expose the address space of the calling process through hmm device file */
+#define HMM_DMIRROR_READ               _IOWR('H', 0x00, struct hmm_dmirror_cmd)
+#define HMM_DMIRROR_WRITE              _IOWR('H', 0x01, struct hmm_dmirror_cmd)
+#define HMM_DMIRROR_MIGRATE            _IOWR('H', 0x02, struct hmm_dmirror_cmd)
+#define HMM_DMIRROR_SNAPSHOT           _IOWR('H', 0x03, struct hmm_dmirror_cmd)
+
+/*
+ * Values returned in hmm_dmirror_cmd.ptr for HMM_DMIRROR_SNAPSHOT.
+ * HMM_DMIRROR_PROT_ERROR: no valid mirror PTE for this page
+ * HMM_DMIRROR_PROT_NONE: unpopulated PTE or PTE with no access
+ * HMM_DMIRROR_PROT_READ: read-only PTE
+ * HMM_DMIRROR_PROT_WRITE: read/write PTE
+ * HMM_DMIRROR_PROT_ZERO: special read-only zero page
+ * HMM_DMIRROR_PROT_DEV_PRIVATE_LOCAL: Migrated device private page on the
+ *                                     device the ioctl() is made
+ * HMM_DMIRROR_PROT_DEV_PRIVATE_REMOTE: Migrated device private page on some
+ *                                     other device
+ */
+enum {
+       HMM_DMIRROR_PROT_ERROR                  = 0xFF,
+       HMM_DMIRROR_PROT_NONE                   = 0x00,
+       HMM_DMIRROR_PROT_READ                   = 0x01,
+       HMM_DMIRROR_PROT_WRITE                  = 0x02,
+       HMM_DMIRROR_PROT_ZERO                   = 0x10,
+       HMM_DMIRROR_PROT_DEV_PRIVATE_LOCAL      = 0x20,
+       HMM_DMIRROR_PROT_DEV_PRIVATE_REMOTE     = 0x30,
+};
+
+#endif /* _LIB_TEST_HMM_UAPI_H */