u32 shift = GET_SHIFT(hwdata->conf);
        const u32 clk_src_266 = 2;
        u32 msk, val, bitmask;
+       unsigned long flags;
        int ret;
 
        /*
         */
        bitmask = (GENMASK(GET_WIDTH(hwdata->conf) - 1, 0) << shift) << 16;
        msk = off ? CPG_CLKSTATUS_SELSDHI1_STS : CPG_CLKSTATUS_SELSDHI0_STS;
+       spin_lock_irqsave(&priv->rmw_lock, flags);
        if (index != clk_src_266) {
                writel(bitmask | ((clk_src_266 + 1) << shift), priv->base + off);
 
-               ret = readl_poll_timeout(priv->base + CPG_CLKSTATUS, val,
-                                        !(val & msk), 100,
-                                        CPG_SDHI_CLK_SWITCH_STATUS_TIMEOUT_US);
-               if (ret) {
-                       dev_err(priv->dev, "failed to switch clk source\n");
-                       return ret;
-               }
+               ret = readl_poll_timeout_atomic(priv->base + CPG_CLKSTATUS, val,
+                                               !(val & msk), 10,
+                                               CPG_SDHI_CLK_SWITCH_STATUS_TIMEOUT_US);
+               if (ret)
+                       goto unlock;
        }
 
        writel(bitmask | ((index + 1) << shift), priv->base + off);
 
-       ret = readl_poll_timeout(priv->base + CPG_CLKSTATUS, val,
-                                !(val & msk), 100,
-                                CPG_SDHI_CLK_SWITCH_STATUS_TIMEOUT_US);
+       ret = readl_poll_timeout_atomic(priv->base + CPG_CLKSTATUS, val,
+                                       !(val & msk), 10,
+                                       CPG_SDHI_CLK_SWITCH_STATUS_TIMEOUT_US);
+unlock:
+       spin_unlock_irqrestore(&priv->rmw_lock, flags);
+
        if (ret)
                dev_err(priv->dev, "failed to switch clk source\n");
 
 
 #define CPG_CLKSTATUS_SELSDHI0_STS     BIT(28)
 #define CPG_CLKSTATUS_SELSDHI1_STS     BIT(29)
 
-#define CPG_SDHI_CLK_SWITCH_STATUS_TIMEOUT_US  20000
+#define CPG_SDHI_CLK_SWITCH_STATUS_TIMEOUT_US  200
 
 /* n = 0/1/2 for PLL1/4/6 */
 #define CPG_SAMPLL_CLK1(n)     (0x04 + (16 * n))