nvmet-tcp: control messages for recvmsg()
authorHannes Reinecke <hare@suse.de>
Thu, 24 Aug 2023 14:39:24 +0000 (16:39 +0200)
committerKeith Busch <kbusch@kernel.org>
Wed, 11 Oct 2023 17:29:59 +0000 (10:29 -0700)
kTLS requires control messages for recvmsg() to relay any out-of-band
TLS messages (eg TLS alerts) to the caller.

Signed-off-by: Hannes Reinecke <hare@suse.de>
Reviewed-by: Sagi Grimberg <sagi@grimberg.me>
Signed-off-by: Keith Busch <kbusch@kernel.org>
drivers/nvme/target/tcp.c

index 58a4a737be77cc47f428354de8778bba46b49df4..c9bca7d6160529fef10d0f97a82ab63336f5059f 100644 (file)
@@ -14,6 +14,7 @@
 #include <net/sock.h>
 #include <net/tcp.h>
 #include <net/tls.h>
+#include <net/tls_prot.h>
 #include <net/handshake.h>
 #include <linux/inet.h>
 #include <linux/llist.h>
@@ -118,6 +119,7 @@ struct nvmet_tcp_cmd {
        u32                             pdu_len;
        u32                             pdu_recv;
        int                             sg_idx;
+       char                            recv_cbuf[CMSG_LEN(sizeof(char))];
        struct msghdr                   recv_msg;
        struct bio_vec                  *iov;
        u32                             flags;
@@ -1121,20 +1123,65 @@ static inline bool nvmet_tcp_pdu_valid(u8 type)
        return false;
 }
 
+static int nvmet_tcp_tls_record_ok(struct nvmet_tcp_queue *queue,
+               struct msghdr *msg, char *cbuf)
+{
+       struct cmsghdr *cmsg = (struct cmsghdr *)cbuf;
+       u8 ctype, level, description;
+       int ret = 0;
+
+       ctype = tls_get_record_type(queue->sock->sk, cmsg);
+       switch (ctype) {
+       case 0:
+               break;
+       case TLS_RECORD_TYPE_DATA:
+               break;
+       case TLS_RECORD_TYPE_ALERT:
+               tls_alert_recv(queue->sock->sk, msg, &level, &description);
+               if (level == TLS_ALERT_LEVEL_FATAL) {
+                       pr_err("queue %d: TLS Alert desc %u\n",
+                              queue->idx, description);
+                       ret = -ENOTCONN;
+               } else {
+                       pr_warn("queue %d: TLS Alert desc %u\n",
+                              queue->idx, description);
+                       ret = -EAGAIN;
+               }
+               break;
+       default:
+               /* discard this record type */
+               pr_err("queue %d: TLS record %d unhandled\n",
+                      queue->idx, ctype);
+               ret = -EAGAIN;
+               break;
+       }
+       return ret;
+}
+
 static int nvmet_tcp_try_recv_pdu(struct nvmet_tcp_queue *queue)
 {
        struct nvme_tcp_hdr *hdr = &queue->pdu.cmd.hdr;
-       int len;
+       int len, ret;
        struct kvec iov;
+       char cbuf[CMSG_LEN(sizeof(char))] = {};
        struct msghdr msg = { .msg_flags = MSG_DONTWAIT };
 
 recv:
        iov.iov_base = (void *)&queue->pdu + queue->offset;
        iov.iov_len = queue->left;
+       if (queue->tls_pskid) {
+               msg.msg_control = cbuf;
+               msg.msg_controllen = sizeof(cbuf);
+       }
        len = kernel_recvmsg(queue->sock, &msg, &iov, 1,
                        iov.iov_len, msg.msg_flags);
        if (unlikely(len < 0))
                return len;
+       if (queue->tls_pskid) {
+               ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf);
+               if (ret < 0)
+                       return ret;
+       }
 
        queue->offset += len;
        queue->left -= len;
@@ -1187,16 +1234,22 @@ static void nvmet_tcp_prep_recv_ddgst(struct nvmet_tcp_cmd *cmd)
 static int nvmet_tcp_try_recv_data(struct nvmet_tcp_queue *queue)
 {
        struct nvmet_tcp_cmd  *cmd = queue->cmd;
-       int ret;
+       int len, ret;
 
        while (msg_data_left(&cmd->recv_msg)) {
-               ret = sock_recvmsg(cmd->queue->sock, &cmd->recv_msg,
+               len = sock_recvmsg(cmd->queue->sock, &cmd->recv_msg,
                        cmd->recv_msg.msg_flags);
-               if (ret <= 0)
-                       return ret;
+               if (len <= 0)
+                       return len;
+               if (queue->tls_pskid) {
+                       ret = nvmet_tcp_tls_record_ok(cmd->queue,
+                                       &cmd->recv_msg, cmd->recv_cbuf);
+                       if (ret < 0)
+                               return ret;
+               }
 
-               cmd->pdu_recv += ret;
-               cmd->rbytes_done += ret;
+               cmd->pdu_recv += len;
+               cmd->rbytes_done += len;
        }
 
        if (queue->data_digest) {
@@ -1214,20 +1267,30 @@ static int nvmet_tcp_try_recv_data(struct nvmet_tcp_queue *queue)
 static int nvmet_tcp_try_recv_ddgst(struct nvmet_tcp_queue *queue)
 {
        struct nvmet_tcp_cmd *cmd = queue->cmd;
-       int ret;
+       int ret, len;
+       char cbuf[CMSG_LEN(sizeof(char))] = {};
        struct msghdr msg = { .msg_flags = MSG_DONTWAIT };
        struct kvec iov = {
                .iov_base = (void *)&cmd->recv_ddgst + queue->offset,
                .iov_len = queue->left
        };
 
-       ret = kernel_recvmsg(queue->sock, &msg, &iov, 1,
+       if (queue->tls_pskid) {
+               msg.msg_control = cbuf;
+               msg.msg_controllen = sizeof(cbuf);
+       }
+       len = kernel_recvmsg(queue->sock, &msg, &iov, 1,
                        iov.iov_len, msg.msg_flags);
-       if (unlikely(ret < 0))
-               return ret;
+       if (unlikely(len < 0))
+               return len;
+       if (queue->tls_pskid) {
+               ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf);
+               if (ret < 0)
+                       return ret;
+       }
 
-       queue->offset += ret;
-       queue->left -= ret;
+       queue->offset += len;
+       queue->left -= len;
        if (queue->left)
                return -EAGAIN;
 
@@ -1407,6 +1470,10 @@ static int nvmet_tcp_alloc_cmd(struct nvmet_tcp_queue *queue,
        if (!c->r2t_pdu)
                goto out_free_data;
 
+       if (queue->state == NVMET_TCP_Q_TLS_HANDSHAKE) {
+               c->recv_msg.msg_control = c->recv_cbuf;
+               c->recv_msg.msg_controllen = sizeof(c->recv_cbuf);
+       }
        c->recv_msg.msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL;
 
        list_add_tail(&c->entry, &queue->free_list);