* Copyright 2008 Johannes Berg <johannes@sipsolutions.net>
  * Copyright 2013-2014  Intel Mobile Communications GmbH
  * Copyright 2016      Intel Deutschland GmbH
- * Copyright (C) 2018-2019 Intel Corporation
+ * Copyright (C) 2018-2020 Intel Corporation
  */
 #include <linux/kernel.h>
 #include <linux/slab.h>
 #include <linux/wireless.h>
 #include <linux/nl80211.h>
 #include <linux/etherdevice.h>
+#include <linux/crc32.h>
+#include <linux/bitfield.h>
 #include <net/arp.h>
 #include <net/cfg80211.h>
 #include <net/cfg80211-wext.h>
 
 #define IEEE80211_SCAN_RESULT_EXPIRE   (30 * HZ)
 
+/**
+ * struct cfg80211_colocated_ap - colocated AP information
+ *
+ * @list: linked list to all colocated aPS
+ * @bssid: BSSID of the reported AP
+ * @ssid: SSID of the reported AP
+ * @ssid_len: length of the ssid
+ * @center_freq: frequency the reported AP is on
+ * @unsolicited_probe: the reported AP is part of an ESS, where all the APs
+ *     that operate in the same channel as the reported AP and that might be
+ *     detected by a STA receiving this frame, are transmitting unsolicited
+ *     Probe Response frames every 20 TUs
+ * @oct_recommended: OCT is recommended to exchange MMPDUs with the reported AP
+ * @same_ssid: the reported AP has the same SSID as the reporting AP
+ * @multi_bss: the reported AP is part of a multiple BSSID set
+ * @transmitted_bssid: the reported AP is the transmitting BSSID
+ * @colocated_ess: all the APs that share the same ESS as the reported AP are
+ *     colocated and can be discovered via legacy bands.
+ * @short_ssid_valid: short_ssid is valid and can be used
+ * @short_ssid: the short SSID for this SSID
+ */
+struct cfg80211_colocated_ap {
+       struct list_head list;
+       u8 bssid[ETH_ALEN];
+       u8 ssid[IEEE80211_MAX_SSID_LEN];
+       size_t ssid_len;
+       u32 short_ssid;
+       u32 center_freq;
+       u8 unsolicited_probe:1,
+          oct_recommended:1,
+          same_ssid:1,
+          multi_bss:1,
+          transmitted_bssid:1,
+          colocated_ess:1,
+          short_ssid_valid:1;
+};
+
 static void bss_free(struct cfg80211_internal_bss *bss)
 {
        struct cfg80211_bss_ies *ies;
        return ret;
 }
 
+static u8 cfg80211_parse_bss_param(u8 data,
+                                  struct cfg80211_colocated_ap *coloc_ap)
+{
+       coloc_ap->oct_recommended =
+               u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_OCT_RECOMMENDED);
+       coloc_ap->same_ssid =
+               u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_SAME_SSID);
+       coloc_ap->multi_bss =
+               u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_MULTI_BSSID);
+       coloc_ap->transmitted_bssid =
+               u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_TRANSMITTED_BSSID);
+       coloc_ap->unsolicited_probe =
+               u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_PROBE_ACTIVE);
+       coloc_ap->colocated_ess =
+               u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_COLOC_ESS);
+
+       return u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_COLOC_AP);
+}
+
+static int cfg80211_calc_short_ssid(const struct cfg80211_bss_ies *ies,
+                                   const struct element **elem, u32 *s_ssid)
+{
+
+       *elem = cfg80211_find_elem(WLAN_EID_SSID, ies->data, ies->len);
+       if (!*elem || (*elem)->datalen > IEEE80211_MAX_SSID_LEN)
+               return -EINVAL;
+
+       *s_ssid = ~crc32_le(~0, (*elem)->data, (*elem)->datalen);
+       return 0;
+}
+
+static void cfg80211_free_coloc_ap_list(struct list_head *coloc_ap_list)
+{
+       struct cfg80211_colocated_ap *ap, *tmp_ap;
+
+       list_for_each_entry_safe(ap, tmp_ap, coloc_ap_list, list) {
+               list_del(&ap->list);
+               kfree(ap);
+       }
+}
+
+static int cfg80211_parse_ap_info(struct cfg80211_colocated_ap *entry,
+                                 const u8 *pos, u8 length,
+                                 const struct element *ssid_elem,
+                                 int s_ssid_tmp)
+{
+       /* skip the TBTT offset */
+       pos++;
+
+       memcpy(entry->bssid, pos, ETH_ALEN);
+       pos += ETH_ALEN;
+
+       if (length == IEEE80211_TBTT_INFO_OFFSET_BSSID_SSSID_BSS_PARAM) {
+               memcpy(&entry->short_ssid, pos,
+                      sizeof(entry->short_ssid));
+               entry->short_ssid_valid = true;
+               pos += 4;
+       }
+
+       /* skip non colocated APs */
+       if (!cfg80211_parse_bss_param(*pos, entry))
+               return -EINVAL;
+       pos++;
+
+       if (length == IEEE80211_TBTT_INFO_OFFSET_BSSID_BSS_PARAM) {
+               /*
+                * no information about the short ssid. Consider the entry valid
+                * for now. It would later be dropped in case there are explicit
+                * SSIDs that need to be matched
+                */
+               if (!entry->same_ssid)
+                       return 0;
+       }
+
+       if (entry->same_ssid) {
+               entry->short_ssid = s_ssid_tmp;
+               entry->short_ssid_valid = true;
+
+               /*
+                * This is safe because we validate datalen in
+                * cfg80211_parse_colocated_ap(), before calling this
+                * function.
+                */
+               memcpy(&entry->ssid, &ssid_elem->data,
+                      ssid_elem->datalen);
+               entry->ssid_len = ssid_elem->datalen;
+       }
+       return 0;
+}
+
+static int cfg80211_parse_colocated_ap(const struct cfg80211_bss_ies *ies,
+                                      struct list_head *list)
+{
+       struct ieee80211_neighbor_ap_info *ap_info;
+       const struct element *elem, *ssid_elem;
+       const u8 *pos, *end;
+       u32 s_ssid_tmp;
+       int n_coloc = 0, ret;
+       LIST_HEAD(ap_list);
+
+       elem = cfg80211_find_elem(WLAN_EID_REDUCED_NEIGHBOR_REPORT, ies->data,
+                                 ies->len);
+       if (!elem || elem->datalen > IEEE80211_MAX_SSID_LEN)
+               return 0;
+
+       pos = elem->data;
+       end = pos + elem->datalen;
+
+       ret = cfg80211_calc_short_ssid(ies, &ssid_elem, &s_ssid_tmp);
+       if (ret)
+               return ret;
+
+       /* RNR IE may contain more than one NEIGHBOR_AP_INFO */
+       while (pos + sizeof(*ap_info) <= end) {
+               enum nl80211_band band;
+               int freq;
+               u8 length, i, count;
+
+               ap_info = (void *)pos;
+               count = u8_get_bits(ap_info->tbtt_info_hdr,
+                                   IEEE80211_AP_INFO_TBTT_HDR_COUNT) + 1;
+               length = ap_info->tbtt_info_len;
+
+               pos += sizeof(*ap_info);
+
+               if (!ieee80211_operating_class_to_band(ap_info->op_class,
+                                                      &band))
+                       break;
+
+               freq = ieee80211_channel_to_frequency(ap_info->channel, band);
+
+               if (end - pos < count * ap_info->tbtt_info_len)
+                       break;
+
+               /*
+                * TBTT info must include bss param + BSSID +
+                * (short SSID or same_ssid bit to be set).
+                * ignore other options, and move to the
+                * next AP info
+                */
+               if (band != NL80211_BAND_6GHZ ||
+                   (length != IEEE80211_TBTT_INFO_OFFSET_BSSID_BSS_PARAM &&
+                    length < IEEE80211_TBTT_INFO_OFFSET_BSSID_SSSID_BSS_PARAM)) {
+                       pos += count * ap_info->tbtt_info_len;
+                       continue;
+               }
+
+               for (i = 0; i < count; i++) {
+                       struct cfg80211_colocated_ap *entry;
+
+                       entry = kzalloc(sizeof(*entry) + IEEE80211_MAX_SSID_LEN,
+                                       GFP_ATOMIC);
+
+                       if (!entry)
+                               break;
+
+                       entry->center_freq = freq;
+
+                       if (!cfg80211_parse_ap_info(entry, pos, length,
+                                                   ssid_elem, s_ssid_tmp)) {
+                               n_coloc++;
+                               list_add_tail(&entry->list, &ap_list);
+                       } else {
+                               kfree(entry);
+                       }
+
+                       pos += ap_info->tbtt_info_len;
+               }
+       }
+
+       if (pos != end) {
+               cfg80211_free_coloc_ap_list(&ap_list);
+               return 0;
+       }
+
+       list_splice_tail(&ap_list, list);
+       return n_coloc;
+}
+
+static  void cfg80211_scan_req_add_chan(struct cfg80211_scan_request *request,
+                                       struct ieee80211_channel *chan,
+                                       bool add_to_6ghz)
+{
+       int i;
+       u32 n_channels = request->n_channels;
+       struct cfg80211_scan_6ghz_params *params =
+               &request->scan_6ghz_params[request->n_6ghz_params];
+
+       for (i = 0; i < n_channels; i++) {
+               if (request->channels[i] == chan) {
+                       if (add_to_6ghz)
+                               params->channel_idx = i;
+                       return;
+               }
+       }
+
+       request->channels[n_channels] = chan;
+       if (add_to_6ghz)
+               request->scan_6ghz_params[request->n_6ghz_params].channel_idx =
+                       n_channels;
+
+       request->n_channels++;
+}
+
+static bool cfg80211_find_ssid_match(struct cfg80211_colocated_ap *ap,
+                                    struct cfg80211_scan_request *request)
+{
+       u8 i;
+       u32 s_ssid;
+
+       for (i = 0; i < request->n_ssids; i++) {
+               /* wildcard ssid in the scan request */
+               if (!request->ssids[i].ssid_len)
+                       return true;
+
+               if (ap->ssid_len &&
+                   ap->ssid_len == request->ssids[i].ssid_len) {
+                       if (!memcmp(request->ssids[i].ssid, ap->ssid,
+                                   ap->ssid_len))
+                               return true;
+               } else if (ap->short_ssid_valid) {
+                       s_ssid = ~crc32_le(~0, request->ssids[i].ssid,
+                                          request->ssids[i].ssid_len);
+
+                       if (ap->short_ssid == s_ssid)
+                               return true;
+               }
+       }
+
+       return false;
+}
+
+static int cfg80211_scan_6ghz(struct cfg80211_registered_device *rdev)
+{
+       u8 i;
+       struct cfg80211_colocated_ap *ap;
+       int n_channels, count = 0, err;
+       struct cfg80211_scan_request *request, *rdev_req = rdev->scan_req;
+       LIST_HEAD(coloc_ap_list);
+       bool need_scan_psc;
+       const struct ieee80211_sband_iftype_data *iftd;
+
+       rdev_req->scan_6ghz = true;
+
+       if (!rdev->wiphy.bands[NL80211_BAND_6GHZ])
+               return -EOPNOTSUPP;
+
+       iftd = ieee80211_get_sband_iftype_data(rdev->wiphy.bands[NL80211_BAND_6GHZ],
+                                              rdev_req->wdev->iftype);
+       if (!iftd || !iftd->he_cap.has_he)
+               return -EOPNOTSUPP;
+
+       n_channels = rdev->wiphy.bands[NL80211_BAND_6GHZ]->n_channels;
+
+       if (rdev_req->flags & NL80211_SCAN_FLAG_COLOCATED_6GHZ) {
+               struct cfg80211_internal_bss *intbss;
+
+               spin_lock_bh(&rdev->bss_lock);
+               list_for_each_entry(intbss, &rdev->bss_list, list) {
+                       struct cfg80211_bss *res = &intbss->pub;
+                       const struct cfg80211_bss_ies *ies;
+
+                       ies = rcu_access_pointer(res->ies);
+                       count += cfg80211_parse_colocated_ap(ies,
+                                                            &coloc_ap_list);
+               }
+               spin_unlock_bh(&rdev->bss_lock);
+       }
+
+       request = kzalloc(struct_size(request, channels, n_channels) +
+                         sizeof(*request->scan_6ghz_params) * count,
+                         GFP_KERNEL);
+       if (!request) {
+               cfg80211_free_coloc_ap_list(&coloc_ap_list);
+               return -ENOMEM;
+       }
+
+       *request = *rdev_req;
+       request->n_channels = 0;
+       request->scan_6ghz_params =
+               (void *)&request->channels[n_channels];
+
+       /*
+        * PSC channels should not be scanned if all the reported co-located APs
+        * are indicating that all APs in the same ESS are co-located
+        */
+       if (count) {
+               need_scan_psc = false;
+
+               list_for_each_entry(ap, &coloc_ap_list, list) {
+                       if (!ap->colocated_ess) {
+                               need_scan_psc = true;
+                               break;
+                       }
+               }
+       } else {
+               need_scan_psc = true;
+       }
+
+       /*
+        * add to the scan request the channels that need to be scanned
+        * regardless of the collocated APs (PSC channels or all channels
+        * in case that NL80211_SCAN_FLAG_COLOCATED_6GHZ is not set)
+        */
+       for (i = 0; i < rdev_req->n_channels; i++) {
+               if (rdev_req->channels[i]->band == NL80211_BAND_6GHZ &&
+                   ((need_scan_psc &&
+                     cfg80211_channel_is_psc(rdev_req->channels[i])) ||
+                    !(rdev_req->flags & NL80211_SCAN_FLAG_COLOCATED_6GHZ))) {
+                       cfg80211_scan_req_add_chan(request,
+                                                  rdev_req->channels[i],
+                                                  false);
+               }
+       }
+
+       if (!(rdev_req->flags & NL80211_SCAN_FLAG_COLOCATED_6GHZ))
+               goto skip;
+
+       list_for_each_entry(ap, &coloc_ap_list, list) {
+               bool found = false;
+               struct cfg80211_scan_6ghz_params *scan_6ghz_params =
+                       &request->scan_6ghz_params[request->n_6ghz_params];
+               struct ieee80211_channel *chan =
+                       ieee80211_get_channel(&rdev->wiphy, ap->center_freq);
+
+               if (!chan || chan->flags & IEEE80211_CHAN_DISABLED)
+                       continue;
+
+               for (i = 0; i < rdev_req->n_channels; i++) {
+                       if (rdev_req->channels[i] == chan)
+                               found = true;
+               }
+
+               if (!found)
+                       continue;
+
+               if (request->n_ssids > 0 &&
+                   !cfg80211_find_ssid_match(ap, request))
+                       continue;
+
+               cfg80211_scan_req_add_chan(request, chan, true);
+               memcpy(scan_6ghz_params->bssid, ap->bssid, ETH_ALEN);
+               scan_6ghz_params->short_ssid = ap->short_ssid;
+               scan_6ghz_params->short_ssid_valid = ap->short_ssid_valid;
+               scan_6ghz_params->unsolicited_probe = ap->unsolicited_probe;
+
+               /*
+                * If a PSC channel is added to the scan and 'need_scan_psc' is
+                * set to false, then all the APs that the scan logic is
+                * interested with on the channel are collocated and thus there
+                * is no need to perform the initial PSC channel listen.
+                */
+               if (cfg80211_channel_is_psc(chan) && !need_scan_psc)
+                       scan_6ghz_params->psc_no_listen = true;
+
+               request->n_6ghz_params++;
+       }
+
+skip:
+       cfg80211_free_coloc_ap_list(&coloc_ap_list);
+
+       if (request->n_channels) {
+               struct cfg80211_scan_request *old = rdev->int_scan_req;
+
+               rdev->int_scan_req = request;
+
+               /*
+                * If this scan follows a previous scan, save the scan start
+                * info from the first part of the scan
+                */
+               if (old)
+                       rdev->int_scan_req->info = old->info;
+
+               err = rdev_scan(rdev, request);
+               if (err) {
+                       rdev->int_scan_req = old;
+                       kfree(request);
+               } else {
+                       kfree(old);
+               }
+
+               return err;
+       }
+
+       kfree(request);
+       return -EINVAL;
+}
+
+int cfg80211_scan(struct cfg80211_registered_device *rdev)
+{
+       struct cfg80211_scan_request *request;
+       struct cfg80211_scan_request *rdev_req = rdev->scan_req;
+       u32 n_channels = 0, idx, i;
+
+       if (!(rdev->wiphy.flags & WIPHY_FLAG_SPLIT_SCAN_6GHZ))
+               return rdev_scan(rdev, rdev_req);
+
+       for (i = 0; i < rdev_req->n_channels; i++) {
+               if (rdev_req->channels[i]->band != NL80211_BAND_6GHZ)
+                       n_channels++;
+       }
+
+       if (!n_channels)
+               return cfg80211_scan_6ghz(rdev);
+
+       request = kzalloc(struct_size(request, channels, n_channels),
+                         GFP_KERNEL);
+       if (!request)
+               return -ENOMEM;
+
+       *request = *rdev_req;
+       request->n_channels = n_channels;
+
+       for (i = idx = 0; i < rdev_req->n_channels; i++) {
+               if (rdev_req->channels[i]->band != NL80211_BAND_6GHZ)
+                       request->channels[idx++] = rdev_req->channels[i];
+       }
+
+       rdev_req->scan_6ghz = false;
+       rdev->int_scan_req = request;
+       return rdev_scan(rdev, request);
+}
+
 void ___cfg80211_scan_done(struct cfg80211_registered_device *rdev,
                           bool send_message)
 {
-       struct cfg80211_scan_request *request;
+       struct cfg80211_scan_request *request, *rdev_req;
        struct wireless_dev *wdev;
        struct sk_buff *msg;
 #ifdef CONFIG_CFG80211_WEXT
                return;
        }
 
-       request = rdev->scan_req;
-       if (!request)
+       rdev_req = rdev->scan_req;
+       if (!rdev_req)
                return;
 
-       wdev = request->wdev;
+       wdev = rdev_req->wdev;
+       request = rdev->int_scan_req ? rdev->int_scan_req : rdev_req;
+
+       if (wdev_running(wdev) &&
+           (rdev->wiphy.flags & WIPHY_FLAG_SPLIT_SCAN_6GHZ) &&
+           !rdev_req->scan_6ghz && !request->info.aborted &&
+           !cfg80211_scan_6ghz(rdev))
+               return;
 
        /*
         * This must be before sending the other events!
        if (wdev->netdev)
                dev_put(wdev->netdev);
 
+       kfree(rdev->int_scan_req);
+       rdev->int_scan_req = NULL;
+
+       kfree(rdev->scan_req);
        rdev->scan_req = NULL;
-       kfree(request);
 
        if (!send_message)
                rdev->scan_msg = msg;
 void cfg80211_scan_done(struct cfg80211_scan_request *request,
                        struct cfg80211_scan_info *info)
 {
+       struct cfg80211_scan_info old_info = request->info;
+
        trace_cfg80211_scan_done(request, info);
-       WARN_ON(request != wiphy_to_rdev(request->wiphy)->scan_req);
+       WARN_ON(request != wiphy_to_rdev(request->wiphy)->scan_req &&
+               request != wiphy_to_rdev(request->wiphy)->int_scan_req);
 
        request->info = *info;
+
+       /*
+        * In case the scan is split, the scan_start_tsf and tsf_bssid should
+        * be of the first part. In such a case old_info.scan_start_tsf should
+        * be non zero.
+        */
+       if (request->scan_6ghz && old_info.scan_start_tsf) {
+               request->info.scan_start_tsf = old_info.scan_start_tsf;
+               memcpy(request->info.tsf_bssid, old_info.tsf_bssid,
+                      sizeof(request->info.tsf_bssid));
+       }
+
        request->notified = true;
        queue_work(cfg80211_wq, &wiphy_to_rdev(request->wiphy)->scan_done_wk);
 }