#include <asm/csr.h>
 #include <asm/kvm_vcpu_sbi.h>
 #include <asm/kvm_vcpu_pmu.h>
+#include <asm/sbi.h>
 #include <linux/bitops.h>
 
 #define kvm_pmu_num_counters(pmu) ((pmu)->num_hw_ctrs + (pmu)->num_fw_ctrs)
        return ret;
 }
 
+static void kvm_pmu_clear_snapshot_area(struct kvm_vcpu *vcpu)
+{
+       struct kvm_pmu *kvpmu = vcpu_to_pmu(vcpu);
+       int snapshot_area_size = sizeof(struct riscv_pmu_snapshot_data);
+
+       if (kvpmu->sdata) {
+               if (kvpmu->snapshot_addr != INVALID_GPA) {
+                       memset(kvpmu->sdata, 0, snapshot_area_size);
+                       kvm_vcpu_write_guest(vcpu, kvpmu->snapshot_addr,
+                                            kvpmu->sdata, snapshot_area_size);
+               } else {
+                       pr_warn("snapshot address invalid\n");
+               }
+               kfree(kvpmu->sdata);
+               kvpmu->sdata = NULL;
+       }
+       kvpmu->snapshot_addr = INVALID_GPA;
+}
+
+int kvm_riscv_vcpu_pmu_snapshot_set_shmem(struct kvm_vcpu *vcpu, unsigned long saddr_low,
+                                     unsigned long saddr_high, unsigned long flags,
+                                     struct kvm_vcpu_sbi_return *retdata)
+{
+       struct kvm_pmu *kvpmu = vcpu_to_pmu(vcpu);
+       int snapshot_area_size = sizeof(struct riscv_pmu_snapshot_data);
+       int sbiret = 0;
+       gpa_t saddr;
+       unsigned long hva;
+       bool writable;
+
+       if (!kvpmu || flags) {
+               sbiret = SBI_ERR_INVALID_PARAM;
+               goto out;
+       }
+
+       if (saddr_low == SBI_SHMEM_DISABLE && saddr_high == SBI_SHMEM_DISABLE) {
+               kvm_pmu_clear_snapshot_area(vcpu);
+               return 0;
+       }
+
+       saddr = saddr_low;
+
+       if (saddr_high != 0) {
+               if (IS_ENABLED(CONFIG_32BIT))
+                       saddr |= ((gpa_t)saddr_high << 32);
+               else
+                       sbiret = SBI_ERR_INVALID_ADDRESS;
+               goto out;
+       }
+
+       hva = kvm_vcpu_gfn_to_hva_prot(vcpu, saddr >> PAGE_SHIFT, &writable);
+       if (kvm_is_error_hva(hva) || !writable) {
+               sbiret = SBI_ERR_INVALID_ADDRESS;
+               goto out;
+       }
+
+       kvpmu->sdata = kzalloc(snapshot_area_size, GFP_ATOMIC);
+       if (!kvpmu->sdata)
+               return -ENOMEM;
+
+       if (kvm_vcpu_write_guest(vcpu, saddr, kvpmu->sdata, snapshot_area_size)) {
+               kfree(kvpmu->sdata);
+               sbiret = SBI_ERR_FAILURE;
+               goto out;
+       }
+
+       kvpmu->snapshot_addr = saddr;
+
+out:
+       retdata->err_val = sbiret;
+
+       return 0;
+}
+
 int kvm_riscv_vcpu_pmu_num_ctrs(struct kvm_vcpu *vcpu,
                                struct kvm_vcpu_sbi_return *retdata)
 {
        int i, pmc_index, sbiret = 0;
        struct kvm_pmc *pmc;
        int fevent_code;
+       bool snap_flag_set = flags & SBI_PMU_START_FLAG_INIT_SNAPSHOT;
 
        if (kvm_pmu_validate_counter_mask(kvpmu, ctr_base, ctr_mask) < 0) {
                sbiret = SBI_ERR_INVALID_PARAM;
                goto out;
        }
 
+       if (snap_flag_set) {
+               if (kvpmu->snapshot_addr == INVALID_GPA) {
+                       sbiret = SBI_ERR_NO_SHMEM;
+                       goto out;
+               }
+               if (kvm_vcpu_read_guest(vcpu, kvpmu->snapshot_addr, kvpmu->sdata,
+                                       sizeof(struct riscv_pmu_snapshot_data))) {
+                       pr_warn("Unable to read snapshot shared memory while starting counters\n");
+                       sbiret = SBI_ERR_FAILURE;
+                       goto out;
+               }
+       }
        /* Start the counters that have been configured and requested by the guest */
        for_each_set_bit(i, &ctr_mask, RISCV_MAX_COUNTERS) {
                pmc_index = i + ctr_base;
                if (!test_bit(pmc_index, kvpmu->pmc_in_use))
                        continue;
                pmc = &kvpmu->pmc[pmc_index];
-               if (flags & SBI_PMU_START_FLAG_SET_INIT_VALUE)
+               if (flags & SBI_PMU_START_FLAG_SET_INIT_VALUE) {
                        pmc->counter_val = ival;
+               } else if (snap_flag_set) {
+                       /* The counter index in the snapshot are relative to the counter base */
+                       pmc->counter_val = kvpmu->sdata->ctr_values[i];
+               }
+
                if (pmc->cinfo.type == SBI_PMU_CTR_TYPE_FW) {
                        fevent_code = get_event_code(pmc->event_idx);
                        if (fevent_code >= SBI_PMU_FW_MAX) {
 {
        struct kvm_pmu *kvpmu = vcpu_to_pmu(vcpu);
        int i, pmc_index, sbiret = 0;
+       u64 enabled, running;
        struct kvm_pmc *pmc;
        int fevent_code;
+       bool snap_flag_set = flags & SBI_PMU_STOP_FLAG_TAKE_SNAPSHOT;
+       bool shmem_needs_update = false;
 
        if (kvm_pmu_validate_counter_mask(kvpmu, ctr_base, ctr_mask) < 0) {
                sbiret = SBI_ERR_INVALID_PARAM;
                goto out;
        }
 
+       if (snap_flag_set && kvpmu->snapshot_addr == INVALID_GPA) {
+               sbiret = SBI_ERR_NO_SHMEM;
+               goto out;
+       }
+
        /* Stop the counters that have been configured and requested by the guest */
        for_each_set_bit(i, &ctr_mask, RISCV_MAX_COUNTERS) {
                pmc_index = i + ctr_base;
                } else {
                        sbiret = SBI_ERR_INVALID_PARAM;
                }
+
+               if (snap_flag_set && !sbiret) {
+                       if (pmc->cinfo.type == SBI_PMU_CTR_TYPE_FW)
+                               pmc->counter_val = kvpmu->fw_event[fevent_code].value;
+                       else if (pmc->perf_event)
+                               pmc->counter_val += perf_event_read_value(pmc->perf_event,
+                                                                         &enabled, &running);
+                       /* TODO: Add counter overflow support when sscofpmf support is added */
+                       kvpmu->sdata->ctr_values[i] = pmc->counter_val;
+                       shmem_needs_update = true;
+               }
+
                if (flags & SBI_PMU_STOP_FLAG_RESET) {
                        pmc->event_idx = SBI_PMU_EVENT_IDX_INVALID;
                        clear_bit(pmc_index, kvpmu->pmc_in_use);
                }
        }
 
+       if (shmem_needs_update)
+               kvm_vcpu_write_guest(vcpu, kvpmu->snapshot_addr, kvpmu->sdata,
+                                            sizeof(struct riscv_pmu_snapshot_data));
+
 out:
        retdata->err_val = sbiret;
 
        kvpmu->num_hw_ctrs = num_hw_ctrs + 1;
        kvpmu->num_fw_ctrs = SBI_PMU_FW_MAX;
        memset(&kvpmu->fw_event, 0, SBI_PMU_FW_MAX * sizeof(struct kvm_fw_event));
+       kvpmu->snapshot_addr = INVALID_GPA;
 
        if (kvpmu->num_hw_ctrs > RISCV_KVM_MAX_HW_CTRS) {
                pr_warn_once("Limiting the hardware counters to 32 as specified by the ISA");
        }
        bitmap_zero(kvpmu->pmc_in_use, RISCV_MAX_COUNTERS);
        memset(&kvpmu->fw_event, 0, SBI_PMU_FW_MAX * sizeof(struct kvm_fw_event));
+       kvm_pmu_clear_snapshot_area(vcpu);
 }
 
 void kvm_riscv_vcpu_pmu_reset(struct kvm_vcpu *vcpu)