wifi: cfg80211: refactor RNR parsing
authorJohannes Berg <johannes.berg@intel.com>
Fri, 16 Feb 2024 11:54:30 +0000 (13:54 +0200)
committerJohannes Berg <johannes.berg@intel.com>
Wed, 21 Feb 2024 14:19:04 +0000 (15:19 +0100)
We'll need more parsing of the reduced neighbor report element,
and we already have two places doing pretty much the same.
Combine by refactoring the parsing into a separate function
with a callback for each item found.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
Reviewed-by: Benjamin Berg <benjamin.berg@intel.com>
Signed-off-by: Miri Korenblit <miriam.rachel.korenblit@intel.com>
Link: https://msgid.link/20240216135047.cfff14b692fc.Ibe25be88a769eab29ebb17b9d19af666df6a2227@changeid
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
net/wireless/scan.c

index c2d85fa4b75d5dedfd73318cad10177903e8ae75..e46dfc71c4970981c7f417d61292625c6fcf3cf3 100644 (file)
@@ -611,104 +611,144 @@ static int cfg80211_parse_ap_info(struct cfg80211_colocated_ap *entry,
        return 0;
 }
 
-VISIBLE_IF_CFG80211_KUNIT int
-cfg80211_parse_colocated_ap(const struct cfg80211_bss_ies *ies,
-                           struct list_head *list)
+enum cfg80211_rnr_iter_ret {
+       RNR_ITER_CONTINUE,
+       RNR_ITER_BREAK,
+       RNR_ITER_ERROR,
+};
+
+static bool
+cfg80211_iter_rnr(const u8 *elems, size_t elems_len,
+                 enum cfg80211_rnr_iter_ret
+                 (*iter)(void *data, u8 type,
+                         const struct ieee80211_neighbor_ap_info *info,
+                         const u8 *tbtt_info, u8 tbtt_info_len),
+                 void *iter_data)
 {
-       struct ieee80211_neighbor_ap_info *ap_info;
-       const struct element *elem, *ssid_elem;
+       const struct element *rnr;
        const u8 *pos, *end;
-       u32 s_ssid_tmp;
-       int n_coloc = 0, ret;
-       LIST_HEAD(ap_list);
 
-       ret = cfg80211_calc_short_ssid(ies, &ssid_elem, &s_ssid_tmp);
-       if (ret)
-               return 0;
+       for_each_element_id(rnr, WLAN_EID_REDUCED_NEIGHBOR_REPORT,
+                           elems, elems_len) {
+               const struct ieee80211_neighbor_ap_info *info;
 
-       for_each_element_id(elem, WLAN_EID_REDUCED_NEIGHBOR_REPORT,
-                           ies->data, ies->len) {
-               pos = elem->data;
-               end = elem->data + elem->datalen;
+               pos = rnr->data;
+               end = rnr->data + rnr->datalen;
 
                /* RNR IE may contain more than one NEIGHBOR_AP_INFO */
-               while (pos + sizeof(*ap_info) <= end) {
-                       enum nl80211_band band;
-                       int freq;
+               while (sizeof(*info) <= end - pos) {
                        u8 length, i, count;
+                       u8 type;
 
-                       ap_info = (void *)pos;
-                       count = u8_get_bits(ap_info->tbtt_info_hdr,
-                                           IEEE80211_AP_INFO_TBTT_HDR_COUNT) + 1;
-                       length = ap_info->tbtt_info_len;
+                       info = (void *)pos;
+                       count = u8_get_bits(info->tbtt_info_hdr,
+                                           IEEE80211_AP_INFO_TBTT_HDR_COUNT) +
+                               1;
+                       length = info->tbtt_info_len;
 
-                       pos += sizeof(*ap_info);
+                       pos += sizeof(*info);
 
-                       if (!ieee80211_operating_class_to_band(ap_info->op_class,
-                                                              &band))
-                               break;
+                       if (count * length > end - pos)
+                               return false;
 
-                       freq = ieee80211_channel_to_frequency(ap_info->channel,
-                                                             band);
+                       type = u8_get_bits(info->tbtt_info_hdr,
+                                          IEEE80211_AP_INFO_TBTT_HDR_TYPE);
 
-                       if (end - pos < count * length)
-                               break;
+                       for (i = 0; i < count; i++) {
+                               switch (iter(iter_data, type, info,
+                                            pos, length)) {
+                               case RNR_ITER_CONTINUE:
+                                       break;
+                               case RNR_ITER_BREAK:
+                                       return true;
+                               case RNR_ITER_ERROR:
+                                       return false;
+                               }
 
-                       if (u8_get_bits(ap_info->tbtt_info_hdr,
-                                       IEEE80211_AP_INFO_TBTT_HDR_TYPE) !=
-                           IEEE80211_TBTT_INFO_TYPE_TBTT) {
-                               pos += count * length;
-                               continue;
+                               pos += length;
                        }
+               }
 
-                       /* TBTT info must include bss param + BSSID +
-                        * (short SSID or same_ssid bit to be set).
-                        * ignore other options, and move to the
-                        * next AP info
-                        */
-                       if (band != NL80211_BAND_6GHZ ||
-                           !(length == offsetofend(struct ieee80211_tbtt_info_7_8_9,
-                                                   bss_params) ||
-                             length == sizeof(struct ieee80211_tbtt_info_7_8_9) ||
-                             length >= offsetofend(struct ieee80211_tbtt_info_ge_11,
-                                                   bss_params))) {
-                               pos += count * length;
-                               continue;
-                       }
+               if (pos != end)
+                       return false;
+       }
 
-                       for (i = 0; i < count; i++) {
-                               struct cfg80211_colocated_ap *entry;
+       return true;
+}
+
+struct colocated_ap_data {
+       const struct element *ssid_elem;
+       struct list_head ap_list;
+       u32 s_ssid_tmp;
+       int n_coloc;
+};
 
-                               entry = kzalloc(sizeof(*entry) + IEEE80211_MAX_SSID_LEN,
-                                               GFP_ATOMIC);
+static enum cfg80211_rnr_iter_ret
+cfg80211_parse_colocated_ap_iter(void *_data, u8 type,
+                                const struct ieee80211_neighbor_ap_info *info,
+                                const u8 *tbtt_info, u8 tbtt_info_len)
+{
+       struct colocated_ap_data *data = _data;
+       struct cfg80211_colocated_ap *entry;
+       enum nl80211_band band;
 
-                               if (!entry)
-                                       goto error;
+       if (type != IEEE80211_TBTT_INFO_TYPE_TBTT)
+               return RNR_ITER_CONTINUE;
 
-                               entry->center_freq = freq;
+       if (!ieee80211_operating_class_to_band(info->op_class, &band))
+               return RNR_ITER_CONTINUE;
 
-                               if (!cfg80211_parse_ap_info(entry, pos, length,
-                                                           ssid_elem,
-                                                           s_ssid_tmp)) {
-                                       n_coloc++;
-                                       list_add_tail(&entry->list, &ap_list);
-                               } else {
-                                       kfree(entry);
-                               }
+       /* TBTT info must include bss param + BSSID + (short SSID or
+        * same_ssid bit to be set). Ignore other options, and move to
+        * the next AP info
+        */
+       if (band != NL80211_BAND_6GHZ ||
+           !(tbtt_info_len == offsetofend(struct ieee80211_tbtt_info_7_8_9,
+                                          bss_params) ||
+             tbtt_info_len == sizeof(struct ieee80211_tbtt_info_7_8_9) ||
+             tbtt_info_len >= offsetofend(struct ieee80211_tbtt_info_ge_11,
+                                          bss_params)))
+               return RNR_ITER_CONTINUE;
+
+       entry = kzalloc(sizeof(*entry) + IEEE80211_MAX_SSID_LEN, GFP_ATOMIC);
+       if (!entry)
+               return RNR_ITER_ERROR;
+
+       entry->center_freq =
+               ieee80211_channel_to_frequency(info->channel, band);
+
+       if (!cfg80211_parse_ap_info(entry, tbtt_info, tbtt_info_len,
+                                   data->ssid_elem, data->s_ssid_tmp)) {
+               data->n_coloc++;
+               list_add_tail(&entry->list, &data->ap_list);
+       } else {
+               kfree(entry);
+       }
 
-                               pos += length;
-                       }
-               }
+       return RNR_ITER_CONTINUE;
+}
 
-error:
-               if (pos != end) {
-                       cfg80211_free_coloc_ap_list(&ap_list);
-                       return 0;
-               }
+VISIBLE_IF_CFG80211_KUNIT int
+cfg80211_parse_colocated_ap(const struct cfg80211_bss_ies *ies,
+                           struct list_head *list)
+{
+       struct colocated_ap_data data = {};
+       int ret;
+
+       INIT_LIST_HEAD(&data.ap_list);
+
+       ret = cfg80211_calc_short_ssid(ies, &data.ssid_elem, &data.s_ssid_tmp);
+       if (ret)
+               return 0;
+
+       if (!cfg80211_iter_rnr(ies->data, ies->len,
+                              cfg80211_parse_colocated_ap_iter, &data)) {
+               cfg80211_free_coloc_ap_list(&data.ap_list);
+               return 0;
        }
 
-       list_splice_tail(&ap_list, list);
-       return n_coloc;
+       list_splice_tail(&data.ap_list, list);
+       return data.n_coloc;
 }
 EXPORT_SYMBOL_IF_CFG80211_KUNIT(cfg80211_parse_colocated_ap);
 
@@ -2607,79 +2647,71 @@ error:
        return NULL;
 }
 
-static u8
-cfg80211_rnr_info_for_mld_ap(const u8 *ie, size_t ielen, u8 mld_id, u8 link_id,
-                            const struct ieee80211_neighbor_ap_info **ap_info,
-                            u8 *param_ch_count)
-{
-       const struct ieee80211_neighbor_ap_info *info;
-       const struct element *rnr;
-       const u8 *pos, *end;
-
-       for_each_element_id(rnr, WLAN_EID_REDUCED_NEIGHBOR_REPORT, ie, ielen) {
-               pos = rnr->data;
-               end = rnr->data + rnr->datalen;
-
-               /* RNR IE may contain more than one NEIGHBOR_AP_INFO */
-               while (sizeof(*info) <= end - pos) {
-                       const struct ieee80211_rnr_mld_params *mld_params;
-                       u16 params;
-                       u8 length, i, count, mld_params_offset;
-                       u8 type, lid;
-                       u32 use_for;
-
-                       info = (void *)pos;
-                       count = u8_get_bits(info->tbtt_info_hdr,
-                                           IEEE80211_AP_INFO_TBTT_HDR_COUNT) + 1;
-                       length = info->tbtt_info_len;
+struct tbtt_info_iter_data {
+       const struct ieee80211_neighbor_ap_info *ap_info;
+       u8 param_ch_count;
+       u32 use_for;
+       u8 mld_id, link_id;
+};
 
-                       pos += sizeof(*info);
+static enum cfg80211_rnr_iter_ret
+cfg802121_mld_ap_rnr_iter(void *_data, u8 type,
+                         const struct ieee80211_neighbor_ap_info *info,
+                         const u8 *tbtt_info, u8 tbtt_info_len)
+{
+       const struct ieee80211_rnr_mld_params *mld_params;
+       struct tbtt_info_iter_data *data = _data;
+       u8 link_id;
+
+       if (type == IEEE80211_TBTT_INFO_TYPE_TBTT &&
+           tbtt_info_len >= offsetofend(struct ieee80211_tbtt_info_ge_11,
+                                        mld_params))
+               mld_params = (void *)(tbtt_info +
+                                     offsetof(struct ieee80211_tbtt_info_ge_11,
+                                              mld_params));
+       else if (type == IEEE80211_TBTT_INFO_TYPE_MLD &&
+                tbtt_info_len >= sizeof(struct ieee80211_rnr_mld_params))
+               mld_params = (void *)tbtt_info;
+       else
+               return RNR_ITER_CONTINUE;
 
-                       if (count * length > end - pos)
-                               return 0;
+       link_id = le16_get_bits(mld_params->params,
+                               IEEE80211_RNR_MLD_PARAMS_LINK_ID);
 
-                       type = u8_get_bits(info->tbtt_info_hdr,
-                                          IEEE80211_AP_INFO_TBTT_HDR_TYPE);
+       if (data->mld_id != mld_params->mld_id)
+               return RNR_ITER_CONTINUE;
 
-                       if (type == IEEE80211_TBTT_INFO_TYPE_TBTT &&
-                           length >=
-                           offsetofend(struct ieee80211_tbtt_info_ge_11,
-                                       mld_params)) {
-                               mld_params_offset =
-                                       offsetof(struct ieee80211_tbtt_info_ge_11, mld_params);
-                               use_for = NL80211_BSS_USE_FOR_ALL;
-                       } else if (type == IEEE80211_TBTT_INFO_TYPE_MLD &&
-                                  length >= sizeof(struct ieee80211_rnr_mld_params)) {
-                               mld_params_offset = 0;
-                               use_for = NL80211_BSS_USE_FOR_MLD_LINK;
-                       } else {
-                               pos += count * length;
-                               continue;
-                       }
+       if (data->link_id != link_id)
+               return RNR_ITER_CONTINUE;
 
-                       for (i = 0; i < count; i++) {
-                               mld_params = (void *)pos + mld_params_offset;
-                               params = le16_to_cpu(mld_params->params);
+       data->ap_info = info;
+       data->param_ch_count =
+               le16_get_bits(mld_params->params,
+                             IEEE80211_RNR_MLD_PARAMS_BSS_CHANGE_COUNT);
 
-                               lid = u16_get_bits(params,
-                                                  IEEE80211_RNR_MLD_PARAMS_LINK_ID);
+       if (type == IEEE80211_TBTT_INFO_TYPE_TBTT)
+               data->use_for = NL80211_BSS_USE_FOR_ALL;
+       else
+               data->use_for = NL80211_BSS_USE_FOR_MLD_LINK;
+       return RNR_ITER_BREAK;
+}
 
-                               if (mld_id == mld_params->mld_id &&
-                                   link_id == lid) {
-                                       *ap_info = info;
-                                       *param_ch_count =
-                                               le16_get_bits(mld_params->params,
-                                                             IEEE80211_RNR_MLD_PARAMS_BSS_CHANGE_COUNT);
+static u8
+cfg80211_rnr_info_for_mld_ap(const u8 *ie, size_t ielen, u8 mld_id, u8 link_id,
+                            const struct ieee80211_neighbor_ap_info **ap_info,
+                            u8 *param_ch_count)
+{
+       struct tbtt_info_iter_data data = {
+               .mld_id = mld_id,
+               .link_id = link_id,
+       };
 
-                                       return use_for;
-                               }
+       cfg80211_iter_rnr(ie, ielen, cfg802121_mld_ap_rnr_iter, &data);
 
-                               pos += length;
-                       }
-               }
-       }
+       *ap_info = data.ap_info;
+       *param_ch_count = data.param_ch_count;
 
-       return 0;
+       return data.use_for;
 }
 
 static struct element *