static void venus_reset_cpu(struct venus_core *core)
 {
+       u32 fw_size = core->fw.mapped_mem_size;
        void __iomem *base = core->base;
 
        writel(0, base + WRAPPER_FW_START_ADDR);
-       writel(VENUS_FW_MEM_SIZE, base + WRAPPER_FW_END_ADDR);
+       writel(fw_size, base + WRAPPER_FW_END_ADDR);
        writel(0, base + WRAPPER_CPA_START_ADDR);
-       writel(VENUS_FW_MEM_SIZE, base + WRAPPER_CPA_END_ADDR);
-       writel(VENUS_FW_MEM_SIZE, base + WRAPPER_NONPIX_START_ADDR);
-       writel(VENUS_FW_MEM_SIZE, base + WRAPPER_NONPIX_END_ADDR);
+       writel(fw_size, base + WRAPPER_CPA_END_ADDR);
+       writel(fw_size, base + WRAPPER_NONPIX_START_ADDR);
+       writel(fw_size, base + WRAPPER_NONPIX_END_ADDR);
        writel(0x0, base + WRAPPER_CPU_CGC_DIS);
        writel(0x0, base + WRAPPER_CPU_CLOCK_CONFIG);
 
        void *mem_va;
        int ret;
 
+       *mem_phys = 0;
+       *mem_size = 0;
+
        dev = core->dev;
        node = of_parse_phandle(dev->of_node, "memory-region", 0);
        if (!node) {
        if (ret)
                return ret;
 
+       ret = request_firmware(&mdt, fwname, dev);
+       if (ret < 0)
+               return ret;
+
+       fw_size = qcom_mdt_get_size(mdt);
+       if (fw_size < 0) {
+               ret = fw_size;
+               goto err_release_fw;
+       }
+
        *mem_phys = r.start;
        *mem_size = resource_size(&r);
 
-       if (*mem_size < VENUS_FW_MEM_SIZE)
-               return -EINVAL;
+       if (*mem_size < fw_size || fw_size > VENUS_FW_MEM_SIZE) {
+               ret = -EINVAL;
+               goto err_release_fw;
+       }
 
        mem_va = memremap(r.start, *mem_size, MEMREMAP_WC);
        if (!mem_va) {
                dev_err(dev, "unable to map memory region: %pa+%zx\n",
                        &r.start, *mem_size);
-               return -ENOMEM;
-       }
-
-       ret = request_firmware(&mdt, fwname, dev);
-       if (ret < 0)
-               goto err_unmap;
-
-       fw_size = qcom_mdt_get_size(mdt);
-       if (fw_size < 0) {
-               ret = fw_size;
-               release_firmware(mdt);
-               goto err_unmap;
+               ret = -ENOMEM;
+               goto err_release_fw;
        }
 
        if (core->use_tz)
                ret = qcom_mdt_load_no_init(dev, mdt, fwname, VENUS_PAS_ID,
                                            mem_va, *mem_phys, *mem_size, NULL);
 
-       release_firmware(mdt);
-
-err_unmap:
        memunmap(mem_va);
+err_release_fw:
+       release_firmware(mdt);
        return ret;
 }
 
                return -EPROBE_DEFER;
 
        iommu = core->fw.iommu_domain;
+       core->fw.mapped_mem_size = mem_size;
 
        ret = iommu_map(iommu, VENUS_FW_START_ADDR, mem_phys, mem_size,
                        IOMMU_READ | IOMMU_WRITE | IOMMU_PRIV);
 
 static int venus_shutdown_no_tz(struct venus_core *core)
 {
+       const size_t mapped = core->fw.mapped_mem_size;
        struct iommu_domain *iommu;
        size_t unmapped;
        u32 reg;
 
        iommu = core->fw.iommu_domain;
 
-       unmapped = iommu_unmap(iommu, VENUS_FW_START_ADDR, VENUS_FW_MEM_SIZE);
-       if (unmapped != VENUS_FW_MEM_SIZE)
+       unmapped = iommu_unmap(iommu, VENUS_FW_START_ADDR, mapped);
+       if (unmapped != mapped)
                dev_err(dev, "failed to unmap firmware\n");
 
        return 0;