From f36fe0a2df03209f4d681fa954f20bfa4eefec45 Mon Sep 17 00:00:00 2001
From: Johannes Berg <johannes.berg@intel.com>
Date: Thu, 14 Jul 2022 23:40:47 +0200
Subject: [PATCH] wifi: mac80211: fix up link station creation/insertion

When we create a station with a non-default link, then
we should have a link address, and we definitely need
to insert it into the link hash table on insertion.

Split the API into with and without link creation and
if it has a link, insert the link into the link hash
table on sta_info_insert().

Fixes: ba6ddab94fc6 ("wifi: mac80211: maintain link-sta hash table")
Signed-off-by: Johannes Berg <johannes.berg@intel.com>
---
 net/mac80211/cfg.c        | 10 ++++++--
 net/mac80211/ibss.c       |  4 ++--
 net/mac80211/mesh_plink.c |  2 +-
 net/mac80211/mlme.c       |  2 +-
 net/mac80211/ocb.c        |  2 +-
 net/mac80211/sta_info.c   | 49 +++++++++++++++++++++++++++++----------
 net/mac80211/sta_info.h   |  7 +++++-
 7 files changed, 56 insertions(+), 20 deletions(-)

diff --git a/net/mac80211/cfg.c b/net/mac80211/cfg.c
index 1545648e2031e..1cacd1e0fc852 100644
--- a/net/mac80211/cfg.c
+++ b/net/mac80211/cfg.c
@@ -1850,8 +1850,14 @@ static int ieee80211_add_station(struct wiphy *wiphy, struct net_device *dev,
 	    !sdata->u.mgd.associated)
 		return -EINVAL;
 
-	sta = sta_info_alloc(sdata, mac, params->link_sta_params.link_id,
-			     GFP_KERNEL);
+	if (params->link_sta_params.link_id >= 0)
+		sta = sta_info_alloc_with_link(sdata, mac,
+					       params->link_sta_params.link_id,
+					       params->link_sta_params.link_mac,
+					       GFP_KERNEL);
+	else
+		sta = sta_info_alloc(sdata, mac, GFP_KERNEL);
+
 	if (!sta)
 		return -ENOMEM;
 
diff --git a/net/mac80211/ibss.c b/net/mac80211/ibss.c
index 60b5230778a3d..d56890e3fabb3 100644
--- a/net/mac80211/ibss.c
+++ b/net/mac80211/ibss.c
@@ -625,7 +625,7 @@ ieee80211_ibss_add_sta(struct ieee80211_sub_if_data *sdata, const u8 *bssid,
 	scan_width = cfg80211_chandef_to_scan_width(&chanctx_conf->def);
 	rcu_read_unlock();
 
-	sta = sta_info_alloc(sdata, addr, -1, GFP_KERNEL);
+	sta = sta_info_alloc(sdata, addr, GFP_KERNEL);
 	if (!sta) {
 		rcu_read_lock();
 		return NULL;
@@ -1226,7 +1226,7 @@ void ieee80211_ibss_rx_no_sta(struct ieee80211_sub_if_data *sdata,
 	scan_width = cfg80211_chandef_to_scan_width(&chanctx_conf->def);
 	rcu_read_unlock();
 
-	sta = sta_info_alloc(sdata, addr, -1, GFP_ATOMIC);
+	sta = sta_info_alloc(sdata, addr, GFP_ATOMIC);
 	if (!sta)
 		return;
 
diff --git a/net/mac80211/mesh_plink.c b/net/mac80211/mesh_plink.c
index 84e3f43fd5c66..ddfe5102b9a43 100644
--- a/net/mac80211/mesh_plink.c
+++ b/net/mac80211/mesh_plink.c
@@ -511,7 +511,7 @@ __mesh_sta_info_alloc(struct ieee80211_sub_if_data *sdata, u8 *hw_addr)
 	if (aid < 0)
 		return NULL;
 
-	sta = sta_info_alloc(sdata, hw_addr, -1, GFP_KERNEL);
+	sta = sta_info_alloc(sdata, hw_addr, GFP_KERNEL);
 	if (!sta)
 		return NULL;
 
diff --git a/net/mac80211/mlme.c b/net/mac80211/mlme.c
index 3263bb1882841..c7fb12d6fb173 100644
--- a/net/mac80211/mlme.c
+++ b/net/mac80211/mlme.c
@@ -5909,7 +5909,7 @@ static int ieee80211_prep_connection(struct ieee80211_sub_if_data *sdata,
 	}
 
 	if (!have_sta) {
-		new_sta = sta_info_alloc(sdata, cbss->bssid, -1, GFP_KERNEL);
+		new_sta = sta_info_alloc(sdata, cbss->bssid, GFP_KERNEL);
 		if (!new_sta)
 			return -ENOMEM;
 	}
diff --git a/net/mac80211/ocb.c b/net/mac80211/ocb.c
index 8664fee699e9a..a57dcbe99a0dc 100644
--- a/net/mac80211/ocb.c
+++ b/net/mac80211/ocb.c
@@ -69,7 +69,7 @@ void ieee80211_ocb_rx_no_sta(struct ieee80211_sub_if_data *sdata,
 	scan_width = cfg80211_chandef_to_scan_width(&chanctx_conf->def);
 	rcu_read_unlock();
 
-	sta = sta_info_alloc(sdata, addr, -1, GFP_ATOMIC);
+	sta = sta_info_alloc(sdata, addr, GFP_ATOMIC);
 	if (!sta)
 		return;
 
diff --git a/net/mac80211/sta_info.c b/net/mac80211/sta_info.c
index eed88630594f7..75122eced1044 100644
--- a/net/mac80211/sta_info.c
+++ b/net/mac80211/sta_info.c
@@ -96,6 +96,14 @@ static int sta_info_hash_del(struct ieee80211_local *local,
 			       sta_rht_params);
 }
 
+static int link_sta_info_hash_add(struct ieee80211_local *local,
+				  struct link_sta_info *link_sta)
+{
+	return rhltable_insert(&local->link_sta_hash,
+			       &link_sta->link_hash_node,
+			       link_sta_rht_params);
+}
+
 static int link_sta_info_hash_del(struct ieee80211_local *local,
 				  struct link_sta_info *link_sta)
 {
@@ -466,8 +474,10 @@ static void sta_info_add_link(struct sta_info *sta,
 	rcu_assign_pointer(sta->sta.link[link_id], link_sta);
 }
 
-struct sta_info *sta_info_alloc(struct ieee80211_sub_if_data *sdata,
-				const u8 *addr, int link_id, gfp_t gfp)
+static struct sta_info *
+__sta_info_alloc(struct ieee80211_sub_if_data *sdata,
+		 const u8 *addr, int link_id, const u8 *link_addr,
+		 gfp_t gfp)
 {
 	struct ieee80211_local *local = sdata->local;
 	struct ieee80211_hw *hw = &local->hw;
@@ -513,8 +523,8 @@ struct sta_info *sta_info_alloc(struct ieee80211_sub_if_data *sdata,
 
 	memcpy(sta->addr, addr, ETH_ALEN);
 	memcpy(sta->sta.addr, addr, ETH_ALEN);
-	memcpy(sta->deflink.addr, addr, ETH_ALEN);
-	memcpy(sta->sta.deflink.addr, addr, ETH_ALEN);
+	memcpy(sta->deflink.addr, link_addr, ETH_ALEN);
+	memcpy(sta->sta.deflink.addr, link_addr, ETH_ALEN);
 	sta->sta.max_rx_aggregation_subframes =
 		local->hw.max_rx_aggregation_subframes;
 
@@ -641,6 +651,21 @@ free:
 	return NULL;
 }
 
+struct sta_info *sta_info_alloc(struct ieee80211_sub_if_data *sdata,
+				const u8 *addr, gfp_t gfp)
+{
+	return __sta_info_alloc(sdata, addr, -1, addr, gfp);
+}
+
+struct sta_info *sta_info_alloc_with_link(struct ieee80211_sub_if_data *sdata,
+					  const u8 *mld_addr,
+					  unsigned int link_id,
+					  const u8 *link_addr,
+					  gfp_t gfp)
+{
+	return __sta_info_alloc(sdata, mld_addr, link_id, link_addr, gfp);
+}
+
 static int sta_info_insert_check(struct sta_info *sta)
 {
 	struct ieee80211_sub_if_data *sdata = sta->sdata;
@@ -774,6 +799,14 @@ static int sta_info_insert_finish(struct sta_info *sta) __acquires(RCU)
 	if (err)
 		goto out_drop_sta;
 
+	if (sta->sta.valid_links) {
+		err = link_sta_info_hash_add(local, &sta->deflink);
+		if (err) {
+			sta_info_hash_del(local, sta);
+			goto out_drop_sta;
+		}
+	}
+
 	list_add_tail_rcu(&sta->list, &local->sta_list);
 
 	/* update channel context before notifying the driver about state
@@ -2697,14 +2730,6 @@ int ieee80211_sta_allocate_link(struct sta_info *sta, unsigned int link_id)
 	return 0;
 }
 
-static int link_sta_info_hash_add(struct ieee80211_local *local,
-				  struct link_sta_info *link_sta)
-{
-	return rhltable_insert(&local->link_sta_hash,
-			       &link_sta->link_hash_node,
-			       link_sta_rht_params);
-}
-
 void ieee80211_sta_free_link(struct sta_info *sta, unsigned int link_id)
 {
 	lockdep_assert_held(&sta->sdata->local->sta_mtx);
diff --git a/net/mac80211/sta_info.h b/net/mac80211/sta_info.h
index 6bc26a2c46073..2eb3a9452e075 100644
--- a/net/mac80211/sta_info.h
+++ b/net/mac80211/sta_info.h
@@ -840,7 +840,12 @@ struct sta_info *sta_info_get_by_idx(struct ieee80211_sub_if_data *sdata,
  * until sta_info_insert().
  */
 struct sta_info *sta_info_alloc(struct ieee80211_sub_if_data *sdata,
-				const u8 *addr, int link_id, gfp_t gfp);
+				const u8 *addr, gfp_t gfp);
+struct sta_info *sta_info_alloc_with_link(struct ieee80211_sub_if_data *sdata,
+					  const u8 *mld_addr,
+					  unsigned int link_id,
+					  const u8 *link_addr,
+					  gfp_t gfp);
 
 void sta_info_free(struct ieee80211_local *local, struct sta_info *sta);
 
-- 
2.30.2