fid->rif = rif;
 }
 
+struct mlxsw_sp_rif *mlxsw_sp_fid_rif(const struct mlxsw_sp_fid *fid)
+{
+       return fid->rif;
+}
+
 enum mlxsw_sp_rif_type
 mlxsw_sp_fid_type_rif_type(const struct mlxsw_sp *mlxsw_sp,
                           enum mlxsw_sp_fid_type type)
        struct mlxsw_sp_fid_family *fid_family = fid->fid_family;
        struct mlxsw_sp *mlxsw_sp = fid_family->mlxsw_sp;
 
-       if (--fid->ref_count == 1 && fid->rif) {
-               /* Destroy the associated RIF and let it drop the last
-                * reference on the FID.
-                */
-               return mlxsw_sp_rif_destroy(fid->rif);
-       } else if (fid->ref_count == 0) {
-               list_del(&fid->list);
-               rhashtable_remove_fast(&mlxsw_sp->fid_core->fid_ht,
-                                      &fid->ht_node, mlxsw_sp_fid_ht_params);
-               fid->fid_family->ops->deconfigure(fid);
-               __clear_bit(fid->fid_index - fid_family->start_index,
-                           fid_family->fids_bitmap);
-               kfree(fid);
-       }
+       if (--fid->ref_count != 0)
+               return;
+
+       list_del(&fid->list);
+       rhashtable_remove_fast(&mlxsw_sp->fid_core->fid_ht,
+                              &fid->ht_node, mlxsw_sp_fid_ht_params);
+       fid->fid_family->ops->deconfigure(fid);
+       __clear_bit(fid->fid_index - fid_family->start_index,
+                   fid_family->fids_bitmap);
+       kfree(fid);
 }
 
 struct mlxsw_sp_fid *mlxsw_sp_fid_8021q_get(struct mlxsw_sp *mlxsw_sp, u16 vid)
 
 #include <linux/gcd.h>
 #include <linux/random.h>
 #include <linux/if_macvlan.h>
+#include <linux/refcount.h>
 #include <net/netevent.h>
 #include <net/neighbour.h>
 #include <net/arp.h>
 
 struct mlxsw_sp_rif_subport {
        struct mlxsw_sp_rif common;
+       refcount_t ref_count;
        union {
                u16 system_port;
                u16 lag_id;
        void (*fdb_del)(struct mlxsw_sp_rif *rif, const char *mac);
 };
 
+static void mlxsw_sp_rif_destroy(struct mlxsw_sp_rif *rif);
 static void mlxsw_sp_lpm_tree_hold(struct mlxsw_sp_lpm_tree *lpm_tree);
 static void mlxsw_sp_lpm_tree_put(struct mlxsw_sp *mlxsw_sp,
                                  struct mlxsw_sp_lpm_tree *lpm_tree);
        return ERR_PTR(err);
 }
 
-void mlxsw_sp_rif_destroy(struct mlxsw_sp_rif *rif)
+static void mlxsw_sp_rif_destroy(struct mlxsw_sp_rif *rif)
 {
        const struct mlxsw_sp_rif_ops *ops = rif->ops;
        struct mlxsw_sp *mlxsw_sp = rif->mlxsw_sp;
                params->system_port = mlxsw_sp_port->local_port;
 }
 
+static struct mlxsw_sp_rif_subport *
+mlxsw_sp_rif_subport_rif(const struct mlxsw_sp_rif *rif)
+{
+       return container_of(rif, struct mlxsw_sp_rif_subport, common);
+}
+
+static struct mlxsw_sp_rif *
+mlxsw_sp_rif_subport_get(struct mlxsw_sp *mlxsw_sp,
+                        const struct mlxsw_sp_rif_params *params,
+                        struct netlink_ext_ack *extack)
+{
+       struct mlxsw_sp_rif_subport *rif_subport;
+       struct mlxsw_sp_rif *rif;
+
+       rif = mlxsw_sp_rif_find_by_dev(mlxsw_sp, params->dev);
+       if (!rif)
+               return mlxsw_sp_rif_create(mlxsw_sp, params, extack);
+
+       rif_subport = mlxsw_sp_rif_subport_rif(rif);
+       refcount_inc(&rif_subport->ref_count);
+       return rif;
+}
+
+static void mlxsw_sp_rif_subport_put(struct mlxsw_sp_rif *rif)
+{
+       struct mlxsw_sp_rif_subport *rif_subport;
+
+       rif_subport = mlxsw_sp_rif_subport_rif(rif);
+       if (!refcount_dec_and_test(&rif_subport->ref_count))
+               return;
+
+       mlxsw_sp_rif_destroy(rif);
+}
+
 static int
 mlxsw_sp_port_vlan_router_join(struct mlxsw_sp_port_vlan *mlxsw_sp_port_vlan,
                               struct net_device *l3_dev,
 {
        struct mlxsw_sp_port *mlxsw_sp_port = mlxsw_sp_port_vlan->mlxsw_sp_port;
        struct mlxsw_sp *mlxsw_sp = mlxsw_sp_port->mlxsw_sp;
+       struct mlxsw_sp_rif_params params = {
+               .dev = l3_dev,
+       };
        u16 vid = mlxsw_sp_port_vlan->vid;
        struct mlxsw_sp_rif *rif;
        struct mlxsw_sp_fid *fid;
        int err;
 
-       rif = mlxsw_sp_rif_find_by_dev(mlxsw_sp, l3_dev);
-       if (!rif) {
-               struct mlxsw_sp_rif_params params = {
-                       .dev = l3_dev,
-               };
-
-               mlxsw_sp_rif_subport_params_init(¶ms, mlxsw_sp_port_vlan);
-               rif = mlxsw_sp_rif_create(mlxsw_sp, ¶ms, extack);
-               if (IS_ERR(rif))
-                       return PTR_ERR(rif);
-       }
+       mlxsw_sp_rif_subport_params_init(¶ms, mlxsw_sp_port_vlan);
+       rif = mlxsw_sp_rif_subport_get(mlxsw_sp, ¶ms, extack);
+       if (IS_ERR(rif))
+               return PTR_ERR(rif);
 
        /* FID was already created, just take a reference */
        fid = rif->ops->fid_get(rif, extack);
        mlxsw_sp_fid_port_vid_unmap(fid, mlxsw_sp_port, vid);
 err_fid_port_vid_map:
        mlxsw_sp_fid_put(fid);
+       mlxsw_sp_rif_subport_put(rif);
        return err;
 }
 
 {
        struct mlxsw_sp_port *mlxsw_sp_port = mlxsw_sp_port_vlan->mlxsw_sp_port;
        struct mlxsw_sp_fid *fid = mlxsw_sp_port_vlan->fid;
+       struct mlxsw_sp_rif *rif = mlxsw_sp_fid_rif(fid);
        u16 vid = mlxsw_sp_port_vlan->vid;
 
        if (WARN_ON(mlxsw_sp_fid_type(fid) != MLXSW_SP_FID_TYPE_RFID))
        mlxsw_sp_port_vid_stp_set(mlxsw_sp_port, vid, BR_STATE_BLOCKING);
        mlxsw_sp_port_vid_learning_set(mlxsw_sp_port, vid, true);
        mlxsw_sp_fid_port_vid_unmap(fid, mlxsw_sp_port, vid);
-       /* If router port holds the last reference on the rFID, then the
-        * associated Sub-port RIF will be destroyed.
-        */
        mlxsw_sp_fid_put(fid);
+       mlxsw_sp_rif_subport_put(rif);
 }
 
 static int mlxsw_sp_inetaddr_port_vlan_event(struct net_device *l3_dev,
                                             __mlxsw_sp_rif_macvlan_flush, rif);
 }
 
-static struct mlxsw_sp_rif_subport *
-mlxsw_sp_rif_subport_rif(const struct mlxsw_sp_rif *rif)
-{
-       return container_of(rif, struct mlxsw_sp_rif_subport, common);
-}
-
 static void mlxsw_sp_rif_subport_setup(struct mlxsw_sp_rif *rif,
                                       const struct mlxsw_sp_rif_params *params)
 {
        struct mlxsw_sp_rif_subport *rif_subport;
 
        rif_subport = mlxsw_sp_rif_subport_rif(rif);
+       refcount_set(&rif_subport->ref_count, 1);
        rif_subport->vid = params->vid;
        rif_subport->lag = params->lag;
        if (params->lag)