complete(&cm_id_priv->comp);
 }
 
-static int cm_alloc_msg(struct cm_id_private *cm_id_priv,
-                       struct ib_mad_send_buf **msg)
+static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv)
 {
        struct ib_mad_agent *mad_agent;
        struct ib_mad_send_buf *m;
        m->retries = cm_id_priv->max_cm_retries;
 
        refcount_inc(&cm_id_priv->refcount);
+       spin_unlock_irqrestore(&cm.state_lock, flags2);
        m->context[0] = cm_id_priv;
-       *msg = m;
+       return m;
 
 out:
        spin_unlock_irqrestore(&cm.state_lock, flags2);
-       return ret;
+       return ERR_PTR(ret);
+}
+
+static struct ib_mad_send_buf *
+cm_alloc_priv_msg(struct cm_id_private *cm_id_priv)
+{
+       struct ib_mad_send_buf *msg;
+
+       lockdep_assert_held(&cm_id_priv->lock);
+
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg))
+               return msg;
+       cm_id_priv->msg = msg;
+       return msg;
+}
+
+static void cm_free_priv_msg(struct ib_mad_send_buf *msg)
+{
+       struct cm_id_private *cm_id_priv = msg->context[0];
+
+       lockdep_assert_held(&cm_id_priv->lock);
+
+       if (!WARN_ON(cm_id_priv->msg != msg))
+               cm_id_priv->msg = NULL;
+
+       if (msg->ah)
+               rdma_destroy_ah(msg->ah, 0);
+       cm_deref_id(cm_id_priv);
+       ib_free_send_mad(msg);
 }
 
 static struct ib_mad_send_buf *cm_alloc_response_msg_no_ah(struct cm_port *port,
                   struct ib_cm_req_param *param)
 {
        struct cm_id_private *cm_id_priv;
+       struct ib_mad_send_buf *msg;
        struct cm_req_msg *req_msg;
        unsigned long flags;
        int ret;
        cm_id_priv->pkey = param->primary_path->pkey;
        cm_id_priv->qp_type = param->qp_type;
 
-       ret = cm_alloc_msg(cm_id_priv, &cm_id_priv->msg);
-       if (ret)
-               goto out;
+       spin_lock_irqsave(&cm_id_priv->lock, flags);
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
+               goto out_unlock;
+       }
 
-       req_msg = (struct cm_req_msg *) cm_id_priv->msg->mad;
+       req_msg = (struct cm_req_msg *)msg->mad;
        cm_format_req(req_msg, cm_id_priv, param);
        cm_id_priv->tid = req_msg->hdr.tid;
-       cm_id_priv->msg->timeout_ms = cm_id_priv->timeout_ms;
-       cm_id_priv->msg->context[1] = (void *) (unsigned long) IB_CM_REQ_SENT;
+       msg->timeout_ms = cm_id_priv->timeout_ms;
+       msg->context[1] = (void *)(unsigned long)IB_CM_REQ_SENT;
 
        cm_id_priv->local_qpn = cpu_to_be32(IBA_GET(CM_REQ_LOCAL_QPN, req_msg));
        cm_id_priv->rq_psn = cpu_to_be32(IBA_GET(CM_REQ_STARTING_PSN, req_msg));
 
        trace_icm_send_req(&cm_id_priv->id);
-       spin_lock_irqsave(&cm_id_priv->lock, flags);
-       ret = ib_post_send_mad(cm_id_priv->msg, NULL);
-       if (ret) {
-               cm_free_msg(cm_id_priv->msg);
-               spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-               goto out;
-       }
+       ret = ib_post_send_mad(msg, NULL);
+       if (ret)
+               goto out_free;
        BUG_ON(cm_id->state != IB_CM_IDLE);
        cm_id->state = IB_CM_REQ_SENT;
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return 0;
+out_free:
+       cm_free_priv_msg(msg);
+out_unlock:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
 out:
        return ret;
 }
                goto out;
        }
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
                goto out;
+       }
 
        rep_msg = (struct cm_rep_msg *) msg->mad;
        cm_format_rep(rep_msg, cm_id_priv, param);
 
        trace_icm_send_rep(cm_id);
        ret = ib_post_send_mad(msg, NULL);
-       if (ret) {
-               spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-               cm_free_msg(msg);
-               return ret;
-       }
+       if (ret)
+               goto out_free;
 
        cm_id->state = IB_CM_REP_SENT;
-       cm_id_priv->msg = msg;
        cm_id_priv->initiator_depth = param->initiator_depth;
        cm_id_priv->responder_resources = param->responder_resources;
        cm_id_priv->rq_psn = cpu_to_be32(IBA_GET(CM_REP_STARTING_PSN, rep_msg));
                  "IBTA declares QPN to be 24 bits, but it is 0x%X\n",
                  param->qp_num);
        cm_id_priv->local_qpn = cpu_to_be32(param->qp_num & 0xFFFFFF);
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+       return 0;
 
-out:   spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+out_free:
+       cm_free_priv_msg(msg);
+out:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_rep);
                goto error;
        }
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
                goto error;
+       }
 
        cm_format_rtu((struct cm_rtu_msg *) msg->mad, cm_id_priv,
                      private_data, private_data_len);
            cm_id_priv->id.lap_state == IB_CM_MRA_LAP_RCVD)
                ib_cancel_mad(cm_id_priv->av.port->mad_agent, cm_id_priv->msg);
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret) {
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
                cm_enter_timewait(cm_id_priv);
-               return ret;
+               return PTR_ERR(msg);
        }
 
        cm_format_dreq((struct cm_dreq_msg *) msg->mad, cm_id_priv,
        ret = ib_post_send_mad(msg, NULL);
        if (ret) {
                cm_enter_timewait(cm_id_priv);
-               cm_free_msg(msg);
+               cm_free_priv_msg(msg);
                return ret;
        }
 
        cm_id_priv->id.state = IB_CM_DREQ_SENT;
-       cm_id_priv->msg = msg;
        return 0;
 }
 
        cm_set_private_data(cm_id_priv, private_data, private_data_len);
        cm_enter_timewait(cm_id_priv);
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
-               return ret;
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg))
+               return PTR_ERR(msg);
 
        cm_format_drep((struct cm_drep_msg *) msg->mad, cm_id_priv,
                       private_data, private_data_len);
        case IB_CM_REP_RCVD:
        case IB_CM_MRA_REP_SENT:
                cm_reset_to_idle(cm_id_priv);
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (ret)
-                       return ret;
+               msg = cm_alloc_msg(cm_id_priv);
+               if (IS_ERR(msg))
+                       return PTR_ERR(msg);
                cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
                              ari, ari_length, private_data, private_data_len,
                              state);
        case IB_CM_REP_SENT:
        case IB_CM_MRA_REP_RCVD:
                cm_enter_timewait(cm_id_priv);
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (ret)
-                       return ret;
+               msg = cm_alloc_msg(cm_id_priv);
+               if (IS_ERR(msg))
+                       return PTR_ERR(msg);
                cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
                              ari, ari_length, private_data, private_data_len,
                              state);
        default:
                trace_icm_send_mra_unknown_err(&cm_id_priv->id);
                ret = -EINVAL;
-               goto error1;
+               goto error_unlock;
        }
 
        if (!(service_timeout & IB_CM_MRA_FLAG_DELAY)) {
-               ret = cm_alloc_msg(cm_id_priv, &msg);
-               if (ret)
-                       goto error1;
+               msg = cm_alloc_msg(cm_id_priv);
+               if (IS_ERR(msg)) {
+                       ret = PTR_ERR(msg);
+                       goto error_unlock;
+               }
 
                cm_format_mra((struct cm_mra_msg *) msg->mad, cm_id_priv,
                              msg_response, service_timeout,
                trace_icm_send_mra(cm_id);
                ret = ib_post_send_mad(msg, NULL);
                if (ret)
-                       goto error2;
+                       goto error_free_msg;
        }
 
        cm_id->state = cm_state;
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return 0;
 
-error1:        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-       kfree(data);
-       return ret;
-
-error2:        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-       kfree(data);
+error_free_msg:
        cm_free_msg(msg);
+error_unlock:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
+       kfree(data);
        return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_mra);
                                 &cm_id_priv->av,
                                 cm_id_priv);
        if (ret)
-               goto out;
+               return ret;
 
        cm_id->service_id = param->service_id;
        cm_id->service_mask = ~cpu_to_be64(0);
        cm_id_priv->timeout_ms = param->timeout_ms;
        cm_id_priv->max_cm_retries = param->max_cm_retries;
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
-               goto out;
-
-       cm_format_sidr_req((struct cm_sidr_req_msg *) msg->mad, cm_id_priv,
-                          param);
-       msg->timeout_ms = cm_id_priv->timeout_ms;
-       msg->context[1] = (void *) (unsigned long) IB_CM_SIDR_REQ_SENT;
 
        spin_lock_irqsave(&cm_id_priv->lock, flags);
-       if (cm_id->state == IB_CM_IDLE) {
-               trace_icm_send_sidr_req(&cm_id_priv->id);
-               ret = ib_post_send_mad(msg, NULL);
-       } else {
+       if (cm_id->state != IB_CM_IDLE) {
                ret = -EINVAL;
+               goto out_unlock;
        }
 
-       if (ret) {
-               spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-               cm_free_msg(msg);
-               goto out;
+       msg = cm_alloc_priv_msg(cm_id_priv);
+       if (IS_ERR(msg)) {
+               ret = PTR_ERR(msg);
+               goto out_unlock;
        }
+
+       cm_format_sidr_req((struct cm_sidr_req_msg *)msg->mad, cm_id_priv,
+                          param);
+       msg->timeout_ms = cm_id_priv->timeout_ms;
+       msg->context[1] = (void *)(unsigned long)IB_CM_SIDR_REQ_SENT;
+
+       trace_icm_send_sidr_req(&cm_id_priv->id);
+       ret = ib_post_send_mad(msg, NULL);
+       if (ret)
+               goto out_free;
        cm_id->state = IB_CM_SIDR_REQ_SENT;
-       cm_id_priv->msg = msg;
        spin_unlock_irqrestore(&cm_id_priv->lock, flags);
-out:
+       return 0;
+out_free:
+       cm_free_priv_msg(msg);
+out_unlock:
+       spin_unlock_irqrestore(&cm_id_priv->lock, flags);
        return ret;
 }
 EXPORT_SYMBOL(ib_send_cm_sidr_req);
        if (cm_id_priv->id.state != IB_CM_SIDR_REQ_RCVD)
                return -EINVAL;
 
-       ret = cm_alloc_msg(cm_id_priv, &msg);
-       if (ret)
-               return ret;
+       msg = cm_alloc_msg(cm_id_priv);
+       if (IS_ERR(msg))
+               return PTR_ERR(msg);
 
        cm_format_sidr_rep((struct cm_sidr_rep_msg *) msg->mad, cm_id_priv,
                           param);