selftests/bpf: Make sock configurable for each test case
authorJordan Rife <jrife@google.com>
Mon, 29 Apr 2024 21:45:22 +0000 (16:45 -0500)
committerMartin KaFai Lau <martin.lau@kernel.org>
Thu, 2 May 2024 22:23:31 +0000 (15:23 -0700)
In order to reuse the same test code for both socket system calls (e.g.
connect(), bind(), etc.) and kernel socket functions (e.g.
kernel_connect(), kernel_bind(), etc.), this patch introduces the "ops"
field to sock_addr_test. This field allows each test cases to configure
the set of functions used in the test case to create, manipulate, and
tear down a socket.

Signed-off-by: Jordan Rife <jrife@google.com>
Link: https://lore.kernel.org/r/20240429214529.2644801-6-jrife@google.com
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
tools/testing/selftests/bpf/prog_tests/sock_addr.c

index bf1ad18262d892e3a9b72a15f0fb759684d76c4e..8f678fa413f5651c469702d889ed83e05af32295 100644 (file)
@@ -54,12 +54,64 @@ enum sock_addr_test_type {
 typedef void *(*load_fn)(int cgroup_fd);
 typedef void (*destroy_fn)(void *skel);
 
+struct sock_ops {
+       int (*connect_to_addr)(int type, const struct sockaddr_storage *addr,
+                              socklen_t addrlen,
+                              const struct network_helper_opts *opts);
+       int (*start_server)(int family, int type, const char *addr_str,
+                           __u16 port, int timeout_ms);
+       int (*socket)(int famil, int type, int protocol);
+       int (*bind)(int fd, struct sockaddr *addr, socklen_t addrlen);
+       int (*getsockname)(int fd, struct sockaddr *addr, socklen_t *addrlen);
+       int (*getpeername)(int fd, struct sockaddr *addr, socklen_t *addrlen);
+       int (*sendmsg)(int fd, struct sockaddr *addr, socklen_t addrlen,
+                      char *msg, int msglen);
+       int (*close)(int fd);
+};
+
+static int user_sendmsg(int fd, struct sockaddr *addr, socklen_t addrlen,
+                       char *msg, int msglen)
+{
+       struct msghdr hdr;
+       struct iovec iov;
+
+       memset(&iov, 0, sizeof(iov));
+       iov.iov_base = msg;
+       iov.iov_len = msglen;
+
+       memset(&hdr, 0, sizeof(hdr));
+       hdr.msg_name = (void *)addr;
+       hdr.msg_namelen = addrlen;
+       hdr.msg_iov = &iov;
+       hdr.msg_iovlen = 1;
+
+       return sendmsg(fd, &hdr, 0);
+}
+
+static int user_bind(int fd, struct sockaddr *addr, socklen_t addrlen)
+{
+       return bind(fd, (const struct sockaddr *)addr, addrlen);
+}
+
+struct sock_ops user_ops = {
+       .connect_to_addr = connect_to_addr,
+       .start_server = start_server,
+       .socket = socket,
+       .bind = user_bind,
+       .getsockname = getsockname,
+       .getpeername = getpeername,
+       .sendmsg = user_sendmsg,
+       .close = close,
+};
+
 struct sock_addr_test {
        enum sock_addr_test_type type;
        const char *name;
        /* BPF prog properties */
        load_fn loadfn;
        destroy_fn destroyfn;
+       /* Socket operations */
+       struct sock_ops *ops;
        /* Socket properties */
        int socket_family;
        int socket_type;
@@ -113,6 +165,7 @@ static struct sock_addr_test tests[] = {
                "bind4: bind (stream)",
                bind4_prog_load,
                bind4_prog_destroy,
+               &user_ops,
                AF_INET,
                SOCK_STREAM,
                SERV4_IP,
@@ -125,6 +178,7 @@ static struct sock_addr_test tests[] = {
                "bind4: bind (dgram)",
                bind4_prog_load,
                bind4_prog_destroy,
+               &user_ops,
                AF_INET,
                SOCK_DGRAM,
                SERV4_IP,
@@ -137,6 +191,7 @@ static struct sock_addr_test tests[] = {
                "bind6: bind (stream)",
                bind6_prog_load,
                bind6_prog_destroy,
+               &user_ops,
                AF_INET6,
                SOCK_STREAM,
                SERV6_IP,
@@ -149,6 +204,7 @@ static struct sock_addr_test tests[] = {
                "bind6: bind (dgram)",
                bind6_prog_load,
                bind6_prog_destroy,
+               &user_ops,
                AF_INET6,
                SOCK_DGRAM,
                SERV6_IP,
@@ -163,6 +219,7 @@ static struct sock_addr_test tests[] = {
                "connect4: connect (stream)",
                connect4_prog_load,
                connect4_prog_destroy,
+               &user_ops,
                AF_INET,
                SOCK_STREAM,
                SERV4_IP,
@@ -176,6 +233,7 @@ static struct sock_addr_test tests[] = {
                "connect4: connect (dgram)",
                connect4_prog_load,
                connect4_prog_destroy,
+               &user_ops,
                AF_INET,
                SOCK_DGRAM,
                SERV4_IP,
@@ -189,6 +247,7 @@ static struct sock_addr_test tests[] = {
                "connect6: connect (stream)",
                connect6_prog_load,
                connect6_prog_destroy,
+               &user_ops,
                AF_INET6,
                SOCK_STREAM,
                SERV6_IP,
@@ -202,6 +261,7 @@ static struct sock_addr_test tests[] = {
                "connect6: connect (dgram)",
                connect6_prog_load,
                connect6_prog_destroy,
+               &user_ops,
                AF_INET6,
                SOCK_DGRAM,
                SERV6_IP,
@@ -215,6 +275,7 @@ static struct sock_addr_test tests[] = {
                "connect_unix: connect (stream)",
                connect_unix_prog_load,
                connect_unix_prog_destroy,
+               &user_ops,
                AF_UNIX,
                SOCK_STREAM,
                SERVUN_ADDRESS,
@@ -230,6 +291,7 @@ static struct sock_addr_test tests[] = {
                "sendmsg4: sendmsg (dgram)",
                sendmsg4_prog_load,
                sendmsg4_prog_destroy,
+               &user_ops,
                AF_INET,
                SOCK_DGRAM,
                SERV4_IP,
@@ -243,6 +305,7 @@ static struct sock_addr_test tests[] = {
                "sendmsg6: sendmsg (dgram)",
                sendmsg6_prog_load,
                sendmsg6_prog_destroy,
+               &user_ops,
                AF_INET6,
                SOCK_DGRAM,
                SERV6_IP,
@@ -256,6 +319,7 @@ static struct sock_addr_test tests[] = {
                "sendmsg_unix: sendmsg (dgram)",
                sendmsg_unix_prog_load,
                sendmsg_unix_prog_destroy,
+               &user_ops,
                AF_UNIX,
                SOCK_DGRAM,
                SERVUN_ADDRESS,
@@ -271,6 +335,7 @@ static struct sock_addr_test tests[] = {
                "recvmsg4: recvfrom (dgram)",
                recvmsg4_prog_load,
                recvmsg4_prog_destroy,
+               &user_ops,
                AF_INET,
                SOCK_DGRAM,
                SERV4_REWRITE_IP,
@@ -284,6 +349,7 @@ static struct sock_addr_test tests[] = {
                "recvmsg6: recvfrom (dgram)",
                recvmsg6_prog_load,
                recvmsg6_prog_destroy,
+               &user_ops,
                AF_INET6,
                SOCK_DGRAM,
                SERV6_REWRITE_IP,
@@ -297,6 +363,7 @@ static struct sock_addr_test tests[] = {
                "recvmsg_unix: recvfrom (dgram)",
                recvmsg_unix_prog_load,
                recvmsg_unix_prog_destroy,
+               &user_ops,
                AF_UNIX,
                SOCK_DGRAM,
                SERVUN_REWRITE_ADDRESS,
@@ -310,6 +377,7 @@ static struct sock_addr_test tests[] = {
                "recvmsg_unix: recvfrom (stream)",
                recvmsg_unix_prog_load,
                recvmsg_unix_prog_destroy,
+               &user_ops,
                AF_UNIX,
                SOCK_STREAM,
                SERVUN_REWRITE_ADDRESS,
@@ -325,6 +393,7 @@ static struct sock_addr_test tests[] = {
                "getsockname_unix",
                getsockname_unix_prog_load,
                getsockname_unix_prog_destroy,
+               &user_ops,
                AF_UNIX,
                SOCK_STREAM,
                SERVUN_ADDRESS,
@@ -340,6 +409,7 @@ static struct sock_addr_test tests[] = {
                "getpeername_unix",
                getpeername_unix_prog_load,
                getpeername_unix_prog_destroy,
+               &user_ops,
                AF_UNIX,
                SOCK_STREAM,
                SERVUN_ADDRESS,
@@ -400,26 +470,15 @@ static int cmp_sock_addr(info_fn fn, int sock1,
        return cmp_addr(&addr1, len1, addr2, addr2_len, cmp_port);
 }
 
-static int cmp_local_addr(int sock1, const struct sockaddr_storage *addr2,
-                         socklen_t addr2_len, bool cmp_port)
-{
-       return cmp_sock_addr(getsockname, sock1, addr2, addr2_len, cmp_port);
-}
-
-static int cmp_peer_addr(int sock1, const struct sockaddr_storage *addr2,
-                        socklen_t addr2_len, bool cmp_port)
-{
-       return cmp_sock_addr(getpeername, sock1, addr2, addr2_len, cmp_port);
-}
-
 static void test_bind(struct sock_addr_test *test)
 {
        struct sockaddr_storage expected_addr;
        socklen_t expected_addr_len = sizeof(struct sockaddr_storage);
        int serv = -1, client = -1, err;
 
-       serv = start_server(test->socket_family, test->socket_type,
-                           test->requested_addr, test->requested_port, 0);
+       serv = test->ops->start_server(test->socket_family, test->socket_type,
+                                      test->requested_addr,
+                                      test->requested_port, 0);
        if (!ASSERT_GE(serv, 0, "start_server"))
                goto cleanup;
 
@@ -429,7 +488,8 @@ static void test_bind(struct sock_addr_test *test)
        if (!ASSERT_EQ(err, 0, "make_sockaddr"))
                goto cleanup;
 
-       err = cmp_local_addr(serv, &expected_addr, expected_addr_len, true);
+       err = cmp_sock_addr(test->ops->getsockname, serv, &expected_addr,
+                           expected_addr_len, true);
        if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
                goto cleanup;
 
@@ -442,7 +502,7 @@ cleanup:
        if (client != -1)
                close(client);
        if (serv != -1)
-               close(serv);
+               test->ops->close(serv);
 }
 
 static void test_connect(struct sock_addr_test *test)
@@ -463,7 +523,8 @@ static void test_connect(struct sock_addr_test *test)
        if (!ASSERT_EQ(err, 0, "make_sockaddr"))
                goto cleanup;
 
-       client = connect_to_addr(test->socket_type, &addr, addr_len, NULL);
+       client = test->ops->connect_to_addr(test->socket_type, &addr, addr_len,
+                                           NULL);
        if (!ASSERT_GE(client, 0, "connect_to_addr"))
                goto cleanup;
 
@@ -479,18 +540,21 @@ static void test_connect(struct sock_addr_test *test)
                        goto cleanup;
        }
 
-       err = cmp_peer_addr(client, &expected_addr, expected_addr_len, true);
+       err = cmp_sock_addr(test->ops->getpeername, client, &expected_addr,
+                           expected_addr_len, true);
        if (!ASSERT_EQ(err, 0, "cmp_peer_addr"))
                goto cleanup;
 
        if (test->expected_src_addr) {
-               err = cmp_local_addr(client, &expected_src_addr, expected_src_addr_len, false);
+               err = cmp_sock_addr(test->ops->getsockname, client,
+                                   &expected_src_addr, expected_src_addr_len,
+                                   false);
                if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
                        goto cleanup;
        }
 cleanup:
        if (client != -1)
-               close(client);
+               test->ops->close(client);
        if (serv != -1)
                close(serv);
 }
@@ -500,8 +564,6 @@ static void test_xmsg(struct sock_addr_test *test)
        struct sockaddr_storage addr, src_addr;
        socklen_t addr_len = sizeof(struct sockaddr_storage),
                  src_addr_len = sizeof(struct sockaddr_storage);
-       struct msghdr hdr;
-       struct iovec iov;
        char data = 'a';
        int serv = -1, client = -1, err;
 
@@ -514,7 +576,7 @@ static void test_xmsg(struct sock_addr_test *test)
        if (!ASSERT_GE(serv, 0, "start_server"))
                goto cleanup;
 
-       client = socket(test->socket_family, test->socket_type, 0);
+       client = test->ops->socket(test->socket_family, test->socket_type, 0);
        if (!ASSERT_GE(client, 0, "socket"))
                goto cleanup;
 
@@ -524,7 +586,8 @@ static void test_xmsg(struct sock_addr_test *test)
                if (!ASSERT_EQ(err, 0, "make_sockaddr"))
                        goto cleanup;
 
-               err = bind(client, (const struct sockaddr *) &src_addr, src_addr_len);
+               err = test->ops->bind(client, (struct sockaddr *)&src_addr,
+                                     src_addr_len);
                if (!ASSERT_OK(err, "bind"))
                        goto cleanup;
        }
@@ -535,17 +598,8 @@ static void test_xmsg(struct sock_addr_test *test)
                goto cleanup;
 
        if (test->socket_type == SOCK_DGRAM) {
-               memset(&iov, 0, sizeof(iov));
-               iov.iov_base = &data;
-               iov.iov_len = sizeof(data);
-
-               memset(&hdr, 0, sizeof(hdr));
-               hdr.msg_name = (void *)&addr;
-               hdr.msg_namelen = addr_len;
-               hdr.msg_iov = &iov;
-               hdr.msg_iovlen = 1;
-
-               err = sendmsg(client, &hdr, 0);
+               err = test->ops->sendmsg(client, (struct sockaddr *)&addr,
+                                        addr_len, &data, sizeof(data));
                if (!ASSERT_EQ(err, sizeof(data), "sendmsg"))
                        goto cleanup;
        } else {
@@ -596,7 +650,7 @@ static void test_xmsg(struct sock_addr_test *test)
 
 cleanup:
        if (client != -1)
-               close(client);
+               test->ops->close(client);
        if (serv != -1)
                close(serv);
 }
@@ -607,7 +661,7 @@ static void test_getsockname(struct sock_addr_test *test)
        socklen_t expected_addr_len = sizeof(struct sockaddr_storage);
        int serv = -1, err;
 
-       serv = start_server(test->socket_family, test->socket_type,
+       serv = test->ops->start_server(test->socket_family, test->socket_type,
                            test->requested_addr, test->requested_port, 0);
        if (!ASSERT_GE(serv, 0, "start_server"))
                goto cleanup;
@@ -618,13 +672,13 @@ static void test_getsockname(struct sock_addr_test *test)
        if (!ASSERT_EQ(err, 0, "make_sockaddr"))
                goto cleanup;
 
-       err = cmp_local_addr(serv, &expected_addr, expected_addr_len, true);
+       err = cmp_sock_addr(test->ops->getsockname, serv, &expected_addr, expected_addr_len, true);
        if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
                goto cleanup;
 
 cleanup:
        if (serv != -1)
-               close(serv);
+               test->ops->close(serv);
 }
 
 static void test_getpeername(struct sock_addr_test *test)
@@ -644,7 +698,8 @@ static void test_getpeername(struct sock_addr_test *test)
        if (!ASSERT_EQ(err, 0, "make_sockaddr"))
                goto cleanup;
 
-       client = connect_to_addr(test->socket_type, &addr, addr_len, NULL);
+       client = test->ops->connect_to_addr(test->socket_type, &addr, addr_len,
+                                           NULL);
        if (!ASSERT_GE(client, 0, "connect_to_addr"))
                goto cleanup;
 
@@ -653,13 +708,14 @@ static void test_getpeername(struct sock_addr_test *test)
        if (!ASSERT_EQ(err, 0, "make_sockaddr"))
                goto cleanup;
 
-       err = cmp_peer_addr(client, &expected_addr, expected_addr_len, true);
+       err = cmp_sock_addr(test->ops->getpeername, client, &expected_addr,
+                           expected_addr_len, true);
        if (!ASSERT_EQ(err, 0, "cmp_peer_addr"))
                goto cleanup;
 
 cleanup:
        if (client != -1)
-               close(client);
+               test->ops->close(client);
        if (serv != -1)
                close(serv);
 }