#include "ps.h"
 #include "reg.h"
 #include "sar.h"
+#include "txrx.h"
 #include "util.h"
 
 static u16 get_max_amsdu_len(struct rtw89_dev *rtwdev,
                rtw89_phy_read32_mask(rtwdev, R_BANDEDGE, B_BANDEDGE_EN);
 }
 
+static
+void rtw89_phy_antdiv_sts_instance_reset(struct rtw89_antdiv_stats *antdiv_sts)
+{
+       ewma_rssi_init(&antdiv_sts->cck_rssi_avg);
+       ewma_rssi_init(&antdiv_sts->ofdm_rssi_avg);
+       ewma_rssi_init(&antdiv_sts->non_legacy_rssi_avg);
+       antdiv_sts->pkt_cnt_cck = 0;
+       antdiv_sts->pkt_cnt_ofdm = 0;
+       antdiv_sts->pkt_cnt_non_legacy = 0;
+}
+
+static void rtw89_phy_antdiv_sts_instance_add(struct rtw89_dev *rtwdev,
+                                             struct rtw89_rx_phy_ppdu *phy_ppdu,
+                                             struct rtw89_antdiv_stats *stats)
+{
+       if (GET_DATA_RATE_MODE(phy_ppdu->rate) == DATA_RATE_MODE_NON_HT) {
+               if (phy_ppdu->rate < RTW89_HW_RATE_OFDM6) {
+                       ewma_rssi_add(&stats->cck_rssi_avg, phy_ppdu->rssi_avg);
+                       stats->pkt_cnt_cck++;
+               } else {
+                       ewma_rssi_add(&stats->ofdm_rssi_avg, phy_ppdu->rssi_avg);
+                       stats->pkt_cnt_ofdm++;
+               }
+       } else {
+               ewma_rssi_add(&stats->non_legacy_rssi_avg, phy_ppdu->rssi_avg);
+               stats->pkt_cnt_non_legacy++;
+       }
+}
+
+static u8 rtw89_phy_antdiv_sts_instance_get_rssi(struct rtw89_antdiv_stats *stats)
+{
+       if (stats->pkt_cnt_non_legacy >= stats->pkt_cnt_cck &&
+           stats->pkt_cnt_non_legacy >= stats->pkt_cnt_ofdm)
+               return ewma_rssi_read(&stats->non_legacy_rssi_avg);
+       else if (stats->pkt_cnt_ofdm >= stats->pkt_cnt_cck &&
+                stats->pkt_cnt_ofdm >= stats->pkt_cnt_non_legacy)
+               return ewma_rssi_read(&stats->ofdm_rssi_avg);
+       else
+               return ewma_rssi_read(&stats->cck_rssi_avg);
+}
+
+void rtw89_phy_antdiv_parse(struct rtw89_dev *rtwdev,
+                           struct rtw89_rx_phy_ppdu *phy_ppdu)
+{
+       struct rtw89_antdiv_info *antdiv = &rtwdev->antdiv;
+       struct rtw89_hal *hal = &rtwdev->hal;
+
+       if (!hal->ant_diversity || hal->ant_diversity_fixed)
+               return;
+
+       rtw89_phy_antdiv_sts_instance_add(rtwdev, phy_ppdu, &antdiv->target_stats);
+
+       if (!antdiv->get_stats)
+               return;
+
+       if (hal->antenna_rx == RF_A)
+               rtw89_phy_antdiv_sts_instance_add(rtwdev, phy_ppdu, &antdiv->main_stats);
+       else if (hal->antenna_rx == RF_B)
+               rtw89_phy_antdiv_sts_instance_add(rtwdev, phy_ppdu, &antdiv->aux_stats);
+}
+
 static void rtw89_phy_antdiv_reg_init(struct rtw89_dev *rtwdev)
 {
        rtw89_phy_write32_idx(rtwdev, R_P0_TRSW, B_P0_ANT_TRAIN_EN,
                              0x0, RTW89_PHY_0);
 }
 
+static void rtw89_phy_antdiv_sts_reset(struct rtw89_dev *rtwdev)
+{
+       struct rtw89_antdiv_info *antdiv = &rtwdev->antdiv;
+
+       rtw89_phy_antdiv_sts_instance_reset(&antdiv->target_stats);
+       rtw89_phy_antdiv_sts_instance_reset(&antdiv->main_stats);
+       rtw89_phy_antdiv_sts_instance_reset(&antdiv->aux_stats);
+}
+
 static void rtw89_phy_antdiv_init(struct rtw89_dev *rtwdev)
 {
+       struct rtw89_antdiv_info *antdiv = &rtwdev->antdiv;
        struct rtw89_hal *hal = &rtwdev->hal;
 
        if (!hal->ant_diversity)
                return;
 
+       antdiv->get_stats = false;
+       antdiv->rssi_pre = 0;
+       rtw89_phy_antdiv_sts_reset(rtwdev);
        rtw89_phy_antdiv_reg_init(rtwdev);
 }
 
                              default_ant, RTW89_PHY_0);
 }
 
+static void rtw89_phy_swap_hal_antenna(struct rtw89_dev *rtwdev)
+{
+       struct rtw89_hal *hal = &rtwdev->hal;
+
+       hal->antenna_rx = hal->antenna_rx == RF_A ? RF_B : RF_A;
+       hal->antenna_tx = hal->antenna_rx;
+}
+
+static void rtw89_phy_antdiv_decision_state(struct rtw89_dev *rtwdev)
+{
+       struct rtw89_antdiv_info *antdiv = &rtwdev->antdiv;
+       struct rtw89_hal *hal = &rtwdev->hal;
+       bool no_change = false;
+       u8 main_rssi, aux_rssi;
+       u32 candidate;
+
+       antdiv->get_stats = false;
+       antdiv->training_count = 0;
+
+       main_rssi = rtw89_phy_antdiv_sts_instance_get_rssi(&antdiv->main_stats);
+       aux_rssi = rtw89_phy_antdiv_sts_instance_get_rssi(&antdiv->aux_stats);
+
+       if (main_rssi > aux_rssi + RTW89_TX_DIV_RSSI_RAW_TH)
+               candidate = RF_A;
+       else if (aux_rssi > main_rssi + RTW89_TX_DIV_RSSI_RAW_TH)
+               candidate = RF_B;
+       else
+               no_change = true;
+
+       if (no_change) {
+               /* swap back from training antenna to original */
+               rtw89_phy_swap_hal_antenna(rtwdev);
+               return;
+       }
+
+       hal->antenna_tx = candidate;
+       hal->antenna_rx = candidate;
+}
+
+static void rtw89_phy_antdiv_training_state(struct rtw89_dev *rtwdev)
+{
+       struct rtw89_antdiv_info *antdiv = &rtwdev->antdiv;
+       u64 state_period;
+
+       if (antdiv->training_count % 2 == 0) {
+               if (antdiv->training_count == 0)
+                       rtw89_phy_antdiv_sts_reset(rtwdev);
+
+               antdiv->get_stats = true;
+               state_period = msecs_to_jiffies(ANTDIV_TRAINNING_INTVL);
+       } else {
+               antdiv->get_stats = false;
+               state_period = msecs_to_jiffies(ANTDIV_DELAY);
+
+               rtw89_phy_swap_hal_antenna(rtwdev);
+               rtw89_phy_antdiv_set_ant(rtwdev);
+       }
+
+       antdiv->training_count++;
+       ieee80211_queue_delayed_work(rtwdev->hw, &rtwdev->antdiv_work,
+                                    state_period);
+}
+
+void rtw89_phy_antdiv_work(struct work_struct *work)
+{
+       struct rtw89_dev *rtwdev = container_of(work, struct rtw89_dev,
+                                               antdiv_work.work);
+       struct rtw89_antdiv_info *antdiv = &rtwdev->antdiv;
+
+       mutex_lock(&rtwdev->mutex);
+
+       if (antdiv->training_count <= ANTDIV_TRAINNING_CNT) {
+               rtw89_phy_antdiv_training_state(rtwdev);
+       } else {
+               rtw89_phy_antdiv_decision_state(rtwdev);
+               rtw89_phy_antdiv_set_ant(rtwdev);
+       }
+
+       mutex_unlock(&rtwdev->mutex);
+}
+
+void rtw89_phy_antdiv_track(struct rtw89_dev *rtwdev)
+{
+       struct rtw89_antdiv_info *antdiv = &rtwdev->antdiv;
+       struct rtw89_hal *hal = &rtwdev->hal;
+       u8 rssi, rssi_pre;
+
+       if (!hal->ant_diversity || hal->ant_diversity_fixed)
+               return;
+
+       rssi = rtw89_phy_antdiv_sts_instance_get_rssi(&antdiv->target_stats);
+       rssi_pre = antdiv->rssi_pre;
+       antdiv->rssi_pre = rssi;
+       rtw89_phy_antdiv_sts_instance_reset(&antdiv->target_stats);
+
+       if (abs((int)rssi - (int)rssi_pre) < ANTDIV_RSSI_DIFF_TH)
+               return;
+
+       antdiv->training_count = 0;
+       ieee80211_queue_delayed_work(rtwdev->hw, &rtwdev->antdiv_work, 0);
+}
+
 static void rtw89_phy_env_monitor_init(struct rtw89_dev *rtwdev)
 {
        rtw89_phy_ccx_top_setting_init(rtwdev);