habanalabs: protect access to dynamic mem 'user_mappings'
authorKoby Elbaz <kelbaz@habana.ai>
Fri, 23 Dec 2022 13:02:05 +0000 (15:02 +0200)
committerOded Gabbay <ogabbay@kernel.org>
Thu, 26 Jan 2023 09:52:11 +0000 (11:52 +0200)
When HL_INFO_USER_MAPPINGS IOCTL is called, we copy_to_user from
a dynamically allocated memory - 'user_mappings'.
Since freeing/allocating it happens in runtime (upon a page fault),
it not unlikely to access it even before being initially allocated
(i.e., accessing a NULL pointer).

The solution is to simply mark the spot when the err info has been
collected, and that way to know whether err info (either page fault
or RAZWI) is available to be read.

Signed-off-by: Koby Elbaz <kelbaz@habana.ai>
Reviewed-by: Oded Gabbay <ogabbay@kernel.org>
Signed-off-by: Oded Gabbay <ogabbay@kernel.org>
drivers/accel/habanalabs/common/device.c
drivers/accel/habanalabs/common/habanalabs.h
drivers/accel/habanalabs/common/habanalabs_drv.c
drivers/accel/habanalabs/common/habanalabs_ioctl.c

index e1b5a2c3498603d3e8415e3515ad09f5e4a13c6c..6a05ab3fda236f618d7eb01dc863a98e91b364d4 100644 (file)
@@ -2441,6 +2441,8 @@ void hl_capture_razwi(struct hl_device *hdev, u64 addr, u16 *engine_id, u16 num_
        memcpy(&razwi_info->razwi.engine_id[0], &engine_id[0],
                        num_of_engines * sizeof(u16));
        razwi_info->razwi.flags = flags;
+
+       razwi_info->razwi_info_available = true;
 }
 
 void hl_handle_razwi(struct hl_device *hdev, u64 addr, u16 *engine_id, u16 num_of_engines,
@@ -2526,6 +2528,8 @@ void hl_capture_page_fault(struct hl_device *hdev, u64 addr, u16 eng_id, bool is
        pgf_info->page_fault.addr = addr;
        pgf_info->page_fault.engine_id = eng_id;
        hl_capture_user_mappings(hdev, is_pmmu);
+
+       pgf_info->page_fault_info_available = true;
 }
 
 void hl_handle_page_fault(struct hl_device *hdev, u64 addr, u16 eng_id, bool is_pmmu,
index e578645acba96465a91316da35b374bec1c4e962..cd474422163d82669e60df36559a7a3589268a82 100644 (file)
@@ -2984,12 +2984,14 @@ struct undefined_opcode_info {
  *                       Since we're looking for the page-fault's root cause,
  *                       we don't care of the others that might follow it-
  *                       so once changed to 1, it will remain that way.
+ * @page_fault_info_available: indicates that a page fault info is now available.
  */
 struct page_fault_info {
        struct hl_page_fault_info       page_fault;
        struct hl_user_mapping          *user_mappings;
        u64                             num_of_user_mappings;
        atomic_t                        page_fault_detected;
+       bool                            page_fault_info_available;
 };
 
 /**
@@ -3000,10 +3002,12 @@ struct page_fault_info {
  *                  Since we're looking for the RAZWI's root cause,
  *                  we don't care of the others that might follow it-
  *                  so once changed to 1, it will remain that way.
+ * @razwi_info_available: indicates that a RAZWI info is now available.
  */
 struct razwi_info {
        struct hl_info_razwi_event      razwi;
        atomic_t                        razwi_detected;
+       bool                            razwi_info_available;
 };
 
 /**
index d7fe0af33bca3a77cea9109db1c574603dd6cc1d..03dae57dc8386306b3ea34b7ec8337f87c365cba 100644 (file)
@@ -225,6 +225,8 @@ int hl_device_open(struct inode *inode, struct file *filp)
        atomic_set(&hdev->captured_err_info.razwi_info.razwi_detected, 0);
        atomic_set(&hdev->captured_err_info.page_fault_info.page_fault_detected, 0);
        hdev->captured_err_info.undef_opcode.write_enable = true;
+       hdev->captured_err_info.razwi_info.razwi_info_available = false;
+       hdev->captured_err_info.page_fault_info.page_fault_info_available = false;
 
        hdev->open_counter++;
        hdev->last_successful_open_jif = jiffies;
index 949d3852716030518b1cfbf250ada0c9d68eee0d..72493bf94ba317f5a6e61398a8d3923b53836390 100644 (file)
@@ -607,16 +607,20 @@ static int cs_timeout_info(struct hl_fpriv *hpriv, struct hl_info_args *args)
 
 static int razwi_info(struct hl_fpriv *hpriv, struct hl_info_args *args)
 {
+       void __user *out = (void __user *) (uintptr_t) args->return_pointer;
        struct hl_device *hdev = hpriv->hdev;
        u32 max_size = args->return_size;
-       struct hl_info_razwi_event *info = &hdev->captured_err_info.razwi_info.razwi;
-       void __user *out = (void __user *) (uintptr_t) args->return_pointer;
+       struct razwi_info *razwi_info;
 
        if ((!max_size) || (!out))
                return -EINVAL;
 
-       return copy_to_user(out, info, min_t(size_t, max_size, sizeof(struct hl_info_razwi_event)))
-                               ? -EFAULT : 0;
+       razwi_info = &hdev->captured_err_info.razwi_info;
+       if (!razwi_info->razwi_info_available)
+               return 0;
+
+       return copy_to_user(out, &razwi_info->razwi,
+                       min_t(size_t, max_size, sizeof(struct hl_info_razwi_event))) ? -EFAULT : 0;
 }
 
 static int undefined_opcode_info(struct hl_fpriv *hpriv, struct hl_info_args *args)
@@ -786,16 +790,20 @@ static int engine_status_info(struct hl_fpriv *hpriv, struct hl_info_args *args)
 
 static int page_fault_info(struct hl_fpriv *hpriv, struct hl_info_args *args)
 {
+       void __user *out = (void __user *) (uintptr_t) args->return_pointer;
        struct hl_device *hdev = hpriv->hdev;
        u32 max_size = args->return_size;
-       struct hl_page_fault_info *info = &hdev->captured_err_info.page_fault_info.page_fault;
-       void __user *out = (void __user *) (uintptr_t) args->return_pointer;
+       struct page_fault_info *pgf_info;
 
        if ((!max_size) || (!out))
                return -EINVAL;
 
-       return copy_to_user(out, info, min_t(size_t, max_size, sizeof(struct hl_page_fault_info)))
-                               ? -EFAULT : 0;
+       pgf_info = &hdev->captured_err_info.page_fault_info;
+       if (!pgf_info->page_fault_info_available)
+               return 0;
+
+       return copy_to_user(out, &pgf_info->page_fault,
+                       min_t(size_t, max_size, sizeof(struct hl_page_fault_info))) ? -EFAULT : 0;
 }
 
 static int user_mappings_info(struct hl_fpriv *hpriv, struct hl_info_args *args)
@@ -806,18 +814,20 @@ static int user_mappings_info(struct hl_fpriv *hpriv, struct hl_info_args *args)
        struct page_fault_info *pgf_info;
        u64 actual_size;
 
-       pgf_info = &hdev->captured_err_info.page_fault_info;
-       args->array_size = pgf_info->num_of_user_mappings;
-
        if (!out)
                return -EINVAL;
 
+       pgf_info = &hdev->captured_err_info.page_fault_info;
+       if (!pgf_info->page_fault_info_available)
+               return 0;
+
+       args->array_size = pgf_info->num_of_user_mappings;
+
        actual_size = pgf_info->num_of_user_mappings * sizeof(struct hl_user_mapping);
        if (user_buf_size < actual_size)
                return -ENOMEM;
 
-       return copy_to_user(out, pgf_info->user_mappings, min_t(size_t, user_buf_size, actual_size))
-                               ? -EFAULT : 0;
+       return copy_to_user(out, pgf_info->user_mappings, actual_size) ? -EFAULT : 0;
 }
 
 static int send_fw_generic_request(struct hl_device *hdev, struct hl_info_args *info_args)