static u32 kvm_pmu_event_mask(struct kvm *kvm)
 {
-       switch (kvm->arch.pmuver) {
+       unsigned int pmuver;
+
+       pmuver = kvm->arch.arm_pmu->pmuver;
+
+       switch (pmuver) {
        case ID_AA64DFR0_PMUVER_8_0:
                return GENMASK(9, 0);
        case ID_AA64DFR0_PMUVER_8_1:
        case ID_AA64DFR0_PMUVER_8_7:
                return GENMASK(15, 0);
        default:                /* Shouldn't be here, just for sanity */
-               WARN_ONCE(1, "Unknown PMU version %d\n", kvm->arch.pmuver);
+               WARN_ONCE(1, "Unknown PMU version %d\n", pmuver);
                return 0;
        }
 }
  */
 static void kvm_pmu_create_perf_event(struct kvm_vcpu *vcpu, u64 select_idx)
 {
+       struct arm_pmu *arm_pmu = vcpu->kvm->arch.arm_pmu;
        struct kvm_pmu *pmu = &vcpu->arch.pmu;
        struct kvm_pmc *pmc;
        struct perf_event *event;
                return;
 
        memset(&attr, 0, sizeof(struct perf_event_attr));
-       attr.type = PERF_TYPE_RAW;
+       attr.type = arm_pmu->pmu.type;
        attr.size = sizeof(attr);
        attr.pinned = 1;
        attr.disabled = !kvm_pmu_counter_is_enabled(vcpu, pmc->idx);
                static_branch_enable(&kvm_arm_pmu_available);
 }
 
-static int kvm_pmu_probe_pmuver(void)
+static struct arm_pmu *kvm_pmu_probe_armpmu(void)
 {
        struct perf_event_attr attr = { };
        struct perf_event *event;
-       struct arm_pmu *pmu;
-       int pmuver = ID_AA64DFR0_PMUVER_IMP_DEF;
+       struct arm_pmu *pmu = NULL;
 
        /*
         * Create a dummy event that only counts user cycles. As we'll never
        if (IS_ERR(event)) {
                pr_err_once("kvm: pmu event creation failed %ld\n",
                            PTR_ERR(event));
-               return ID_AA64DFR0_PMUVER_IMP_DEF;
+               return NULL;
        }
 
        if (event->pmu) {
                pmu = to_arm_pmu(event->pmu);
-               if (pmu->pmuver)
-                       pmuver = pmu->pmuver;
+               if (pmu->pmuver == 0 ||
+                   pmu->pmuver == ID_AA64DFR0_PMUVER_IMP_DEF)
+                       pmu = NULL;
        }
 
        perf_event_disable(event);
        perf_event_release_kernel(event);
 
-       return pmuver;
+       return pmu;
 }
 
 u64 kvm_pmu_get_pmceid(struct kvm_vcpu *vcpu, bool pmceid1)
                 * Don't advertise STALL_SLOT, as PMMIR_EL0 is handled
                 * as RAZ
                 */
-               if (vcpu->kvm->arch.pmuver >= ID_AA64DFR0_PMUVER_8_4)
+               if (vcpu->kvm->arch.arm_pmu->pmuver >= ID_AA64DFR0_PMUVER_8_4)
                        val &= ~BIT_ULL(ARMV8_PMUV3_PERFCTR_STALL_SLOT - 32);
                base = 32;
        }
        if (vcpu->arch.pmu.created)
                return -EBUSY;
 
-       if (!vcpu->kvm->arch.pmuver)
-               vcpu->kvm->arch.pmuver = kvm_pmu_probe_pmuver();
-
-       if (vcpu->kvm->arch.pmuver == ID_AA64DFR0_PMUVER_IMP_DEF)
-               return -ENODEV;
+       mutex_lock(&kvm->lock);
+       if (!kvm->arch.arm_pmu) {
+               /* No PMU set, get the default one */
+               kvm->arch.arm_pmu = kvm_pmu_probe_armpmu();
+               if (!kvm->arch.arm_pmu) {
+                       mutex_unlock(&kvm->lock);
+                       return -ENODEV;
+               }
+       }
+       mutex_unlock(&kvm->lock);
 
        switch (attr->attr) {
        case KVM_ARM_VCPU_PMU_V3_IRQ: {