return -ENOMEM;
        }
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
        rc = hl_mmu_map_contiguous(ctx, cb->virtual_addr, cb->bus_address, cb->roundup_size);
        if (rc) {
                dev_err(hdev->dev, "Failed to map VA %#llx to CB\n", cb->virtual_addr);
                goto err_va_umap;
        }
        rc = hl_mmu_invalidate_cache(hdev, false, MMU_OP_USERPTR | MMU_OP_SKIP_LOW_CACHE_INV);
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 
        cb->is_mmu_mapped = true;
        return rc;
 
 err_va_umap:
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
        gen_pool_free(ctx->cb_va_pool, cb->virtual_addr, cb->roundup_size);
        return rc;
 }
 {
        struct hl_device *hdev = ctx->hdev;
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
        hl_mmu_unmap_contiguous(ctx, cb->virtual_addr, cb->roundup_size);
        hl_mmu_invalidate_cache(hdev, true, MMU_OP_USERPTR);
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 
        gen_pool_free(ctx->cb_va_pool, cb->virtual_addr, cb->roundup_size);
 }
 
  *                 command submissions for a long time after CS id wraparound.
  * @va_range: holds available virtual addresses for host and dram mappings.
  * @mem_hash_lock: protects the mem_hash.
- * @mmu_lock: protects the MMU page tables. Any change to the PGT, modifying the
- *            MMU hash or walking the PGT requires talking this lock.
  * @hw_block_list_lock: protects the HW block memory list.
  * @debugfs_list: node in debugfs list of contexts.
  * @hw_block_mem_list: list of HW block virtual mapped addresses.
        struct hl_cs_outcome_store      outcome_store;
        struct hl_va_range              *va_range[HL_VA_RANGE_TYPE_MAX];
        struct mutex                    mem_hash_lock;
-       struct mutex                    mmu_lock;
        struct mutex                    hw_block_list_lock;
        struct list_head                debugfs_list;
        struct list_head                hw_block_mem_list;
  * @asid_mutex: protects asid_bitmap.
  * @send_cpu_message_lock: enforces only one message in Host <-> CPU-CP queue.
  * @debug_lock: protects critical section of setting debug mode for device
+ * @mmu_lock: protects the MMU page tables and invalidation h/w. Although the
+ *            page tables are per context, the invalidation h/w is per MMU.
+ *            Therefore, we can't allow multiple contexts (we only have two,
+ *            user and kernel) to access the invalidation h/w at the same time.
+ *            In addition, any change to the PGT, modifying the MMU hash or
+ *            walking the PGT requires talking this lock.
  * @asic_prop: ASIC specific immutable properties.
  * @asic_funcs: ASIC specific functions.
  * @asic_specific: ASIC specific information to use only from ASIC files.
        struct mutex                    asid_mutex;
        struct mutex                    send_cpu_message_lock;
        struct mutex                    debug_lock;
+       struct mutex                    mmu_lock;
        struct asic_fixed_properties    asic_prop;
        const struct hl_asic_funcs      *asic_funcs;
        void                            *asic_specific;
 
                goto va_block_err;
        }
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
 
        rc = map_phys_pg_pack(ctx, ret_vaddr, phys_pg_pack);
        if (rc) {
                dev_err(hdev->dev, "mapping page pack failed for handle %u\n", handle);
-               mutex_unlock(&ctx->mmu_lock);
+               mutex_unlock(&hdev->mmu_lock);
                goto map_err;
        }
 
        rc = hl_mmu_invalidate_cache_range(hdev, false, *vm_type | MMU_OP_SKIP_LOW_CACHE_INV,
                                ctx->asid, ret_vaddr, phys_pg_pack->total_size);
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
        if (rc)
                goto map_err;
 
        else
                vaddr &= ~(((u64) phys_pg_pack->page_size) - 1);
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
 
        unmap_phys_pg_pack(ctx, vaddr, phys_pg_pack);
 
                rc = hl_mmu_invalidate_cache_range(hdev, true, *vm_type, ctx->asid, vaddr,
                                                        phys_pg_pack->total_size);
 
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 
        /*
         * If the context is closing we don't need to check for the MMU cache
                unmap_device_va(ctx, &args, true);
        }
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
 
        /* invalidate the cache once after the unmapping loop */
        hl_mmu_invalidate_cache(hdev, true, MMU_OP_USERPTR);
        hl_mmu_invalidate_cache(hdev, true, MMU_OP_PHYS_PACK);
 
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 
        INIT_LIST_HEAD(&free_list);
 
 
        if (!hdev->mmu_enable)
                return 0;
 
+       mutex_init(&hdev->mmu_lock);
+
        if (hdev->mmu_func[MMU_DR_PGT].init != NULL) {
                rc = hdev->mmu_func[MMU_DR_PGT].init(hdev);
                if (rc)
 
        if (hdev->mmu_func[MMU_HR_PGT].fini != NULL)
                hdev->mmu_func[MMU_HR_PGT].fini(hdev);
+
+       mutex_destroy(&hdev->mmu_lock);
 }
 
 /**
        if (!hdev->mmu_enable)
                return 0;
 
-       mutex_init(&ctx->mmu_lock);
-
        if (hdev->mmu_func[MMU_DR_PGT].ctx_init != NULL) {
                rc = hdev->mmu_func[MMU_DR_PGT].ctx_init(ctx);
                if (rc)
 
        if (hdev->mmu_func[MMU_HR_PGT].ctx_fini != NULL)
                hdev->mmu_func[MMU_HR_PGT].ctx_fini(ctx);
-
-       mutex_destroy(&ctx->mmu_lock);
 }
 
 /*
        pgt_residency = mmu_prop->host_resident ? MMU_HR_PGT : MMU_DR_PGT;
        mmu_funcs = hl_mmu_get_funcs(hdev, pgt_residency, is_dram_addr);
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
        rc = mmu_funcs->get_tlb_info(ctx, virt_addr, hops);
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 
        if (rc)
                return rc;
 {
        struct hl_prefetch_work *pfw = container_of(work, struct hl_prefetch_work, pf_work);
        struct hl_ctx *ctx = pfw->ctx;
+       struct hl_device *hdev = ctx->hdev;
 
-       if (!hl_device_operational(ctx->hdev, NULL))
+       if (!hl_device_operational(hdev, NULL))
                goto put_ctx;
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
 
-       ctx->hdev->asic_funcs->mmu_prefetch_cache_range(ctx, pfw->flags, pfw->asid,
-                                                               pfw->va, pfw->size);
+       hdev->asic_funcs->mmu_prefetch_cache_range(ctx, pfw->flags, pfw->asid, pfw->va, pfw->size);
 
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 
 put_ctx:
        /*
 
                goto destroy_internal_cb_pool;
        }
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
        rc = hl_mmu_map_contiguous(ctx, hdev->internal_cb_va_base,
                        hdev->internal_cb_pool_dma_addr,
                        HOST_SPACE_INTERNAL_CB_SZ);
 
        hl_mmu_invalidate_cache(hdev, false, MMU_OP_USERPTR);
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 
        if (rc)
                goto unreserve_internal_cb_pool;
        if (!(gaudi->hw_cap_initialized & HW_CAP_MMU))
                return;
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
        hl_mmu_unmap_contiguous(ctx, hdev->internal_cb_va_base,
                        HOST_SPACE_INTERNAL_CB_SZ);
        hl_unreserve_va_block(hdev, ctx, hdev->internal_cb_va_base,
                        HOST_SPACE_INTERNAL_CB_SZ);
        hl_mmu_invalidate_cache(hdev, true, MMU_OP_USERPTR);
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 
        gen_pool_destroy(hdev->internal_cb_pool);
 
 
        }
 
        /* Create mapping on asic side */
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
        rc = hl_mmu_map_contiguous(ctx, reserved_va_base, host_mem_dma_addr, SZ_2M);
        hl_mmu_invalidate_cache_range(hdev, false,
                                      MMU_OP_USERPTR | MMU_OP_SKIP_LOW_CACHE_INV,
                                      ctx->asid, reserved_va_base, SZ_2M);
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
        if (rc) {
                dev_err(hdev->dev, "Failed to create mapping on asic mmu\n");
                goto unreserve_va;
 
        gaudi2_kdma_set_mmbp_asid(hdev, true, HL_KERNEL_ASID_ID);
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
        hl_mmu_unmap_contiguous(ctx, reserved_va_base, SZ_2M);
        hl_mmu_invalidate_cache_range(hdev, false, MMU_OP_USERPTR,
                                      ctx->asid, reserved_va_base, SZ_2M);
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 unreserve_va:
        hl_unreserve_va_block(hdev, ctx, reserved_va_base, SZ_2M);
 free_data_buffer:
                goto destroy_internal_cb_pool;
        }
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
        rc = hl_mmu_map_contiguous(ctx, hdev->internal_cb_va_base, hdev->internal_cb_pool_dma_addr,
                                        HOST_SPACE_INTERNAL_CB_SZ);
        hl_mmu_invalidate_cache(hdev, false, MMU_OP_USERPTR);
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 
        if (rc)
                goto unreserve_internal_cb_pool;
        if (!(gaudi2->hw_cap_initialized & HW_CAP_PMMU))
                return;
 
-       mutex_lock(&ctx->mmu_lock);
+       mutex_lock(&hdev->mmu_lock);
        hl_mmu_unmap_contiguous(ctx, hdev->internal_cb_va_base, HOST_SPACE_INTERNAL_CB_SZ);
        hl_unreserve_va_block(hdev, ctx, hdev->internal_cb_va_base, HOST_SPACE_INTERNAL_CB_SZ);
        hl_mmu_invalidate_cache(hdev, true, MMU_OP_USERPTR);
-       mutex_unlock(&ctx->mmu_lock);
+       mutex_unlock(&hdev->mmu_lock);
 
        gen_pool_destroy(hdev->internal_cb_pool);