!!(rate->flags & IEEE80211_TX_RC_40_MHZ_WIDTH));
 }
 
+/*
+ * Look up an MCS group index based on new cfg80211 rate_info.
+ */
+static int
+minstrel_ht_ri_get_group_idx(struct rate_info *rate)
+{
+       return GROUP_IDX((rate->mcs / 8) + 1,
+                        !!(rate->flags & RATE_INFO_FLAGS_SHORT_GI),
+                        !!(rate->bw & RATE_INFO_BW_40));
+}
+
 static int
 minstrel_vht_get_group_idx(struct ieee80211_tx_rate *rate)
 {
                             2*!!(rate->flags & IEEE80211_TX_RC_80_MHZ_WIDTH));
 }
 
+/*
+ * Look up an MCS group index based on new cfg80211 rate_info.
+ */
+static int
+minstrel_vht_ri_get_group_idx(struct rate_info *rate)
+{
+       return VHT_GROUP_IDX(rate->nss,
+                            !!(rate->flags & RATE_INFO_FLAGS_SHORT_GI),
+                            !!(rate->bw & RATE_INFO_BW_40) +
+                            2*!!(rate->bw & RATE_INFO_BW_80));
+}
+
 static struct minstrel_rate_stats *
 minstrel_ht_get_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
                      struct ieee80211_tx_rate *rate)
        return &mi->groups[group].rates[idx];
 }
 
+/*
+ * Get the minstrel rate statistics for specified STA and rate info.
+ */
+static struct minstrel_rate_stats *
+minstrel_ht_ri_get_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi,
+                         struct ieee80211_rate_status *rate_status)
+{
+       int group, idx;
+       struct rate_info *rate = &rate_status->rate_idx;
+
+       if (rate->flags & RATE_INFO_FLAGS_MCS) {
+               group = minstrel_ht_ri_get_group_idx(rate);
+               idx = rate->mcs % 8;
+               goto out;
+       }
+
+       if (rate->flags & RATE_INFO_FLAGS_VHT_MCS) {
+               group = minstrel_vht_ri_get_group_idx(rate);
+               idx = rate->mcs;
+               goto out;
+       }
+
+       group = MINSTREL_CCK_GROUP;
+       for (idx = 0; idx < ARRAY_SIZE(mp->cck_rates); idx++) {
+               if (rate->legacy != minstrel_cck_bitrates[ mp->cck_rates[idx] ])
+                       continue;
+
+               /* short preamble */
+               if ((mi->supported[group] & BIT(idx + 4)) &&
+                                                       mi->use_short_preamble)
+                       idx += 4;
+               goto out;
+       }
+
+       group = MINSTREL_OFDM_GROUP;
+       for (idx = 0; idx < ARRAY_SIZE(mp->ofdm_rates[0]); idx++)
+               if (rate->legacy == minstrel_ofdm_bitrates[ mp->ofdm_rates[mi->band][idx] ])
+                       goto out;
+
+       idx = 0;
+out:
+       return &mi->groups[group].rates[idx];
+}
+
 static inline struct minstrel_rate_stats *
 minstrel_get_ratestats(struct minstrel_ht_sta *mi, int index)
 {
        return false;
 }
 
+/*
+ * Check whether rate_status contains valid information.
+ */
+static bool
+minstrel_ht_ri_txstat_valid(struct minstrel_priv *mp,
+                           struct minstrel_ht_sta *mi,
+                           struct ieee80211_rate_status *rate_status)
+{
+       int i;
+
+       if (!rate_status)
+               return false;
+       if (!rate_status->try_count)
+               return false;
+
+       if (rate_status->rate_idx.flags & RATE_INFO_FLAGS_MCS ||
+           rate_status->rate_idx.flags & RATE_INFO_FLAGS_VHT_MCS)
+               return true;
+
+       for (i = 0; i < ARRAY_SIZE(mp->cck_rates); i++) {
+               if (rate_status->rate_idx.legacy ==
+                   minstrel_cck_bitrates[ mp->cck_rates[i] ])
+                       return true;
+       }
+
+       for (i = 0; i < ARRAY_SIZE(mp->ofdm_rates); i++) {
+               if (rate_status->rate_idx.legacy ==
+                   minstrel_ofdm_bitrates[ mp->ofdm_rates[mi->band][i] ])
+                       return true;
+       }
+
+       return false;
+}
+
 static void
 minstrel_downgrade_rate(struct minstrel_ht_sta *mi, u16 *idx, bool primary)
 {
        mi->ampdu_packets++;
        mi->ampdu_len += info->status.ampdu_len;
 
-       last = !minstrel_ht_txstat_valid(mp, mi, &ar[0]);
-       for (i = 0; !last; i++) {
-               last = (i == IEEE80211_TX_MAX_RATES - 1) ||
-                      !minstrel_ht_txstat_valid(mp, mi, &ar[i + 1]);
+       if (st->rates && st->n_rates) {
+               last = !minstrel_ht_ri_txstat_valid(mp, mi, &(st->rates[0]));
+               for (i = 0; !last; i++) {
+                       last = (i == st->n_rates - 1) ||
+                               !minstrel_ht_ri_txstat_valid(mp, mi,
+                                                       &(st->rates[i + 1]));
+
+                       rate = minstrel_ht_ri_get_stats(mp, mi,
+                                                       &(st->rates[i]));
+
+                       if (last)
+                               rate->success += info->status.ampdu_ack_len;
+
+                       rate->attempts += st->rates[i].try_count *
+                                         info->status.ampdu_len;
+               }
+       } else {
+               last = !minstrel_ht_txstat_valid(mp, mi, &ar[0]);
+               for (i = 0; !last; i++) {
+                       last = (i == IEEE80211_TX_MAX_RATES - 1) ||
+                               !minstrel_ht_txstat_valid(mp, mi, &ar[i + 1]);
 
-               rate = minstrel_ht_get_stats(mp, mi, &ar[i]);
-               if (last)
-                       rate->success += info->status.ampdu_ack_len;
+                       rate = minstrel_ht_get_stats(mp, mi, &ar[i]);
+                       if (last)
+                               rate->success += info->status.ampdu_ack_len;
 
-               rate->attempts += ar[i].count * info->status.ampdu_len;
+                       rate->attempts += ar[i].count * info->status.ampdu_len;
+               }
        }
 
        if (mp->hw->max_rates > 1) {
        u16 ht_cap = sta->deflink.ht_cap.cap;
        struct ieee80211_sta_vht_cap *vht_cap = &sta->deflink.vht_cap;
        const struct ieee80211_rate *ctl_rate;
+       struct sta_info *sta_info;
        bool ldpc, erp;
        int use_vht;
        int n_supported = 0;
                        n_supported++;
        }
 
+       sta_info = container_of(sta, struct sta_info, sta);
+       mi->use_short_preamble = test_sta_flag(sta_info, WLAN_STA_SHORT_PREAMBLE) &&
+                                sta_info->sdata->vif.bss_conf.use_short_preamble;
+
        minstrel_ht_update_cck(mp, mi, sband, sta);
        minstrel_ht_update_ofdm(mp, mi, sband, sta);