#include "vendor.h"
 #include "bus.h"
 #include "common.h"
+#include "fwvid.h"
 
 #define BRCMF_SCAN_IE_LEN_MAX          2048
 
        return reason;
 }
 
-static int brcmf_set_pmk(struct brcmf_if *ifp, const u8 *pmk_data, u16 pmk_len)
+int brcmf_set_wsec(struct brcmf_if *ifp, const u8 *key, u16 key_len, u16 flags)
 {
        struct brcmf_pub *drvr = ifp->drvr;
        struct brcmf_wsec_pmk_le pmk;
        int err;
 
+       if (key_len > sizeof(pmk.key)) {
+               bphy_err(drvr, "key must be less than %zu bytes\n",
+                        sizeof(pmk.key));
+               return -EINVAL;
+       }
+
        memset(&pmk, 0, sizeof(pmk));
 
-       /* pass pmk directly */
-       pmk.key_len = cpu_to_le16(pmk_len);
-       pmk.flags = cpu_to_le16(0);
-       memcpy(pmk.key, pmk_data, pmk_len);
+       /* pass key material directly */
+       pmk.key_len = cpu_to_le16(key_len);
+       pmk.flags = cpu_to_le16(flags);
+       memcpy(pmk.key, key, key_len);
 
-       /* store psk in firmware */
+       /* store key material in firmware */
        err = brcmf_fil_cmd_data_set(ifp, BRCMF_C_SET_WSEC_PMK,
                                     &pmk, sizeof(pmk));
        if (err < 0)
                bphy_err(drvr, "failed to change PSK in firmware (len=%u)\n",
-                        pmk_len);
+                        key_len);
 
        return err;
 }
+BRCMF_EXPORT_SYMBOL_GPL(brcmf_set_wsec);
 
-static int brcmf_set_sae_password(struct brcmf_if *ifp, const u8 *pwd_data,
-                                 u16 pwd_len)
+static int brcmf_set_pmk(struct brcmf_if *ifp, const u8 *pmk_data, u16 pmk_len)
 {
-       struct brcmf_pub *drvr = ifp->drvr;
-       struct brcmf_wsec_sae_pwd_le sae_pwd;
-       int err;
-
-       if (pwd_len > BRCMF_WSEC_MAX_SAE_PASSWORD_LEN) {
-               bphy_err(drvr, "sae_password must be less than %d\n",
-                        BRCMF_WSEC_MAX_SAE_PASSWORD_LEN);
-               return -EINVAL;
-       }
-
-       sae_pwd.key_len = cpu_to_le16(pwd_len);
-       memcpy(sae_pwd.key, pwd_data, pwd_len);
-
-       err = brcmf_fil_iovar_data_set(ifp, "sae_password", &sae_pwd,
-                                      sizeof(sae_pwd));
-       if (err < 0)
-               bphy_err(drvr, "failed to set SAE password in firmware (len=%u)\n",
-                        pwd_len);
-
-       return err;
+       return brcmf_set_wsec(ifp, pmk_data, pmk_len, 0);
 }
 
 static void brcmf_link_down(struct brcmf_cfg80211_vif *vif, u16 reason,
                        bphy_err(drvr, "failed to clean up user-space RSNE\n");
                        goto done;
                }
-               err = brcmf_set_sae_password(ifp, sme->crypto.sae_pwd,
-                                            sme->crypto.sae_pwd_len);
+               err = brcmf_fwvid_set_sae_password(ifp, &sme->crypto);
                if (!err && sme->crypto.psk)
                        err = brcmf_set_pmk(ifp, sme->crypto.psk,
                                            BRCMF_WSEC_MAX_PSK_LEN);
                if (crypto->sae_pwd) {
                        brcmf_dbg(INFO, "using SAE offload\n");
                        profile->use_fwauth |= BIT(BRCMF_PROFILE_FWAUTH_SAE);
-                       err = brcmf_set_sae_password(ifp, crypto->sae_pwd,
-                                                    crypto->sae_pwd_len);
+                       err = brcmf_fwvid_set_sae_password(ifp, crypto);
                        if (err < 0)
                                goto exit;
                }
                msleep(400);
 
                if (profile->use_fwauth != BIT(BRCMF_PROFILE_FWAUTH_NONE)) {
+                       struct cfg80211_crypto_settings crypto = {};
+
                        if (profile->use_fwauth & BIT(BRCMF_PROFILE_FWAUTH_PSK))
                                brcmf_set_pmk(ifp, NULL, 0);
                        if (profile->use_fwauth & BIT(BRCMF_PROFILE_FWAUTH_SAE))
-                               brcmf_set_sae_password(ifp, NULL, 0);
+                               brcmf_fwvid_set_sae_password(ifp, &crypto);
                        profile->use_fwauth = BIT(BRCMF_PROFILE_FWAUTH_NONE);
                }
 
 
 #include <core.h>
 #include <bus.h>
 #include <fwvid.h>
+#include <fwil.h>
 
 #include "vops.h"
 
        pr_err("%s: executing\n", __func__);
 }
 
+static int brcmf_cyw_set_sae_pwd(struct brcmf_if *ifp,
+                                struct cfg80211_crypto_settings *crypto)
+{
+       struct brcmf_pub *drvr = ifp->drvr;
+       struct brcmf_wsec_sae_pwd_le sae_pwd;
+       u16 pwd_len = crypto->sae_pwd_len;
+       int err;
+
+       if (pwd_len > BRCMF_WSEC_MAX_SAE_PASSWORD_LEN) {
+               bphy_err(drvr, "sae_password must be less than %d\n",
+                        BRCMF_WSEC_MAX_SAE_PASSWORD_LEN);
+               return -EINVAL;
+       }
+
+       sae_pwd.key_len = cpu_to_le16(pwd_len);
+       memcpy(sae_pwd.key, crypto->sae_pwd, pwd_len);
+
+       err = brcmf_fil_iovar_data_set(ifp, "sae_password", &sae_pwd,
+                                      sizeof(sae_pwd));
+       if (err < 0)
+               bphy_err(drvr, "failed to set SAE password in firmware (len=%u)\n",
+                        pwd_len);
+
+       return err;
+}
+
 const struct brcmf_fwvid_ops brcmf_cyw_ops = {
        .attach = brcmf_cyw_attach,
        .detach = brcmf_cyw_detach,
+       .set_sae_password = brcmf_cyw_set_sae_pwd,
 };