#include "../../../../mm/gup_test.h"
 #include "../kselftest.h"
 #include "vm_util.h"
+#include "thp_settings.h"
 
 static size_t pagesize;
 static int pagemap_fd;
 static size_t pmdsize;
+static int nr_thpsizes;
+static size_t thpsizes[20];
 static int nr_hugetlbsizes;
 static size_t hugetlbsizes[10];
 static int gup_fd;
 static bool has_huge_zeropage;
 
+static int sz2ord(size_t size)
+{
+       return __builtin_ctzll(size / pagesize);
+}
+
+static int detect_thp_sizes(size_t sizes[], int max)
+{
+       int count = 0;
+       unsigned long orders;
+       size_t kb;
+       int i;
+
+       /* thp not supported at all. */
+       if (!pmdsize)
+               return 0;
+
+       orders = 1UL << sz2ord(pmdsize);
+       orders |= thp_supported_orders();
+
+       for (i = 0; orders && count < max; i++) {
+               if (!(orders & (1UL << i)))
+                       continue;
+               orders &= ~(1UL << i);
+               kb = (pagesize >> 10) << i;
+               sizes[count++] = kb * 1024;
+               ksft_print_msg("[INFO] detected THP size: %zu KiB\n", kb);
+       }
+
+       return count;
+}
+
 static void detect_huge_zeropage(void)
 {
        int fd = open("/sys/kernel/mm/transparent_hugepage/use_zero_page",
 
        run_with_base_page(test_case->fn, test_case->desc);
        run_with_base_page_swap(test_case->fn, test_case->desc);
-       if (pmdsize) {
-               run_with_thp(test_case->fn, test_case->desc, pmdsize);
-               run_with_thp_swap(test_case->fn, test_case->desc, pmdsize);
-               run_with_pte_mapped_thp(test_case->fn, test_case->desc, pmdsize);
-               run_with_pte_mapped_thp_swap(test_case->fn, test_case->desc, pmdsize);
-               run_with_single_pte_of_thp(test_case->fn, test_case->desc, pmdsize);
-               run_with_single_pte_of_thp_swap(test_case->fn, test_case->desc, pmdsize);
-               run_with_partial_mremap_thp(test_case->fn, test_case->desc, pmdsize);
-               run_with_partial_shared_thp(test_case->fn, test_case->desc, pmdsize);
+       for (i = 0; i < nr_thpsizes; i++) {
+               size_t size = thpsizes[i];
+               struct thp_settings settings = *thp_current_settings();
+
+               settings.hugepages[sz2ord(pmdsize)].enabled = THP_NEVER;
+               settings.hugepages[sz2ord(size)].enabled = THP_ALWAYS;
+               thp_push_settings(&settings);
+
+               if (size == pmdsize) {
+                       run_with_thp(test_case->fn, test_case->desc, size);
+                       run_with_thp_swap(test_case->fn, test_case->desc, size);
+               }
+
+               run_with_pte_mapped_thp(test_case->fn, test_case->desc, size);
+               run_with_pte_mapped_thp_swap(test_case->fn, test_case->desc, size);
+               run_with_single_pte_of_thp(test_case->fn, test_case->desc, size);
+               run_with_single_pte_of_thp_swap(test_case->fn, test_case->desc, size);
+               run_with_partial_mremap_thp(test_case->fn, test_case->desc, size);
+               run_with_partial_shared_thp(test_case->fn, test_case->desc, size);
+
+               thp_pop_settings();
        }
        for (i = 0; i < nr_hugetlbsizes; i++)
                run_with_hugetlb(test_case->fn, test_case->desc,
 {
        int tests = 2 + nr_hugetlbsizes;
 
+       tests += 6 * nr_thpsizes;
        if (pmdsize)
-               tests += 8;
+               tests += 2;
        return tests;
 }
 
 int main(int argc, char **argv)
 {
        int err;
+       struct thp_settings default_settings;
 
        ksft_print_header();
 
        pagesize = getpagesize();
        pmdsize = read_pmd_pagesize();
        if (pmdsize) {
+               /* Only if THP is supported. */
+               thp_read_settings(&default_settings);
+               default_settings.hugepages[sz2ord(pmdsize)].enabled = THP_INHERIT;
+               thp_save_settings();
+               thp_push_settings(&default_settings);
+
                ksft_print_msg("[INFO] detected PMD size: %zu KiB\n",
                               pmdsize / 1024);
-               ksft_print_msg("[INFO] detected THP size: %zu KiB\n",
-                              pmdsize / 1024);
+               nr_thpsizes = detect_thp_sizes(thpsizes, ARRAY_SIZE(thpsizes));
        }
        nr_hugetlbsizes = detect_hugetlb_page_sizes(hugetlbsizes,
                                                    ARRAY_SIZE(hugetlbsizes));
        run_anon_thp_test_cases();
        run_non_anon_test_cases();
 
+       if (pmdsize) {
+               /* Only if THP is supported. */
+               thp_restore_settings();
+       }
+
        err = ksft_get_fail_cnt();
        if (err)
                ksft_exit_fail_msg("%d out of %d tests failed\n",