#include "bpf_helpers.h"
 #include "bpf_endian.h"
 
+struct socket_cookie {
+       __u64 cookie_key;
+       __u32 cookie_value;
+};
+
 struct bpf_map_def SEC("maps") socket_cookies = {
-       .type = BPF_MAP_TYPE_HASH,
-       .key_size = sizeof(__u64),
-       .value_size = sizeof(__u32),
-       .max_entries = 1 << 8,
+       .type = BPF_MAP_TYPE_SK_STORAGE,
+       .key_size = sizeof(int),
+       .value_size = sizeof(struct socket_cookie),
+       .map_flags = BPF_F_NO_PREALLOC,
 };
 
+BPF_ANNOTATE_KV_PAIR(socket_cookies, int, struct socket_cookie);
+
 SEC("cgroup/connect6")
 int set_cookie(struct bpf_sock_addr *ctx)
 {
-       __u32 cookie_value = 0xFF;
-       __u64 cookie_key;
+       struct socket_cookie *p;
 
        if (ctx->family != AF_INET6 || ctx->user_family != AF_INET6)
                return 1;
 
-       cookie_key = bpf_get_socket_cookie(ctx);
-       if (bpf_map_update_elem(&socket_cookies, &cookie_key, &cookie_value, 0))
-               return 0;
+       p = bpf_sk_storage_get(&socket_cookies, ctx->sk, 0,
+                              BPF_SK_STORAGE_GET_F_CREATE);
+       if (!p)
+               return 1;
+
+       p->cookie_value = 0xFF;
+       p->cookie_key = bpf_get_socket_cookie(ctx);
 
        return 1;
 }
 SEC("sockops")
 int update_cookie(struct bpf_sock_ops *ctx)
 {
-       __u32 new_cookie_value;
-       __u32 *cookie_value;
-       __u64 cookie_key;
+       struct bpf_sock *sk;
+       struct socket_cookie *p;
 
        if (ctx->family != AF_INET6)
                return 1;
        if (ctx->op != BPF_SOCK_OPS_TCP_CONNECT_CB)
                return 1;
 
-       cookie_key = bpf_get_socket_cookie(ctx);
+       if (!ctx->sk)
+               return 1;
+
+       p = bpf_sk_storage_get(&socket_cookies, ctx->sk, 0, 0);
+       if (!p)
+               return 1;
 
-       cookie_value = bpf_map_lookup_elem(&socket_cookies, &cookie_key);
-       if (!cookie_value)
+       if (p->cookie_key != bpf_get_socket_cookie(ctx))
                return 1;
 
-       new_cookie_value = (ctx->local_port << 8) | *cookie_value;
-       bpf_map_update_elem(&socket_cookies, &cookie_key, &new_cookie_value, 0);
+       p->cookie_value = (ctx->local_port << 8) | p->cookie_value;
 
        return 1;
 }
 
 #define CG_PATH                        "/foo"
 #define SOCKET_COOKIE_PROG     "./socket_cookie_prog.o"
 
+struct socket_cookie {
+       __u64 cookie_key;
+       __u32 cookie_value;
+};
+
 static int start_server(void)
 {
        struct sockaddr_in6 addr;
        __u32 cookie_expected_value;
        struct sockaddr_in6 addr;
        socklen_t len = sizeof(addr);
-       __u32 cookie_value;
-       __u64 cookie_key;
+       struct socket_cookie val;
        int err = 0;
        int map_fd;
 
 
        map_fd = bpf_map__fd(map);
 
-       err = bpf_map_get_next_key(map_fd, NULL, &cookie_key);
-       if (err) {
-               log_err("Can't get cookie key from map");
-               goto out;
-       }
-
-       err = bpf_map_lookup_elem(map_fd, &cookie_key, &cookie_value);
-       if (err) {
-               log_err("Can't get cookie value from map");
-               goto out;
-       }
+       err = bpf_map_lookup_elem(map_fd, &client_fd, &val);
 
        err = getsockname(client_fd, (struct sockaddr *)&addr, &len);
        if (err) {
        }
 
        cookie_expected_value = (ntohs(addr.sin6_port) << 8) | 0xFF;
-       if (cookie_value != cookie_expected_value) {
-               log_err("Unexpected value in map: %x != %x", cookie_value,
+       if (val.cookie_value != cookie_expected_value) {
+               log_err("Unexpected value in map: %x != %x", val.cookie_value,
                        cookie_expected_value);
                goto err;
        }