return match;
 }
 
+static int dax_match_type(struct dax_device_driver *dax_drv, struct device *dev)
+{
+       enum dax_driver_type type = DAXDRV_DEVICE_TYPE;
+       struct dev_dax *dev_dax = to_dev_dax(dev);
+
+       if (dev_dax->region->res.flags & IORESOURCE_DAX_KMEM)
+               type = DAXDRV_KMEM_TYPE;
+
+       if (dax_drv->type == type)
+               return 1;
+
+       /* default to device mode if dax_kmem is disabled */
+       if (dax_drv->type == DAXDRV_DEVICE_TYPE &&
+           !IS_ENABLED(CONFIG_DEV_DAX_KMEM))
+               return 1;
+
+       return 0;
+}
+
 enum id_action {
        ID_REMOVE,
        ID_ADD,
 {
        struct dax_device_driver *dax_drv = to_dax_drv(drv);
 
-       /*
-        * All but the 'device-dax' driver, which has 'match_always'
-        * set, requires an exact id match.
-        */
-       if (dax_drv->match_always)
+       if (dax_match_id(dax_drv, dev))
                return 1;
-
-       return dax_match_id(dax_drv, dev);
+       return dax_match_type(dax_drv, dev);
 }
 
 /*
 }
 EXPORT_SYMBOL_GPL(devm_create_dev_dax);
 
-static int match_always_count;
-
 int __dax_driver_register(struct dax_device_driver *dax_drv,
                struct module *module, const char *mod_name)
 {
        struct device_driver *drv = &dax_drv->drv;
-       int rc = 0;
 
        /*
         * dax_bus_probe() calls dax_drv->probe() unconditionally.
        drv->mod_name = mod_name;
        drv->bus = &dax_bus_type;
 
-       /* there can only be one default driver */
-       mutex_lock(&dax_bus_lock);
-       match_always_count += dax_drv->match_always;
-       if (match_always_count > 1) {
-               match_always_count--;
-               WARN_ON(1);
-               rc = -EINVAL;
-       }
-       mutex_unlock(&dax_bus_lock);
-       if (rc)
-               return rc;
-
-       rc = driver_register(drv);
-       if (rc && dax_drv->match_always) {
-               mutex_lock(&dax_bus_lock);
-               match_always_count -= dax_drv->match_always;
-               mutex_unlock(&dax_bus_lock);
-       }
-
-       return rc;
+       return driver_register(drv);
 }
 EXPORT_SYMBOL_GPL(__dax_driver_register);
 
        struct dax_id *dax_id, *_id;
 
        mutex_lock(&dax_bus_lock);
-       match_always_count -= dax_drv->match_always;
        list_for_each_entry_safe(dax_id, _id, &dax_drv->ids, list) {
                list_del(&dax_id->list);
                kfree(dax_id);
 
 struct dax_region;
 void dax_region_put(struct dax_region *dax_region);
 
-#define IORESOURCE_DAX_STATIC (1UL << 0)
+/* dax bus specific ioresource flags */
+#define IORESOURCE_DAX_STATIC BIT(0)
+#define IORESOURCE_DAX_KMEM BIT(1)
+
 struct dax_region *alloc_dax_region(struct device *parent, int region_id,
                struct range *range, int target_node, unsigned int align,
                unsigned long flags);
 
 struct dev_dax *devm_create_dev_dax(struct dev_dax_data *data);
 
+enum dax_driver_type {
+       DAXDRV_KMEM_TYPE,
+       DAXDRV_DEVICE_TYPE,
+};
+
 struct dax_device_driver {
        struct device_driver drv;
        struct list_head ids;
-       int match_always;
+       enum dax_driver_type type;
        int (*probe)(struct dev_dax *dev);
        void (*remove)(struct dev_dax *dev);
 };
 
 
 static int dax_hmem_probe(struct platform_device *pdev)
 {
+       unsigned long flags = IORESOURCE_DAX_KMEM;
        struct device *dev = &pdev->dev;
        struct dax_region *dax_region;
        struct memregion_info *mri;
        struct dev_dax_data data;
        struct dev_dax *dev_dax;
 
+       /*
+        * @region_idle == true indicates that an administrative agent
+        * wants to manipulate the range partitioning before the devices
+        * are created, so do not send them to the dax_kmem driver by
+        * default.
+        */
+       if (region_idle)
+               flags = 0;
+
        mri = dev->platform_data;
        dax_region = alloc_dax_region(dev, pdev->id, &mri->range,
-                                     mri->target_node, PMD_SIZE, 0);
+                                     mri->target_node, PMD_SIZE, flags);
        if (!dax_region)
                return -ENOMEM;