#include <linux/platform_device.h>
 #include <linux/usb/typec_altmode.h>
 #include <linux/usb/typec_dp.h>
+#include <linux/usb/typec_mux.h>
 #include <linux/usb/typec_retimer.h>
 
 #define DRV_NAME "cros-typec-switch"
 /* Handles and other relevant data required for each port's switches. */
 struct cros_typec_port {
        int port_num;
+       struct typec_mux_dev *mode_switch;
        struct typec_retimer *retimer;
        struct cros_typec_switch_data *sdata;
 };
        return -ETIMEDOUT;
 }
 
+static int cros_typec_mode_switch_set(struct typec_mux_dev *mode_switch,
+                                     struct typec_mux_state *state)
+{
+       struct cros_typec_port *port = typec_mux_get_drvdata(mode_switch);
+
+       /* Mode switches have index 0. */
+       return cros_typec_configure_mux(port->sdata, port->port_num, 0, state->mode, state->alt);
+}
+
 static int cros_typec_retimer_set(struct typec_retimer *retimer, struct typec_retimer_state *state)
 {
        struct cros_typec_port *port = typec_retimer_get_drvdata(retimer);
                if (!sdata->ports[i])
                        continue;
                typec_retimer_unregister(sdata->ports[i]->retimer);
+               typec_mux_unregister(sdata->ports[i]->mode_switch);
        }
 }
 
+static int cros_typec_register_mode_switch(struct cros_typec_port *port,
+                                          struct fwnode_handle *fwnode)
+{
+       struct typec_mux_desc mode_switch_desc = {
+               .fwnode = fwnode,
+               .drvdata = port,
+               .name = fwnode_get_name(fwnode),
+               .set = cros_typec_mode_switch_set,
+       };
+
+       port->mode_switch = typec_mux_register(port->sdata->dev, &mode_switch_desc);
+       if (IS_ERR(port->mode_switch))
+               return PTR_ERR(port->mode_switch);
+
+       return 0;
+}
+
 static int cros_typec_register_retimer(struct cros_typec_port *port, struct fwnode_handle *fwnode)
 {
        struct typec_retimer_desc retimer_desc = {
                }
 
                dev_dbg(dev, "Retimer switch registered for index %llu\n", index);
+
+               if (!device_property_present(fwnode->dev, "mode-switch"))
+                       continue;
+
+               ret = cros_typec_register_mode_switch(port, fwnode);
+               if (ret) {
+                       dev_err(dev, "Mode switch register failed\n");
+                       goto err_switch;
+               }
+
+               dev_dbg(dev, "Mode switch registered for index %llu\n", index);
        }
 
        return 0;