aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tools/testing/selftests/net/tcp_ao
diff options
context:
space:
mode:
Diffstat (limited to 'tools/testing/selftests/net/tcp_ao')
-rw-r--r--tools/testing/selftests/net/tcp_ao/.gitignore2
-rw-r--r--tools/testing/selftests/net/tcp_ao/Makefile56
-rw-r--r--tools/testing/selftests/net/tcp_ao/bench-lookups.c360
-rw-r--r--tools/testing/selftests/net/tcp_ao/config10
-rw-r--r--tools/testing/selftests/net/tcp_ao/connect-deny.c264
-rw-r--r--tools/testing/selftests/net/tcp_ao/connect.c90
l---------tools/testing/selftests/net/tcp_ao/icmps-accept.c1
-rw-r--r--tools/testing/selftests/net/tcp_ao/icmps-discard.c449
-rw-r--r--tools/testing/selftests/net/tcp_ao/key-management.c1186
-rw-r--r--tools/testing/selftests/net/tcp_ao/lib/aolib.h605
-rw-r--r--tools/testing/selftests/net/tcp_ao/lib/kconfig.c148
-rw-r--r--tools/testing/selftests/net/tcp_ao/lib/netlink.c413
-rw-r--r--tools/testing/selftests/net/tcp_ao/lib/proc.c273
-rw-r--r--tools/testing/selftests/net/tcp_ao/lib/repair.c254
-rw-r--r--tools/testing/selftests/net/tcp_ao/lib/setup.c361
-rw-r--r--tools/testing/selftests/net/tcp_ao/lib/sock.c596
-rw-r--r--tools/testing/selftests/net/tcp_ao/lib/utils.c30
-rw-r--r--tools/testing/selftests/net/tcp_ao/restore.c236
-rw-r--r--tools/testing/selftests/net/tcp_ao/rst.c457
-rw-r--r--tools/testing/selftests/net/tcp_ao/self-connect.c197
-rw-r--r--tools/testing/selftests/net/tcp_ao/seq-ext.c245
-rw-r--r--tools/testing/selftests/net/tcp_ao/setsockopt-closed.c835
-rw-r--r--tools/testing/selftests/net/tcp_ao/settings1
-rw-r--r--tools/testing/selftests/net/tcp_ao/unsigned-md5.c741
24 files changed, 7810 insertions, 0 deletions
diff --git a/tools/testing/selftests/net/tcp_ao/.gitignore b/tools/testing/selftests/net/tcp_ao/.gitignore
new file mode 100644
index 000000000000..e8bb81b715b7
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/.gitignore
@@ -0,0 +1,2 @@
+*_ipv4
+*_ipv6
diff --git a/tools/testing/selftests/net/tcp_ao/Makefile b/tools/testing/selftests/net/tcp_ao/Makefile
new file mode 100644
index 000000000000..522d991e310e
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/Makefile
@@ -0,0 +1,56 @@
+# SPDX-License-Identifier: GPL-2.0
+TEST_BOTH_AF := bench-lookups
+TEST_BOTH_AF += connect
+TEST_BOTH_AF += connect-deny
+TEST_BOTH_AF += icmps-accept icmps-discard
+TEST_BOTH_AF += key-management
+TEST_BOTH_AF += restore
+TEST_BOTH_AF += rst
+TEST_BOTH_AF += self-connect
+TEST_BOTH_AF += seq-ext
+TEST_BOTH_AF += setsockopt-closed
+TEST_BOTH_AF += unsigned-md5
+
+TEST_IPV4_PROGS := $(TEST_BOTH_AF:%=%_ipv4)
+TEST_IPV6_PROGS := $(TEST_BOTH_AF:%=%_ipv6)
+
+TEST_GEN_PROGS := $(TEST_IPV4_PROGS) $(TEST_IPV6_PROGS)
+
+top_srcdir := ../../../../..
+include ../../lib.mk
+
+HOSTAR ?= ar
+
+LIBDIR := $(OUTPUT)/lib
+LIB := $(LIBDIR)/libaotst.a
+LDLIBS += $(LIB) -pthread
+LIBDEPS := lib/aolib.h Makefile
+
+CFLAGS := -Wall -O2 -g -D_GNU_SOURCE -fno-strict-aliasing
+CFLAGS += $(KHDR_INCLUDES)
+CFLAGS += -iquote ./lib/ -I ../../../../include/
+
+# Library
+LIBSRC := kconfig.c netlink.c proc.c repair.c setup.c sock.c utils.c
+LIBOBJ := $(LIBSRC:%.c=$(LIBDIR)/%.o)
+EXTRA_CLEAN += $(LIBOBJ) $(LIB)
+
+$(LIB): $(LIBOBJ)
+ $(HOSTAR) rcs $@ $^
+
+$(LIBDIR)/%.o: ./lib/%.c $(LIBDEPS)
+ mkdir -p $(LIBDIR)
+ $(CC) $< $(CFLAGS) $(CPPFLAGS) -o $@ -c
+
+$(TEST_GEN_PROGS): $(LIB)
+
+$(OUTPUT)/%_ipv4: %.c
+ $(LINK.c) $^ $(LDLIBS) -o $@
+
+$(OUTPUT)/%_ipv6: %.c
+ $(LINK.c) -DIPV6_TEST $^ $(LDLIBS) -o $@
+
+$(OUTPUT)/icmps-accept_ipv4: CFLAGS+= -DTEST_ICMPS_ACCEPT
+$(OUTPUT)/icmps-accept_ipv6: CFLAGS+= -DTEST_ICMPS_ACCEPT
+$(OUTPUT)/bench-lookups_ipv4: LDLIBS+= -lm
+$(OUTPUT)/bench-lookups_ipv6: LDLIBS+= -lm
diff --git a/tools/testing/selftests/net/tcp_ao/bench-lookups.c b/tools/testing/selftests/net/tcp_ao/bench-lookups.c
new file mode 100644
index 000000000000..a1e6e007c291
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/bench-lookups.c
@@ -0,0 +1,360 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Author: Dmitry Safonov <dima@arista.com> */
+#include <arpa/inet.h>
+#include <inttypes.h>
+#include <math.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <time.h>
+
+#include "../../../../include/linux/bits.h"
+#include "../../../../include/linux/kernel.h"
+#include "aolib.h"
+
+#define BENCH_NR_ITERS 100 /* number of times to run gathering statistics */
+
+static void gen_test_ips(union tcp_addr *ips, size_t ips_nr, bool use_rand)
+{
+ union tcp_addr net = {};
+ size_t i, j;
+
+ if (inet_pton(TEST_FAMILY, TEST_NETWORK, &net) != 1)
+ test_error("Can't convert ip address %s", TEST_NETWORK);
+
+ if (!use_rand) {
+ for (i = 0; i < ips_nr; i++)
+ ips[i] = gen_tcp_addr(net, 2 * i + 1);
+ return;
+ }
+ for (i = 0; i < ips_nr; i++) {
+ size_t r = (size_t)random() | 0x1;
+
+ ips[i] = gen_tcp_addr(net, r);
+
+ for (j = i - 1; j > 0 && i > 0; j--) {
+ if (!memcmp(&ips[i], &ips[j], sizeof(union tcp_addr))) {
+ i--; /* collision */
+ break;
+ }
+ }
+ }
+}
+
+static void test_add_routes(union tcp_addr *ips, size_t ips_nr)
+{
+ size_t i;
+
+ for (i = 0; i < ips_nr; i++) {
+ union tcp_addr *p = (union tcp_addr *)&ips[i];
+ int err;
+
+ err = ip_route_add(veth_name, TEST_FAMILY, this_ip_addr, *p);
+ if (err && err != -EEXIST)
+ test_error("Failed to add route");
+ }
+}
+
+static void server_apply_keys(int lsk, union tcp_addr *ips, size_t ips_nr)
+{
+ size_t i;
+
+ for (i = 0; i < ips_nr; i++) {
+ union tcp_addr *p = (union tcp_addr *)&ips[i];
+
+ if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, *p, -1, 100, 100))
+ test_error("setsockopt(TCP_AO)");
+ }
+}
+
+static const size_t nr_keys[] = { 512, 1024, 2048, 4096, 8192 };
+static union tcp_addr *test_ips;
+
+struct bench_stats {
+ uint64_t min;
+ uint64_t max;
+ uint64_t nr;
+ double mean;
+ double s2;
+};
+
+static struct bench_tests {
+ struct bench_stats delete_last_key;
+ struct bench_stats add_key;
+ struct bench_stats delete_rand_key;
+ struct bench_stats connect_last_key;
+ struct bench_stats connect_rand_key;
+ struct bench_stats delete_async;
+} bench_results[ARRAY_SIZE(nr_keys)];
+
+#define NSEC_PER_SEC 1000000000ULL
+
+static void measure_call(struct bench_stats *st,
+ void (*f)(int, void *), int sk, void *arg)
+{
+ struct timespec start = {}, end = {};
+ double delta;
+ uint64_t nsec;
+
+ if (clock_gettime(CLOCK_MONOTONIC, &start))
+ test_error("clock_gettime()");
+
+ f(sk, arg);
+
+ if (clock_gettime(CLOCK_MONOTONIC, &end))
+ test_error("clock_gettime()");
+
+ nsec = (end.tv_sec - start.tv_sec) * NSEC_PER_SEC;
+ if (end.tv_nsec >= start.tv_nsec)
+ nsec += end.tv_nsec - start.tv_nsec;
+ else
+ nsec -= start.tv_nsec - end.tv_nsec;
+
+ if (st->nr == 0) {
+ st->min = st->max = nsec;
+ } else {
+ if (st->min > nsec)
+ st->min = nsec;
+ if (st->max < nsec)
+ st->max = nsec;
+ }
+
+ /* Welford-Knuth algorithm */
+ st->nr++;
+ delta = (double)nsec - st->mean;
+ st->mean += delta / st->nr;
+ st->s2 += delta * ((double)nsec - st->mean);
+}
+
+static void delete_mkt(int sk, void *arg)
+{
+ struct tcp_ao_del *ao = arg;
+
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_DEL_KEY, ao, sizeof(*ao)))
+ test_error("setsockopt(TCP_AO_DEL_KEY)");
+}
+
+static void add_back_mkt(int sk, void *arg)
+{
+ union tcp_addr *p = arg;
+
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, *p, -1, 100, 100))
+ test_error("setsockopt(TCP_AO)");
+}
+
+static void bench_delete(int lsk, struct bench_stats *add,
+ struct bench_stats *del,
+ union tcp_addr *ips, size_t ips_nr,
+ bool rand_order, bool async)
+{
+ struct tcp_ao_del ao_del = {};
+ union tcp_addr *p;
+ size_t i;
+
+ ao_del.sndid = 100;
+ ao_del.rcvid = 100;
+ ao_del.del_async = !!async;
+ ao_del.prefix = DEFAULT_TEST_PREFIX;
+
+ /* Remove the first added */
+ p = (union tcp_addr *)&ips[0];
+ tcp_addr_to_sockaddr_in(&ao_del.addr, p, 0);
+
+ for (i = 0; i < BENCH_NR_ITERS; i++) {
+ measure_call(del, delete_mkt, lsk, (void *)&ao_del);
+
+ /* Restore it back */
+ measure_call(add, add_back_mkt, lsk, (void *)p);
+
+ /*
+ * Slowest for FILO-linked-list:
+ * on (i) iteration removing ips[i] element. When it gets
+ * added to the list back - it becomes first to fetch, so
+ * on (i + 1) iteration go to ips[i + 1] element.
+ */
+ if (rand_order)
+ p = (union tcp_addr *)&ips[rand() % ips_nr];
+ else
+ p = (union tcp_addr *)&ips[i % ips_nr];
+ tcp_addr_to_sockaddr_in(&ao_del.addr, p, 0);
+ }
+}
+
+static void bench_connect_srv(int lsk, union tcp_addr *ips, size_t ips_nr)
+{
+ size_t i;
+
+ for (i = 0; i < BENCH_NR_ITERS; i++) {
+ int sk;
+
+ synchronize_threads();
+
+ if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
+ test_error("test_wait_fd()");
+
+ sk = accept(lsk, NULL, NULL);
+ if (sk < 0)
+ test_error("accept()");
+
+ close(sk);
+ }
+}
+
+static void test_print_stats(const char *desc, size_t nr, struct bench_stats *bs)
+{
+ test_ok("%-20s\t%zu keys: min=%" PRIu64 "ms max=%" PRIu64 "ms mean=%gms stddev=%g",
+ desc, nr, bs->min / 1000000, bs->max / 1000000,
+ bs->mean / 1000000, sqrt((bs->mean / 1000000) / bs->nr));
+}
+
+static void *server_fn(void *arg)
+{
+ size_t i;
+
+ for (i = 0; i < ARRAY_SIZE(nr_keys); i++) {
+ struct bench_tests *bt = &bench_results[i];
+ int lsk;
+
+ test_ips = malloc(nr_keys[i] * sizeof(union tcp_addr));
+ if (!test_ips)
+ test_error("malloc()");
+
+ lsk = test_listen_socket(this_ip_addr, test_server_port + i, 1);
+
+ gen_test_ips(test_ips, nr_keys[i], false);
+ test_add_routes(test_ips, nr_keys[i]);
+ test_set_optmem(KERNEL_TCP_AO_KEY_SZ_ROUND_UP * nr_keys[i]);
+ server_apply_keys(lsk, test_ips, nr_keys[i]);
+
+ synchronize_threads();
+ bench_connect_srv(lsk, test_ips, nr_keys[i]);
+ bench_connect_srv(lsk, test_ips, nr_keys[i]);
+
+ /* The worst case for FILO-list */
+ bench_delete(lsk, &bt->add_key, &bt->delete_last_key,
+ test_ips, nr_keys[i], false, false);
+ test_print_stats("Add a new key",
+ nr_keys[i], &bt->add_key);
+ test_print_stats("Delete: worst case",
+ nr_keys[i], &bt->delete_last_key);
+
+ bench_delete(lsk, &bt->add_key, &bt->delete_rand_key,
+ test_ips, nr_keys[i], true, false);
+ test_print_stats("Delete: random-search",
+ nr_keys[i], &bt->delete_rand_key);
+
+ bench_delete(lsk, &bt->add_key, &bt->delete_async,
+ test_ips, nr_keys[i], false, true);
+ test_print_stats("Delete: async", nr_keys[i], &bt->delete_async);
+
+ free(test_ips);
+ close(lsk);
+ }
+
+ return NULL;
+}
+
+static void connect_client(int sk, void *arg)
+{
+ size_t *p = arg;
+
+ if (test_connect_socket(sk, this_ip_dest, test_server_port + *p) <= 0)
+ test_error("failed to connect()");
+}
+
+static void client_addr_setup(int sk, union tcp_addr taddr)
+{
+#ifdef IPV6_TEST
+ struct sockaddr_in6 addr = {
+ .sin6_family = AF_INET6,
+ .sin6_port = 0,
+ .sin6_addr = taddr.a6,
+ };
+#else
+ struct sockaddr_in addr = {
+ .sin_family = AF_INET,
+ .sin_port = 0,
+ .sin_addr = taddr.a4,
+ };
+#endif
+ int ret;
+
+ ret = ip_addr_add(veth_name, TEST_FAMILY, taddr, TEST_PREFIX);
+ if (ret && ret != -EEXIST)
+ test_error("Failed to add ip address");
+ ret = ip_route_add(veth_name, TEST_FAMILY, taddr, this_ip_dest);
+ if (ret && ret != -EEXIST)
+ test_error("Failed to add route");
+
+ if (bind(sk, &addr, sizeof(addr)))
+ test_error("bind()");
+}
+
+static void bench_connect_client(size_t port_off, struct bench_tests *bt,
+ union tcp_addr *ips, size_t ips_nr, bool rand_order)
+{
+ struct bench_stats *con;
+ union tcp_addr *p;
+ size_t i;
+
+ if (rand_order)
+ con = &bt->connect_rand_key;
+ else
+ con = &bt->connect_last_key;
+
+ p = (union tcp_addr *)&ips[0];
+
+ for (i = 0; i < BENCH_NR_ITERS; i++) {
+ int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+
+ if (sk < 0)
+ test_error("socket()");
+
+ client_addr_setup(sk, *p);
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest,
+ -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ synchronize_threads();
+
+ measure_call(con, connect_client, sk, (void *)&port_off);
+
+ close(sk);
+
+ /*
+ * Slowest for FILO-linked-list:
+ * on (i) iteration removing ips[i] element. When it gets
+ * added to the list back - it becomes first to fetch, so
+ * on (i + 1) iteration go to ips[i + 1] element.
+ */
+ if (rand_order)
+ p = (union tcp_addr *)&ips[rand() % ips_nr];
+ else
+ p = (union tcp_addr *)&ips[i % ips_nr];
+ }
+}
+
+static void *client_fn(void *arg)
+{
+ size_t i;
+
+ for (i = 0; i < ARRAY_SIZE(nr_keys); i++) {
+ struct bench_tests *bt = &bench_results[i];
+
+ synchronize_threads();
+ bench_connect_client(i, bt, test_ips, nr_keys[i], false);
+ test_print_stats("Connect: worst case",
+ nr_keys[i], &bt->connect_last_key);
+
+ bench_connect_client(i, bt, test_ips, nr_keys[i], false);
+ test_print_stats("Connect: random-search",
+ nr_keys[i], &bt->connect_last_key);
+ }
+ synchronize_threads();
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(30, server_fn, client_fn);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/config b/tools/testing/selftests/net/tcp_ao/config
new file mode 100644
index 000000000000..d3277a9de987
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/config
@@ -0,0 +1,10 @@
+CONFIG_CRYPTO_HMAC=y
+CONFIG_CRYPTO_RMD160=y
+CONFIG_CRYPTO_SHA1=y
+CONFIG_IPV6_MULTIPLE_TABLES=y
+CONFIG_IPV6=y
+CONFIG_NET_L3_MASTER_DEV=y
+CONFIG_NET_VRF=y
+CONFIG_TCP_AO=y
+CONFIG_TCP_MD5SIG=y
+CONFIG_VETH=m
diff --git a/tools/testing/selftests/net/tcp_ao/connect-deny.c b/tools/testing/selftests/net/tcp_ao/connect-deny.c
new file mode 100644
index 000000000000..185a2f6e5ff3
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/connect-deny.c
@@ -0,0 +1,264 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Author: Dmitry Safonov <dima@arista.com> */
+#include <inttypes.h>
+#include "aolib.h"
+
+#define fault(type) (inj == FAULT_ ## type)
+
+static inline int test_add_key_maclen(int sk, const char *key, uint8_t maclen,
+ union tcp_addr in_addr, uint8_t prefix,
+ uint8_t sndid, uint8_t rcvid)
+{
+ struct tcp_ao_add tmp = {};
+ int err;
+
+ if (prefix > DEFAULT_TEST_PREFIX)
+ prefix = DEFAULT_TEST_PREFIX;
+
+ err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr, false, false,
+ prefix, 0, sndid, rcvid, maclen,
+ 0, strlen(key), key);
+ if (err)
+ return err;
+
+ err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
+ if (err < 0)
+ return -errno;
+
+ return test_verify_socket_key(sk, &tmp);
+}
+
+static void try_accept(const char *tst_name, unsigned int port, const char *pwd,
+ union tcp_addr addr, uint8_t prefix,
+ uint8_t sndid, uint8_t rcvid, uint8_t maclen,
+ const char *cnt_name, test_cnt cnt_expected,
+ fault_t inj)
+{
+ struct tcp_ao_counters ao_cnt1, ao_cnt2;
+ uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */
+ int lsk, err, sk = 0;
+ time_t timeout;
+
+ lsk = test_listen_socket(this_ip_addr, port, 1);
+
+ if (pwd && test_add_key_maclen(lsk, pwd, maclen, addr, prefix, sndid, rcvid))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ if (cnt_name)
+ before_cnt = netstat_get_one(cnt_name, NULL);
+ if (pwd && test_get_tcp_ao_counters(lsk, &ao_cnt1))
+ test_error("test_get_tcp_ao_counters()");
+
+ synchronize_threads(); /* preparations done */
+
+ timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
+ err = test_wait_fd(lsk, timeout, 0);
+ if (err == -ETIMEDOUT) {
+ if (!fault(TIMEOUT))
+ test_fail("timed out for accept()");
+ } else if (err < 0) {
+ test_error("test_wait_fd()");
+ } else {
+ if (fault(TIMEOUT))
+ test_fail("ready to accept");
+
+ sk = accept(lsk, NULL, NULL);
+ if (sk < 0) {
+ test_error("accept()");
+ } else {
+ if (fault(TIMEOUT))
+ test_fail("%s: accepted", tst_name);
+ }
+ }
+
+ if (pwd && test_get_tcp_ao_counters(lsk, &ao_cnt2))
+ test_error("test_get_tcp_ao_counters()");
+
+ close(lsk);
+ if (pwd)
+ test_tcp_ao_counters_cmp(tst_name, &ao_cnt1, &ao_cnt2, cnt_expected);
+
+ if (!cnt_name)
+ goto out;
+
+ after_cnt = netstat_get_one(cnt_name, NULL);
+
+ if (after_cnt <= before_cnt) {
+ test_fail("%s: %s counter did not increase: %zu <= %zu",
+ tst_name, cnt_name, after_cnt, before_cnt);
+ } else {
+ test_ok("%s: counter %s increased %zu => %zu",
+ tst_name, cnt_name, before_cnt, after_cnt);
+ }
+
+out:
+ synchronize_threads(); /* close() */
+ if (sk > 0)
+ close(sk);
+}
+
+static void *server_fn(void *arg)
+{
+ union tcp_addr wrong_addr, network_addr;
+ unsigned int port = test_server_port;
+
+ if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
+ test_error("Can't convert ip address %s", TEST_WRONG_IP);
+
+ try_accept("Non-AO server + AO client", port++, NULL,
+ this_ip_dest, -1, 100, 100, 0,
+ "TCPAOKeyNotFound", 0, FAULT_TIMEOUT);
+
+ try_accept("AO server + Non-AO client", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100, 0,
+ "TCPAORequired", TEST_CNT_AO_REQUIRED, FAULT_TIMEOUT);
+
+ try_accept("Wrong password", port++, "something that is not DEFAULT_TEST_PASSWORD",
+ this_ip_dest, -1, 100, 100, 0,
+ "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);
+
+ try_accept("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 101, 0,
+ "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);
+
+ try_accept("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 101, 100, 0,
+ "TCPAOGood", TEST_CNT_GOOD, FAULT_TIMEOUT);
+
+ try_accept("Different maclen", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100, 8,
+ "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT);
+
+ try_accept("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
+ wrong_addr, -1, 100, 100, 0,
+ "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT);
+
+ try_accept("Client: Wrong addr", port++, NULL,
+ this_ip_dest, -1, 100, 100, 0, NULL, 0, FAULT_TIMEOUT);
+
+ try_accept("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 200, 100, 0,
+ "TCPAOGood", TEST_CNT_GOOD, 0);
+
+ if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
+ test_error("Can't convert ip address %s", TEST_NETWORK);
+
+ try_accept("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
+ network_addr, 16, 100, 100, 0,
+ "TCPAOGood", TEST_CNT_GOOD, 0);
+
+ try_accept("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100, 0,
+ "TCPAOGood", TEST_CNT_GOOD, 0);
+
+ /* client exits */
+ synchronize_threads();
+ return NULL;
+}
+
+static void try_connect(const char *tst_name, unsigned int port,
+ const char *pwd, union tcp_addr addr, uint8_t prefix,
+ uint8_t sndid, uint8_t rcvid,
+ test_cnt cnt_expected, fault_t inj)
+{
+ struct tcp_ao_counters ao_cnt1, ao_cnt2;
+ time_t timeout;
+ int sk, ret;
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ if (pwd && test_add_key(sk, pwd, addr, prefix, sndid, rcvid))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ if (pwd && test_get_tcp_ao_counters(sk, &ao_cnt1))
+ test_error("test_get_tcp_ao_counters()");
+
+ synchronize_threads(); /* preparations done */
+
+ timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
+ ret = _test_connect_socket(sk, this_ip_dest, port, timeout);
+
+ if (ret < 0) {
+ if (fault(KEYREJECT) && ret == -EKEYREJECTED) {
+ test_ok("%s: connect() was prevented", tst_name);
+ } else if (ret == -ETIMEDOUT && fault(TIMEOUT)) {
+ test_ok("%s", tst_name);
+ } else if (ret == -ECONNREFUSED &&
+ (fault(TIMEOUT) || fault(KEYREJECT))) {
+ test_ok("%s: refused to connect", tst_name);
+ } else {
+ test_error("%s: connect() returned %d", tst_name, ret);
+ }
+ goto out;
+ }
+
+ if (fault(TIMEOUT) || fault(KEYREJECT))
+ test_fail("%s: connected", tst_name);
+ else
+ test_ok("%s: connected", tst_name);
+ if (pwd && ret > 0) {
+ if (test_get_tcp_ao_counters(sk, &ao_cnt2))
+ test_error("test_get_tcp_ao_counters()");
+ test_tcp_ao_counters_cmp(tst_name, &ao_cnt1, &ao_cnt2, cnt_expected);
+ }
+out:
+ synchronize_threads(); /* close() */
+
+ if (ret > 0)
+ close(sk);
+}
+
+static void *client_fn(void *arg)
+{
+ union tcp_addr wrong_addr, network_addr;
+ unsigned int port = test_server_port;
+
+ if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
+ test_error("Can't convert ip address %s", TEST_WRONG_IP);
+
+ try_connect("Non-AO server + AO client", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
+
+ try_connect("AO server + Non-AO client", port++, NULL,
+ this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
+
+ try_connect("Wrong password", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
+
+ try_connect("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
+
+ try_connect("Wrong snd id", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
+
+ try_connect("Different maclen", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
+
+ try_connect("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT);
+
+ try_connect("Client: Wrong addr", port++, DEFAULT_TEST_PASSWORD,
+ wrong_addr, -1, 100, 100, 0, FAULT_KEYREJECT);
+
+ try_connect("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 200, TEST_CNT_GOOD, 0);
+
+ if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1)
+ test_error("Can't convert ip address %s", TEST_NETWORK);
+
+ try_connect("Server: prefix match", port++, DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100, TEST_CNT_GOOD, 0);
+
+ try_connect("Client: prefix match", port++, DEFAULT_TEST_PASSWORD,
+ network_addr, 16, 100, 100, TEST_CNT_GOOD, 0);
+
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(21, server_fn, client_fn);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/connect.c b/tools/testing/selftests/net/tcp_ao/connect.c
new file mode 100644
index 000000000000..81653b47f303
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/connect.c
@@ -0,0 +1,90 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Author: Dmitry Safonov <dima@arista.com> */
+#include <inttypes.h>
+#include "aolib.h"
+
+static void *server_fn(void *arg)
+{
+ int sk, lsk;
+ ssize_t bytes;
+
+ lsk = test_listen_socket(this_ip_addr, test_server_port, 1);
+
+ if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ synchronize_threads();
+
+ if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
+ test_error("test_wait_fd()");
+
+ sk = accept(lsk, NULL, NULL);
+ if (sk < 0)
+ test_error("accept()");
+
+ synchronize_threads();
+
+ bytes = test_server_run(sk, 0, 0);
+
+ test_fail("server served: %zd", bytes);
+ return NULL;
+}
+
+static void *client_fn(void *arg)
+{
+ int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ uint64_t before_aogood, after_aogood;
+ const size_t nr_packets = 20;
+ struct netstat *ns_before, *ns_after;
+ struct tcp_ao_counters ao1, ao2;
+
+ if (sk < 0)
+ test_error("socket()");
+
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ synchronize_threads();
+ if (test_connect_socket(sk, this_ip_dest, test_server_port) <= 0)
+ test_error("failed to connect()");
+ synchronize_threads();
+
+ ns_before = netstat_read();
+ before_aogood = netstat_get(ns_before, "TCPAOGood", NULL);
+ if (test_get_tcp_ao_counters(sk, &ao1))
+ test_error("test_get_tcp_ao_counters()");
+
+ if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
+ test_fail("verify failed");
+ return NULL;
+ }
+
+ ns_after = netstat_read();
+ after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
+ if (test_get_tcp_ao_counters(sk, &ao2))
+ test_error("test_get_tcp_ao_counters()");
+ netstat_print_diff(ns_before, ns_after);
+ netstat_free(ns_before);
+ netstat_free(ns_after);
+
+ if (nr_packets > (after_aogood - before_aogood)) {
+ test_fail("TCPAOGood counter mismatch: %zu > (%zu - %zu)",
+ nr_packets, after_aogood, before_aogood);
+ return NULL;
+ }
+ if (test_tcp_ao_counters_cmp("connect", &ao1, &ao2, TEST_CNT_GOOD))
+ return NULL;
+
+ test_ok("connect TCPAOGood %" PRIu64 "/%" PRIu64 "/%" PRIu64 " => %" PRIu64 "/%" PRIu64 "/%" PRIu64 ", sent %" PRIu64,
+ before_aogood, ao1.ao_info_pkt_good,
+ ao1.key_cnts[0].pkt_good,
+ after_aogood, ao2.ao_info_pkt_good,
+ ao2.key_cnts[0].pkt_good,
+ nr_packets);
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(1, server_fn, client_fn);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/icmps-accept.c b/tools/testing/selftests/net/tcp_ao/icmps-accept.c
new file mode 120000
index 000000000000..0a5bb85eb260
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/icmps-accept.c
@@ -0,0 +1 @@
+icmps-discard.c \ No newline at end of file
diff --git a/tools/testing/selftests/net/tcp_ao/icmps-discard.c b/tools/testing/selftests/net/tcp_ao/icmps-discard.c
new file mode 100644
index 000000000000..d69bcba3c929
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/icmps-discard.c
@@ -0,0 +1,449 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Selftest that verifies that incomping ICMPs are ignored,
+ * the TCP connection stays alive, no hard or soft errors get reported
+ * to the usespace and the counter for ignored ICMPs is updated.
+ *
+ * RFC5925, 7.8:
+ * >> A TCP-AO implementation MUST default to ignore incoming ICMPv4
+ * messages of Type 3 (destination unreachable), Codes 2-4 (protocol
+ * unreachable, port unreachable, and fragmentation needed -- ’hard
+ * errors’), and ICMPv6 Type 1 (destination unreachable), Code 1
+ * (administratively prohibited) and Code 4 (port unreachable) intended
+ * for connections in synchronized states (ESTABLISHED, FIN-WAIT-1, FIN-
+ * WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, TIME-WAIT) that match MKTs.
+ *
+ * Author: Dmitry Safonov <dima@arista.com>
+ */
+#include <inttypes.h>
+#include <linux/icmp.h>
+#include <linux/icmpv6.h>
+#include <linux/ipv6.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <sys/socket.h>
+#include "aolib.h"
+#include "../../../../include/linux/compiler.h"
+
+const size_t packets_nr = 20;
+const size_t packet_size = 100;
+const char *tcpao_icmps = "TCPAODroppedIcmps";
+
+#ifdef IPV6_TEST
+const char *dst_unreach = "Icmp6InDestUnreachs";
+const int sk_ip_level = SOL_IPV6;
+const int sk_recverr = IPV6_RECVERR;
+#else
+const char *dst_unreach = "InDestUnreachs";
+const int sk_ip_level = SOL_IP;
+const int sk_recverr = IP_RECVERR;
+#endif
+
+/* Server is expected to fail with hard error if ::accept_icmp is set */
+#ifdef TEST_ICMPS_ACCEPT
+# define test_icmps_fail test_ok
+# define test_icmps_ok test_fail
+#else
+# define test_icmps_fail test_fail
+# define test_icmps_ok test_ok
+#endif
+
+static void serve_interfered(int sk)
+{
+ ssize_t test_quota = packet_size * packets_nr * 10;
+ uint64_t dest_unreach_a, dest_unreach_b;
+ uint64_t icmp_ignored_a, icmp_ignored_b;
+ struct tcp_ao_counters ao_cnt1, ao_cnt2;
+ bool counter_not_found;
+ struct netstat *ns_after, *ns_before;
+ ssize_t bytes;
+
+ ns_before = netstat_read();
+ dest_unreach_a = netstat_get(ns_before, dst_unreach, NULL);
+ icmp_ignored_a = netstat_get(ns_before, tcpao_icmps, NULL);
+ if (test_get_tcp_ao_counters(sk, &ao_cnt1))
+ test_error("test_get_tcp_ao_counters()");
+ bytes = test_server_run(sk, test_quota, 0);
+ ns_after = netstat_read();
+ netstat_print_diff(ns_before, ns_after);
+ dest_unreach_b = netstat_get(ns_after, dst_unreach, NULL);
+ icmp_ignored_b = netstat_get(ns_after, tcpao_icmps,
+ &counter_not_found);
+ if (test_get_tcp_ao_counters(sk, &ao_cnt2))
+ test_error("test_get_tcp_ao_counters()");
+
+ netstat_free(ns_before);
+ netstat_free(ns_after);
+
+ if (dest_unreach_a >= dest_unreach_b) {
+ test_fail("%s counter didn't change: %" PRIu64 " >= %" PRIu64,
+ dst_unreach, dest_unreach_a, dest_unreach_b);
+ return;
+ }
+ test_ok("%s delivered %" PRIu64,
+ dst_unreach, dest_unreach_b - dest_unreach_a);
+ if (bytes < 0)
+ test_icmps_fail("Server failed with %zd: %s", bytes, strerrordesc_np(-bytes));
+ else
+ test_icmps_ok("Server survived %zd bytes of traffic", test_quota);
+ if (counter_not_found) {
+ test_fail("Not found %s counter", tcpao_icmps);
+ return;
+ }
+#ifdef TEST_ICMPS_ACCEPT
+ test_tcp_ao_counters_cmp(NULL, &ao_cnt1, &ao_cnt2, TEST_CNT_GOOD);
+#else
+ test_tcp_ao_counters_cmp(NULL, &ao_cnt1, &ao_cnt2, TEST_CNT_GOOD | TEST_CNT_AO_DROPPED_ICMP);
+#endif
+ if (icmp_ignored_a >= icmp_ignored_b) {
+ test_icmps_fail("%s counter didn't change: %" PRIu64 " >= %" PRIu64,
+ tcpao_icmps, icmp_ignored_a, icmp_ignored_b);
+ return;
+ }
+ test_icmps_ok("ICMPs ignored %" PRIu64, icmp_ignored_b - icmp_ignored_a);
+}
+
+static void *server_fn(void *arg)
+{
+ int val, sk, lsk;
+ bool accept_icmps = false;
+
+ lsk = test_listen_socket(this_ip_addr, test_server_port, 1);
+
+#ifdef TEST_ICMPS_ACCEPT
+ accept_icmps = true;
+#endif
+
+ if (test_set_ao_flags(lsk, false, accept_icmps))
+ test_error("setsockopt(TCP_AO_INFO)");
+
+ if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ synchronize_threads();
+
+ if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
+ test_error("test_wait_fd()");
+
+ sk = accept(lsk, NULL, NULL);
+ if (sk < 0)
+ test_error("accept()");
+
+ /* Fail on hard ip errors, such as dest unreachable (RFC1122) */
+ val = 1;
+ if (setsockopt(sk, sk_ip_level, sk_recverr, &val, sizeof(val)))
+ test_error("setsockopt()");
+
+ synchronize_threads();
+
+ serve_interfered(sk);
+ return NULL;
+}
+
+static size_t packets_sent;
+static size_t icmps_sent;
+
+static uint32_t checksum4_nofold(void *data, size_t len, uint32_t sum)
+{
+ uint16_t *words = data;
+ size_t i;
+
+ for (i = 0; i < len / sizeof(uint16_t); i++)
+ sum += words[i];
+ if (len & 1)
+ sum += ((char *)data)[len - 1];
+ return sum;
+}
+
+static uint16_t checksum4_fold(void *data, size_t len, uint32_t sum)
+{
+ sum = checksum4_nofold(data, len, sum);
+ while (sum > 0xFFFF)
+ sum = (sum & 0xFFFF) + (sum >> 16);
+ return ~sum;
+}
+
+static void set_ip4hdr(struct iphdr *iph, size_t packet_len, int proto,
+ struct sockaddr_in *src, struct sockaddr_in *dst)
+{
+ iph->version = 4;
+ iph->ihl = 5;
+ iph->tos = 0;
+ iph->tot_len = htons(packet_len);
+ iph->ttl = 2;
+ iph->protocol = proto;
+ iph->saddr = src->sin_addr.s_addr;
+ iph->daddr = dst->sin_addr.s_addr;
+ iph->check = checksum4_fold((void *)iph, iph->ihl << 1, 0);
+}
+
+static void icmp_interfere4(uint8_t type, uint8_t code, uint32_t rcv_nxt,
+ struct sockaddr_in *src, struct sockaddr_in *dst)
+{
+ int sk = socket(AF_INET, SOCK_RAW, IPPROTO_RAW);
+ struct {
+ struct iphdr iph;
+ struct icmphdr icmph;
+ struct iphdr iphe;
+ struct {
+ uint16_t sport;
+ uint16_t dport;
+ uint32_t seq;
+ } tcph;
+ } packet = {};
+ size_t packet_len;
+ ssize_t bytes;
+
+ if (sk < 0)
+ test_error("socket(AF_INET, SOCK_RAW, IPPROTO_RAW)");
+
+ packet_len = sizeof(packet);
+ set_ip4hdr(&packet.iph, packet_len, IPPROTO_ICMP, src, dst);
+
+ packet.icmph.type = type;
+ packet.icmph.code = code;
+ if (code == ICMP_FRAG_NEEDED) {
+ randomize_buffer(&packet.icmph.un.frag.mtu,
+ sizeof(packet.icmph.un.frag.mtu));
+ }
+
+ packet_len = sizeof(packet.iphe) + sizeof(packet.tcph);
+ set_ip4hdr(&packet.iphe, packet_len, IPPROTO_TCP, dst, src);
+
+ packet.tcph.sport = dst->sin_port;
+ packet.tcph.dport = src->sin_port;
+ packet.tcph.seq = htonl(rcv_nxt);
+
+ packet_len = sizeof(packet) - sizeof(packet.iph);
+ packet.icmph.checksum = checksum4_fold((void *)&packet.icmph,
+ packet_len, 0);
+
+ bytes = sendto(sk, &packet, sizeof(packet), 0,
+ (struct sockaddr *)dst, sizeof(*dst));
+ if (bytes != sizeof(packet))
+ test_error("send(): %zd", bytes);
+ icmps_sent++;
+
+ close(sk);
+}
+
+static void set_ip6hdr(struct ipv6hdr *iph, size_t packet_len, int proto,
+ struct sockaddr_in6 *src, struct sockaddr_in6 *dst)
+{
+ iph->version = 6;
+ iph->payload_len = htons(packet_len);
+ iph->nexthdr = proto;
+ iph->hop_limit = 2;
+ iph->saddr = src->sin6_addr;
+ iph->daddr = dst->sin6_addr;
+}
+
+static inline uint16_t csum_fold(uint32_t csum)
+{
+ uint32_t sum = csum;
+
+ sum = (sum & 0xffff) + (sum >> 16);
+ sum = (sum & 0xffff) + (sum >> 16);
+ return (uint16_t)~sum;
+}
+
+static inline uint32_t csum_add(uint32_t csum, uint32_t addend)
+{
+ uint32_t res = csum;
+
+ res += addend;
+ return res + (res < addend);
+}
+
+noinline uint32_t checksum6_nofold(void *data, size_t len, uint32_t sum)
+{
+ uint16_t *words = data;
+ size_t i;
+
+ for (i = 0; i < len / sizeof(uint16_t); i++)
+ sum = csum_add(sum, words[i]);
+ if (len & 1)
+ sum = csum_add(sum, ((char *)data)[len - 1]);
+ return sum;
+}
+
+noinline uint16_t icmp6_checksum(struct sockaddr_in6 *src,
+ struct sockaddr_in6 *dst,
+ void *ptr, size_t len, uint8_t proto)
+{
+ struct {
+ struct in6_addr saddr;
+ struct in6_addr daddr;
+ uint32_t payload_len;
+ uint8_t zero[3];
+ uint8_t nexthdr;
+ } pseudo_header = {};
+ uint32_t sum;
+
+ pseudo_header.saddr = src->sin6_addr;
+ pseudo_header.daddr = dst->sin6_addr;
+ pseudo_header.payload_len = htonl(len);
+ pseudo_header.nexthdr = proto;
+
+ sum = checksum6_nofold(&pseudo_header, sizeof(pseudo_header), 0);
+ sum = checksum6_nofold(ptr, len, sum);
+
+ return csum_fold(sum);
+}
+
+static void icmp6_interfere(int type, int code, uint32_t rcv_nxt,
+ struct sockaddr_in6 *src, struct sockaddr_in6 *dst)
+{
+ int sk = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW);
+ struct sockaddr_in6 dst_raw = *dst;
+ struct {
+ struct ipv6hdr iph;
+ struct icmp6hdr icmph;
+ struct ipv6hdr iphe;
+ struct {
+ uint16_t sport;
+ uint16_t dport;
+ uint32_t seq;
+ } tcph;
+ } packet = {};
+ size_t packet_len;
+ ssize_t bytes;
+
+
+ if (sk < 0)
+ test_error("socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)");
+
+ packet_len = sizeof(packet) - sizeof(packet.iph);
+ set_ip6hdr(&packet.iph, packet_len, IPPROTO_ICMPV6, src, dst);
+
+ packet.icmph.icmp6_type = type;
+ packet.icmph.icmp6_code = code;
+
+ packet_len = sizeof(packet.iphe) + sizeof(packet.tcph);
+ set_ip6hdr(&packet.iphe, packet_len, IPPROTO_TCP, dst, src);
+
+ packet.tcph.sport = dst->sin6_port;
+ packet.tcph.dport = src->sin6_port;
+ packet.tcph.seq = htonl(rcv_nxt);
+
+ packet_len = sizeof(packet) - sizeof(packet.iph);
+
+ packet.icmph.icmp6_cksum = icmp6_checksum(src, dst,
+ (void *)&packet.icmph, packet_len, IPPROTO_ICMPV6);
+
+ dst_raw.sin6_port = htons(IPPROTO_RAW);
+ bytes = sendto(sk, &packet, sizeof(packet), 0,
+ (struct sockaddr *)&dst_raw, sizeof(dst_raw));
+ if (bytes != sizeof(packet))
+ test_error("send(): %zd", bytes);
+ icmps_sent++;
+
+ close(sk);
+}
+
+static uint32_t get_rcv_nxt(int sk)
+{
+ int val = TCP_REPAIR_ON;
+ uint32_t ret;
+ socklen_t sz = sizeof(ret);
+
+ if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
+ test_error("setsockopt(TCP_REPAIR)");
+ val = TCP_RECV_QUEUE;
+ if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &val, sizeof(val)))
+ test_error("setsockopt(TCP_REPAIR_QUEUE)");
+ if (getsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &ret, &sz))
+ test_error("getsockopt(TCP_QUEUE_SEQ)");
+ val = TCP_REPAIR_OFF_NO_WP;
+ if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
+ test_error("setsockopt(TCP_REPAIR)");
+ return ret;
+}
+
+static void icmp_interfere(const size_t nr, uint32_t rcv_nxt, void *src, void *dst)
+{
+ struct sockaddr_in *saddr4 = src;
+ struct sockaddr_in *daddr4 = dst;
+ struct sockaddr_in6 *saddr6 = src;
+ struct sockaddr_in6 *daddr6 = dst;
+ size_t i;
+
+ if (saddr4->sin_family != daddr4->sin_family)
+ test_error("Different address families");
+
+ for (i = 0; i < nr; i++) {
+ if (saddr4->sin_family == AF_INET) {
+ icmp_interfere4(ICMP_DEST_UNREACH, ICMP_PROT_UNREACH,
+ rcv_nxt, saddr4, daddr4);
+ icmp_interfere4(ICMP_DEST_UNREACH, ICMP_PORT_UNREACH,
+ rcv_nxt, saddr4, daddr4);
+ icmp_interfere4(ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED,
+ rcv_nxt, saddr4, daddr4);
+ icmps_sent += 3;
+ } else if (saddr4->sin_family == AF_INET6) {
+ icmp6_interfere(ICMPV6_DEST_UNREACH,
+ ICMPV6_ADM_PROHIBITED,
+ rcv_nxt, saddr6, daddr6);
+ icmp6_interfere(ICMPV6_DEST_UNREACH,
+ ICMPV6_PORT_UNREACH,
+ rcv_nxt, saddr6, daddr6);
+ icmps_sent += 2;
+ } else {
+ test_error("Not ip address family");
+ }
+ }
+}
+
+static void send_interfered(int sk)
+{
+ const unsigned int timeout = TEST_TIMEOUT_SEC;
+ struct sockaddr_in6 src, dst;
+ socklen_t addr_sz;
+
+ addr_sz = sizeof(src);
+ if (getsockname(sk, &src, &addr_sz))
+ test_error("getsockname()");
+ addr_sz = sizeof(dst);
+ if (getpeername(sk, &dst, &addr_sz))
+ test_error("getpeername()");
+
+ while (1) {
+ uint32_t rcv_nxt;
+
+ if (test_client_verify(sk, packet_size, packets_nr, timeout)) {
+ test_fail("client: connection is broken");
+ return;
+ }
+ packets_sent += packets_nr;
+ rcv_nxt = get_rcv_nxt(sk);
+ icmp_interfere(packets_nr, rcv_nxt, (void *)&src, (void *)&dst);
+ }
+}
+
+static void *client_fn(void *arg)
+{
+ int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+
+ if (sk < 0)
+ test_error("socket()");
+
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ synchronize_threads();
+ if (test_connect_socket(sk, this_ip_dest, test_server_port) <= 0)
+ test_error("failed to connect()");
+ synchronize_threads();
+
+ send_interfered(sk);
+
+ /* Not expecting client to quit */
+ test_fail("client disconnected");
+
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(3, server_fn, client_fn);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/key-management.c b/tools/testing/selftests/net/tcp_ao/key-management.c
new file mode 100644
index 000000000000..24e62120b792
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/key-management.c
@@ -0,0 +1,1186 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Author: Dmitry Safonov <dima@arista.com> */
+#include <inttypes.h>
+#include "../../../../include/linux/kernel.h"
+#include "aolib.h"
+
+const size_t nr_packets = 20;
+const size_t msg_len = 100;
+const size_t quota = nr_packets * msg_len;
+union tcp_addr wrong_addr;
+#define SECOND_PASSWORD "at all times sincere friends of freedom have been rare"
+#define fault(type) (inj == FAULT_ ## type)
+
+static const int test_vrf_ifindex = 200;
+static const uint8_t test_vrf_tabid = 42;
+static void setup_vrfs(void)
+{
+ int err;
+
+ if (!kernel_config_has(KCONFIG_NET_VRF))
+ return;
+
+ err = add_vrf("ksft-vrf", test_vrf_tabid, test_vrf_ifindex, -1);
+ if (err)
+ test_error("Failed to add a VRF: %d", err);
+
+ err = link_set_up("ksft-vrf");
+ if (err)
+ test_error("Failed to bring up a VRF");
+
+ err = ip_route_add_vrf(veth_name, TEST_FAMILY,
+ this_ip_addr, this_ip_dest, test_vrf_tabid);
+ if (err)
+ test_error("Failed to add a route to VRF");
+}
+
+
+static int prepare_sk(union tcp_addr *addr, uint8_t sndid, uint8_t rcvid)
+{
+ int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+
+ if (sk < 0)
+ test_error("socket()");
+
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest,
+ DEFAULT_TEST_PREFIX, 100, 100))
+ test_error("test_add_key()");
+
+ if (addr && test_add_key(sk, SECOND_PASSWORD, *addr,
+ DEFAULT_TEST_PREFIX, sndid, rcvid))
+ test_error("test_add_key()");
+
+ return sk;
+}
+
+static int prepare_lsk(union tcp_addr *addr, uint8_t sndid, uint8_t rcvid)
+{
+ int sk = prepare_sk(addr, sndid, rcvid);
+
+ if (listen(sk, 10))
+ test_error("listen()");
+
+ return sk;
+}
+
+static int test_del_key(int sk, uint8_t sndid, uint8_t rcvid, bool async,
+ int current_key, int rnext_key)
+{
+ struct tcp_ao_info_opt ao_info = {};
+ struct tcp_ao_getsockopt key = {};
+ struct tcp_ao_del del = {};
+ sockaddr_af sockaddr;
+ int err;
+
+ tcp_addr_to_sockaddr_in(&del.addr, &this_ip_dest, 0);
+ del.prefix = DEFAULT_TEST_PREFIX;
+ del.sndid = sndid;
+ del.rcvid = rcvid;
+
+ if (current_key >= 0) {
+ del.set_current = 1;
+ del.current_key = (uint8_t)current_key;
+ }
+ if (rnext_key >= 0) {
+ del.set_rnext = 1;
+ del.rnext = (uint8_t)rnext_key;
+ }
+
+ err = setsockopt(sk, IPPROTO_TCP, TCP_AO_DEL_KEY, &del, sizeof(del));
+ if (err < 0)
+ return -errno;
+
+ if (async)
+ return 0;
+
+ tcp_addr_to_sockaddr_in(&sockaddr, &this_ip_dest, 0);
+ err = test_get_one_ao(sk, &key, &sockaddr, sizeof(sockaddr),
+ DEFAULT_TEST_PREFIX, sndid, rcvid);
+ if (!err)
+ return -EEXIST;
+ if (err != -E2BIG)
+ test_error("getsockopt()");
+ if (current_key < 0 && rnext_key < 0)
+ return 0;
+ if (test_get_ao_info(sk, &ao_info))
+ test_error("getsockopt(TCP_AO_INFO) failed");
+ if (current_key >= 0 && ao_info.current_key != (uint8_t)current_key)
+ return -ENOTRECOVERABLE;
+ if (rnext_key >= 0 && ao_info.rnext != (uint8_t)rnext_key)
+ return -ENOTRECOVERABLE;
+ return 0;
+}
+
+static void try_delete_key(char *tst_name, int sk, uint8_t sndid, uint8_t rcvid,
+ bool async, int current_key, int rnext_key,
+ fault_t inj)
+{
+ int err;
+
+ err = test_del_key(sk, sndid, rcvid, async, current_key, rnext_key);
+ if ((err == -EBUSY && fault(BUSY)) || (err == -EINVAL && fault(CURRNEXT))) {
+ test_ok("%s: key deletion was prevented", tst_name);
+ return;
+ }
+ if (err && fault(FIXME)) {
+ test_xfail("%s: failed to delete the key %u:%u %d",
+ tst_name, sndid, rcvid, err);
+ return;
+ }
+ if (!err) {
+ if (fault(BUSY) || fault(CURRNEXT)) {
+ test_fail("%s: the key was deleted %u:%u %d", tst_name,
+ sndid, rcvid, err);
+ } else {
+ test_ok("%s: the key was deleted", tst_name);
+ }
+ return;
+ }
+ test_fail("%s: can't delete the key %u:%u %d", tst_name, sndid, rcvid, err);
+}
+
+static int test_set_key(int sk, int current_keyid, int rnext_keyid)
+{
+ struct tcp_ao_info_opt ao_info = {};
+ int err;
+
+ if (current_keyid >= 0) {
+ ao_info.set_current = 1;
+ ao_info.current_key = (uint8_t)current_keyid;
+ }
+ if (rnext_keyid >= 0) {
+ ao_info.set_rnext = 1;
+ ao_info.rnext = (uint8_t)rnext_keyid;
+ }
+
+ err = test_set_ao_info(sk, &ao_info);
+ if (err)
+ return err;
+ if (test_get_ao_info(sk, &ao_info))
+ test_error("getsockopt(TCP_AO_INFO) failed");
+ if (current_keyid >= 0 && ao_info.current_key != (uint8_t)current_keyid)
+ return -ENOTRECOVERABLE;
+ if (rnext_keyid >= 0 && ao_info.rnext != (uint8_t)rnext_keyid)
+ return -ENOTRECOVERABLE;
+ return 0;
+}
+
+static int test_add_current_rnext_key(int sk, const char *key, uint8_t keyflags,
+ union tcp_addr in_addr, uint8_t prefix,
+ bool set_current, bool set_rnext,
+ uint8_t sndid, uint8_t rcvid)
+{
+ struct tcp_ao_add tmp = {};
+ int err;
+
+ err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr,
+ set_current, set_rnext,
+ prefix, 0, sndid, rcvid, 0, keyflags,
+ strlen(key), key);
+ if (err)
+ return err;
+
+
+ err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
+ if (err < 0)
+ return -errno;
+
+ return test_verify_socket_key(sk, &tmp);
+}
+
+static int __try_add_current_rnext_key(int sk, const char *key, uint8_t keyflags,
+ union tcp_addr in_addr, uint8_t prefix,
+ bool set_current, bool set_rnext,
+ uint8_t sndid, uint8_t rcvid)
+{
+ struct tcp_ao_info_opt ao_info = {};
+ int err;
+
+ err = test_add_current_rnext_key(sk, key, keyflags, in_addr, prefix,
+ set_current, set_rnext, sndid, rcvid);
+ if (err)
+ return err;
+
+ if (test_get_ao_info(sk, &ao_info))
+ test_error("getsockopt(TCP_AO_INFO) failed");
+ if (set_current && ao_info.current_key != sndid)
+ return -ENOTRECOVERABLE;
+ if (set_rnext && ao_info.rnext != rcvid)
+ return -ENOTRECOVERABLE;
+ return 0;
+}
+
+static void try_add_current_rnext_key(char *tst_name, int sk, const char *key,
+ uint8_t keyflags,
+ union tcp_addr in_addr, uint8_t prefix,
+ bool set_current, bool set_rnext,
+ uint8_t sndid, uint8_t rcvid, fault_t inj)
+{
+ int err;
+
+ err = __try_add_current_rnext_key(sk, key, keyflags, in_addr, prefix,
+ set_current, set_rnext, sndid, rcvid);
+ if (!err && !fault(CURRNEXT)) {
+ test_ok("%s", tst_name);
+ return;
+ }
+ if (err == -EINVAL && fault(CURRNEXT)) {
+ test_ok("%s", tst_name);
+ return;
+ }
+ test_fail("%s", tst_name);
+}
+
+static void check_closed_socket(void)
+{
+ int sk;
+
+ sk = prepare_sk(&this_ip_dest, 200, 200);
+ try_delete_key("closed socket, delete a key", sk, 200, 200, 0, -1, -1, 0);
+ try_delete_key("closed socket, delete all keys", sk, 100, 100, 0, -1, -1, 0);
+ close(sk);
+
+ sk = prepare_sk(&this_ip_dest, 200, 200);
+ if (test_set_key(sk, 100, 200))
+ test_error("failed to set current/rnext keys");
+ try_delete_key("closed socket, delete current key", sk, 100, 100, 0, -1, -1, FAULT_BUSY);
+ try_delete_key("closed socket, delete rnext key", sk, 200, 200, 0, -1, -1, FAULT_BUSY);
+ close(sk);
+
+ sk = prepare_sk(&this_ip_dest, 200, 200);
+ if (test_add_key(sk, "Glory to heros!", this_ip_dest,
+ DEFAULT_TEST_PREFIX, 10, 11))
+ test_error("test_add_key()");
+ if (test_add_key(sk, "Glory to Ukraine!", this_ip_dest,
+ DEFAULT_TEST_PREFIX, 12, 13))
+ test_error("test_add_key()");
+ try_delete_key("closed socket, delete a key + set current/rnext", sk, 100, 100, 0, 10, 13, 0);
+ try_delete_key("closed socket, force-delete current key", sk, 10, 11, 0, 200, -1, 0);
+ try_delete_key("closed socket, force-delete rnext key", sk, 12, 13, 0, -1, 200, 0);
+ try_delete_key("closed socket, delete current+rnext key", sk, 200, 200, 0, -1, -1, FAULT_BUSY);
+ close(sk);
+
+ sk = prepare_sk(&this_ip_dest, 200, 200);
+ if (test_set_key(sk, 100, 200))
+ test_error("failed to set current/rnext keys");
+ try_add_current_rnext_key("closed socket, add + change current key",
+ sk, "Laaaa! Lalala-la-la-lalala...", 0,
+ this_ip_dest, DEFAULT_TEST_PREFIX,
+ true, false, 10, 20, 0);
+ try_add_current_rnext_key("closed socket, add + change rnext key",
+ sk, "Laaaa! Lalala-la-la-lalala...", 0,
+ this_ip_dest, DEFAULT_TEST_PREFIX,
+ false, true, 20, 10, 0);
+ close(sk);
+}
+
+static void assert_no_current_rnext(const char *tst_msg, int sk)
+{
+ struct tcp_ao_info_opt ao_info = {};
+
+ if (test_get_ao_info(sk, &ao_info))
+ test_error("getsockopt(TCP_AO_INFO) failed");
+
+ errno = 0;
+ if (ao_info.set_current || ao_info.set_rnext) {
+ test_xfail("%s: the socket has current/rnext keys: %d:%d",
+ tst_msg,
+ (ao_info.set_current) ? ao_info.current_key : -1,
+ (ao_info.set_rnext) ? ao_info.rnext : -1);
+ } else {
+ test_ok("%s: the socket has no current/rnext keys", tst_msg);
+ }
+}
+
+static void assert_no_tcp_repair(void)
+{
+ struct tcp_ao_repair ao_img = {};
+ socklen_t len = sizeof(ao_img);
+ int sk, err;
+
+ sk = prepare_sk(&this_ip_dest, 200, 200);
+ test_enable_repair(sk);
+ if (listen(sk, 10))
+ test_error("listen()");
+ errno = 0;
+ err = getsockopt(sk, SOL_TCP, TCP_AO_REPAIR, &ao_img, &len);
+ if (err && errno == EPERM)
+ test_ok("listen socket, getsockopt(TCP_AO_REPAIR) is restricted");
+ else
+ test_fail("listen socket, getsockopt(TCP_AO_REPAIR) works");
+ errno = 0;
+ err = setsockopt(sk, SOL_TCP, TCP_AO_REPAIR, &ao_img, sizeof(ao_img));
+ if (err && errno == EPERM)
+ test_ok("listen socket, setsockopt(TCP_AO_REPAIR) is restricted");
+ else
+ test_fail("listen socket, setsockopt(TCP_AO_REPAIR) works");
+ close(sk);
+}
+
+static void check_listen_socket(void)
+{
+ int sk, err;
+
+ sk = prepare_lsk(&this_ip_dest, 200, 200);
+ try_delete_key("listen socket, delete a key", sk, 200, 200, 0, -1, -1, 0);
+ try_delete_key("listen socket, delete all keys", sk, 100, 100, 0, -1, -1, 0);
+ close(sk);
+
+ sk = prepare_lsk(&this_ip_dest, 200, 200);
+ err = test_set_key(sk, 100, -1);
+ if (err == -EINVAL)
+ test_ok("listen socket, setting current key not allowed");
+ else
+ test_fail("listen socket, set current key");
+ err = test_set_key(sk, -1, 200);
+ if (err == -EINVAL)
+ test_ok("listen socket, setting rnext key not allowed");
+ else
+ test_fail("listen socket, set rnext key");
+ close(sk);
+
+ sk = prepare_sk(&this_ip_dest, 200, 200);
+ if (test_set_key(sk, 100, 200))
+ test_error("failed to set current/rnext keys");
+ if (listen(sk, 10))
+ test_error("listen()");
+ assert_no_current_rnext("listen() after current/rnext keys set", sk);
+ try_delete_key("listen socket, delete current key from before listen()", sk, 100, 100, 0, -1, -1, FAULT_FIXME);
+ try_delete_key("listen socket, delete rnext key from before listen()", sk, 200, 200, 0, -1, -1, FAULT_FIXME);
+ close(sk);
+
+ assert_no_tcp_repair();
+
+ sk = prepare_lsk(&this_ip_dest, 200, 200);
+ if (test_add_key(sk, "Glory to heros!", this_ip_dest,
+ DEFAULT_TEST_PREFIX, 10, 11))
+ test_error("test_add_key()");
+ if (test_add_key(sk, "Glory to Ukraine!", this_ip_dest,
+ DEFAULT_TEST_PREFIX, 12, 13))
+ test_error("test_add_key()");
+ try_delete_key("listen socket, delete a key + set current/rnext", sk,
+ 100, 100, 0, 10, 13, FAULT_CURRNEXT);
+ try_delete_key("listen socket, force-delete current key", sk,
+ 10, 11, 0, 200, -1, FAULT_CURRNEXT);
+ try_delete_key("listen socket, force-delete rnext key", sk,
+ 12, 13, 0, -1, 200, FAULT_CURRNEXT);
+ try_delete_key("listen socket, delete a key", sk,
+ 200, 200, 0, -1, -1, 0);
+ close(sk);
+
+ sk = prepare_lsk(&this_ip_dest, 200, 200);
+ try_add_current_rnext_key("listen socket, add + change current key",
+ sk, "Laaaa! Lalala-la-la-lalala...", 0,
+ this_ip_dest, DEFAULT_TEST_PREFIX,
+ true, false, 10, 20, FAULT_CURRNEXT);
+ try_add_current_rnext_key("listen socket, add + change rnext key",
+ sk, "Laaaa! Lalala-la-la-lalala...", 0,
+ this_ip_dest, DEFAULT_TEST_PREFIX,
+ false, true, 20, 10, FAULT_CURRNEXT);
+ close(sk);
+}
+
+static const char *fips_fpath = "/proc/sys/crypto/fips_enabled";
+static bool is_fips_enabled(void)
+{
+ static int fips_checked = -1;
+ FILE *fenabled;
+ int enabled;
+
+ if (fips_checked >= 0)
+ return !!fips_checked;
+ if (access(fips_fpath, R_OK)) {
+ if (errno != ENOENT)
+ test_error("Can't open %s", fips_fpath);
+ fips_checked = 0;
+ return false;
+ }
+ fenabled = fopen(fips_fpath, "r");
+ if (!fenabled)
+ test_error("Can't open %s", fips_fpath);
+ if (fscanf(fenabled, "%d", &enabled) != 1)
+ test_error("Can't read from %s", fips_fpath);
+ fclose(fenabled);
+ fips_checked = !!enabled;
+ return !!fips_checked;
+}
+
+struct test_key {
+ char password[TCP_AO_MAXKEYLEN];
+ const char *alg;
+ unsigned int len;
+ uint8_t client_keyid;
+ uint8_t server_keyid;
+ uint8_t maclen;
+ uint8_t matches_client : 1,
+ matches_server : 1,
+ matches_vrf : 1,
+ is_current : 1,
+ is_rnext : 1,
+ used_on_server_tx : 1,
+ used_on_client_tx : 1,
+ skip_counters_checks : 1;
+};
+
+struct key_collection {
+ unsigned int nr_keys;
+ struct test_key *keys;
+};
+
+static struct key_collection collection;
+
+#define TEST_MAX_MACLEN 16
+const char *test_algos[] = {
+ "cmac(aes128)",
+ "hmac(sha1)", "hmac(sha512)", "hmac(sha384)", "hmac(sha256)",
+ "hmac(sha224)", "hmac(sha3-512)",
+ /* only if !CONFIG_FIPS */
+#define TEST_NON_FIPS_ALGOS 2
+ "hmac(rmd160)", "hmac(md5)"
+};
+const unsigned int test_maclens[] = { 1, 4, 12, 16 };
+#define MACLEN_SHIFT 2
+#define ALGOS_SHIFT 4
+
+static unsigned int make_mask(unsigned int shift, unsigned int prev_shift)
+{
+ unsigned int ret = BIT(shift) - 1;
+
+ return ret << prev_shift;
+}
+
+static void init_key_in_collection(unsigned int index, bool randomized)
+{
+ struct test_key *key = &collection.keys[index];
+ unsigned int algos_nr, algos_index;
+
+ /* Same for randomized and non-randomized test flows */
+ key->client_keyid = index;
+ key->server_keyid = 127 + index;
+ key->matches_client = 1;
+ key->matches_server = 1;
+ key->matches_vrf = 1;
+ /* not really even random, but good enough for a test */
+ key->len = rand() % (TCP_AO_MAXKEYLEN - TEST_TCP_AO_MINKEYLEN);
+ key->len += TEST_TCP_AO_MINKEYLEN;
+ randomize_buffer(key->password, key->len);
+
+ if (randomized) {
+ key->maclen = (rand() % TEST_MAX_MACLEN) + 1;
+ algos_index = rand();
+ } else {
+ unsigned int shift = MACLEN_SHIFT;
+
+ key->maclen = test_maclens[index & make_mask(shift, 0)];
+ algos_index = index & make_mask(ALGOS_SHIFT, shift);
+ }
+ algos_nr = ARRAY_SIZE(test_algos);
+ if (is_fips_enabled())
+ algos_nr -= TEST_NON_FIPS_ALGOS;
+ key->alg = test_algos[algos_index % algos_nr];
+}
+
+static int init_default_key_collection(unsigned int nr_keys, bool randomized)
+{
+ size_t key_sz = sizeof(collection.keys[0]);
+
+ if (!nr_keys) {
+ free(collection.keys);
+ collection.keys = NULL;
+ return 0;
+ }
+
+ /*
+ * All keys have uniq sndid/rcvid and sndid != rcvid in order to
+ * check for any bugs/issues for different keyids, visible to both
+ * peers. Keyid == 254 is unused.
+ */
+ if (nr_keys > 127)
+ test_error("Test requires too many keys, correct the source");
+
+ collection.keys = reallocarray(collection.keys, nr_keys, key_sz);
+ if (!collection.keys)
+ return -ENOMEM;
+
+ memset(collection.keys, 0, nr_keys * key_sz);
+ collection.nr_keys = nr_keys;
+ while (nr_keys--)
+ init_key_in_collection(nr_keys, randomized);
+
+ return 0;
+}
+
+static void test_key_error(const char *msg, struct test_key *key)
+{
+ test_error("%s: key: { %s, %u:%u, %u, %u:%u:%u:%u:%u (%u)}",
+ msg, key->alg, key->client_keyid, key->server_keyid,
+ key->maclen, key->matches_client, key->matches_server,
+ key->matches_vrf, key->is_current, key->is_rnext, key->len);
+}
+
+static int test_add_key_cr(int sk, const char *pwd, unsigned int pwd_len,
+ union tcp_addr addr, uint8_t vrf,
+ uint8_t sndid, uint8_t rcvid,
+ uint8_t maclen, const char *alg,
+ bool set_current, bool set_rnext)
+{
+ struct tcp_ao_add tmp = {};
+ uint8_t keyflags = 0;
+ int err;
+
+ if (!alg)
+ alg = DEFAULT_TEST_ALGO;
+
+ if (vrf)
+ keyflags |= TCP_AO_KEYF_IFINDEX;
+ err = test_prepare_key(&tmp, alg, addr, set_current, set_rnext,
+ DEFAULT_TEST_PREFIX, vrf, sndid, rcvid, maclen,
+ keyflags, pwd_len, pwd);
+ if (err)
+ return err;
+
+ err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
+ if (err < 0)
+ return -errno;
+
+ return test_verify_socket_key(sk, &tmp);
+}
+
+static void verify_current_rnext(const char *tst, int sk,
+ int current_keyid, int rnext_keyid)
+{
+ struct tcp_ao_info_opt ao_info = {};
+
+ if (test_get_ao_info(sk, &ao_info))
+ test_error("getsockopt(TCP_AO_INFO) failed");
+
+ errno = 0;
+ if (current_keyid >= 0) {
+ if (!ao_info.set_current)
+ test_fail("%s: the socket doesn't have current key", tst);
+ else if (ao_info.current_key != current_keyid)
+ test_fail("%s: current key is not the expected one %d != %u",
+ tst, current_keyid, ao_info.current_key);
+ else
+ test_ok("%s: current key %u as expected",
+ tst, ao_info.current_key);
+ }
+ if (rnext_keyid >= 0) {
+ if (!ao_info.set_rnext)
+ test_fail("%s: the socket doesn't have rnext key", tst);
+ else if (ao_info.rnext != rnext_keyid)
+ test_fail("%s: rnext key is not the expected one %d != %u",
+ tst, rnext_keyid, ao_info.rnext);
+ else
+ test_ok("%s: rnext key %u as expected", tst, ao_info.rnext);
+ }
+}
+
+
+static int key_collection_socket(bool server, unsigned int port)
+{
+ unsigned int i;
+ int sk;
+
+ if (server)
+ sk = test_listen_socket(this_ip_addr, port, 1);
+ else
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ for (i = 0; i < collection.nr_keys; i++) {
+ struct test_key *key = &collection.keys[i];
+ union tcp_addr *addr = &wrong_addr;
+ uint8_t sndid, rcvid, vrf;
+ bool set_current = false, set_rnext = false;
+
+ if (key->matches_vrf)
+ vrf = 0;
+ else
+ vrf = test_vrf_ifindex;
+ if (server) {
+ if (key->matches_client)
+ addr = &this_ip_dest;
+ sndid = key->server_keyid;
+ rcvid = key->client_keyid;
+ } else {
+ if (key->matches_server)
+ addr = &this_ip_dest;
+ sndid = key->client_keyid;
+ rcvid = key->server_keyid;
+ key->used_on_client_tx = set_current = key->is_current;
+ key->used_on_server_tx = set_rnext = key->is_rnext;
+ }
+
+ if (test_add_key_cr(sk, key->password, key->len,
+ *addr, vrf, sndid, rcvid, key->maclen,
+ key->alg, set_current, set_rnext))
+ test_key_error("setsockopt(TCP_AO_ADD_KEY)", key);
+#ifdef DEBUG
+ test_print("%s [%u/%u] key: { %s, %u:%u, %u, %u:%u:%u:%u (%u)}",
+ server ? "server" : "client", i, collection.nr_keys,
+ key->alg, rcvid, sndid, key->maclen,
+ key->matches_client, key->matches_server,
+ key->is_current, key->is_rnext, key->len);
+#endif
+ }
+ return sk;
+}
+
+static void verify_counters(const char *tst_name, bool is_listen_sk, bool server,
+ struct tcp_ao_counters *a, struct tcp_ao_counters *b)
+{
+ unsigned int i;
+
+ __test_tcp_ao_counters_cmp(tst_name, a, b, TEST_CNT_GOOD);
+
+ for (i = 0; i < collection.nr_keys; i++) {
+ struct test_key *key = &collection.keys[i];
+ uint8_t sndid, rcvid;
+ bool rx_cnt_expected;
+
+ if (key->skip_counters_checks)
+ continue;
+ if (server) {
+ sndid = key->server_keyid;
+ rcvid = key->client_keyid;
+ rx_cnt_expected = key->used_on_client_tx;
+ } else {
+ sndid = key->client_keyid;
+ rcvid = key->server_keyid;
+ rx_cnt_expected = key->used_on_server_tx;
+ }
+
+ test_tcp_ao_key_counters_cmp(tst_name, a, b,
+ rx_cnt_expected ? TEST_CNT_KEY_GOOD : 0,
+ sndid, rcvid);
+ }
+ test_tcp_ao_counters_free(a);
+ test_tcp_ao_counters_free(b);
+ test_ok("%s: passed counters checks", tst_name);
+}
+
+static struct tcp_ao_getsockopt *lookup_key(struct tcp_ao_getsockopt *buf,
+ size_t len, int sndid, int rcvid)
+{
+ size_t i;
+
+ for (i = 0; i < len; i++) {
+ if (sndid >= 0 && buf[i].sndid != sndid)
+ continue;
+ if (rcvid >= 0 && buf[i].rcvid != rcvid)
+ continue;
+ return &buf[i];
+ }
+ return NULL;
+}
+
+static void verify_keys(const char *tst_name, int sk,
+ bool is_listen_sk, bool server)
+{
+ socklen_t len = sizeof(struct tcp_ao_getsockopt);
+ struct tcp_ao_getsockopt *keys;
+ bool passed_test = true;
+ unsigned int i;
+
+ keys = calloc(collection.nr_keys, len);
+ if (!keys)
+ test_error("calloc()");
+
+ keys->nkeys = collection.nr_keys;
+ keys->get_all = 1;
+
+ if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, keys, &len)) {
+ free(keys);
+ test_error("getsockopt(TCP_AO_GET_KEYS)");
+ }
+
+ for (i = 0; i < collection.nr_keys; i++) {
+ struct test_key *key = &collection.keys[i];
+ struct tcp_ao_getsockopt *dump_key;
+ bool is_kdf_aes_128_cmac = false;
+ bool is_cmac_aes = false;
+ uint8_t sndid, rcvid;
+ bool matches = false;
+
+ if (server) {
+ if (key->matches_client)
+ matches = true;
+ sndid = key->server_keyid;
+ rcvid = key->client_keyid;
+ } else {
+ if (key->matches_server)
+ matches = true;
+ sndid = key->client_keyid;
+ rcvid = key->server_keyid;
+ }
+ if (!key->matches_vrf)
+ matches = false;
+ /* no keys get removed on the original listener socket */
+ if (is_listen_sk)
+ matches = true;
+
+ dump_key = lookup_key(keys, keys->nkeys, sndid, rcvid);
+ if (matches != !!dump_key) {
+ test_fail("%s: key %u:%u %s%s on the socket",
+ tst_name, sndid, rcvid,
+ key->matches_vrf ? "" : "[vrf] ",
+ matches ? "disappeared" : "yet present");
+ passed_test = false;
+ goto out;
+ }
+ if (!dump_key)
+ continue;
+
+ if (!strcmp("cmac(aes128)", key->alg)) {
+ is_kdf_aes_128_cmac = (key->len != 16);
+ is_cmac_aes = true;
+ }
+
+ if (is_cmac_aes) {
+ if (strcmp(dump_key->alg_name, "cmac(aes)")) {
+ test_fail("%s: key %u:%u cmac(aes) has unexpected alg %s",
+ tst_name, sndid, rcvid,
+ dump_key->alg_name);
+ passed_test = false;
+ continue;
+ }
+ } else if (strcmp(dump_key->alg_name, key->alg)) {
+ test_fail("%s: key %u:%u has unexpected alg %s != %s",
+ tst_name, sndid, rcvid,
+ dump_key->alg_name, key->alg);
+ passed_test = false;
+ continue;
+ }
+ if (is_kdf_aes_128_cmac) {
+ if (dump_key->keylen != 16) {
+ test_fail("%s: key %u:%u cmac(aes128) has unexpected len %u",
+ tst_name, sndid, rcvid,
+ dump_key->keylen);
+ continue;
+ }
+ } else if (dump_key->keylen != key->len) {
+ test_fail("%s: key %u:%u changed password len %u != %u",
+ tst_name, sndid, rcvid,
+ dump_key->keylen, key->len);
+ passed_test = false;
+ continue;
+ }
+ if (!is_kdf_aes_128_cmac &&
+ memcmp(dump_key->key, key->password, key->len)) {
+ test_fail("%s: key %u:%u has different password",
+ tst_name, sndid, rcvid);
+ passed_test = false;
+ continue;
+ }
+ if (dump_key->maclen != key->maclen) {
+ test_fail("%s: key %u:%u changed maclen %u != %u",
+ tst_name, sndid, rcvid,
+ dump_key->maclen, key->maclen);
+ passed_test = false;
+ continue;
+ }
+ }
+
+ if (passed_test)
+ test_ok("%s: The socket keys are consistent with the expectations",
+ tst_name);
+out:
+ free(keys);
+}
+
+static int start_server(const char *tst_name, unsigned int port, size_t quota,
+ struct tcp_ao_counters *begin,
+ unsigned int current_index, unsigned int rnext_index)
+{
+ struct tcp_ao_counters lsk_c1, lsk_c2;
+ ssize_t bytes;
+ int sk, lsk;
+
+ synchronize_threads(); /* 1: key collection initialized */
+ lsk = key_collection_socket(true, port);
+ if (test_get_tcp_ao_counters(lsk, &lsk_c1))
+ test_error("test_get_tcp_ao_counters()");
+ synchronize_threads(); /* 2: MKTs added => connect() */
+ if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
+ test_error("test_wait_fd()");
+
+ sk = accept(lsk, NULL, NULL);
+ if (sk < 0)
+ test_error("accept()");
+ if (test_get_tcp_ao_counters(sk, begin))
+ test_error("test_get_tcp_ao_counters()");
+
+ synchronize_threads(); /* 3: accepted => send data */
+ if (test_get_tcp_ao_counters(lsk, &lsk_c2))
+ test_error("test_get_tcp_ao_counters()");
+ verify_keys(tst_name, lsk, true, true);
+ close(lsk);
+
+ bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
+ if (bytes != quota)
+ test_fail("%s: server served: %zd", tst_name, bytes);
+ else
+ test_ok("%s: server alive", tst_name);
+
+ verify_counters(tst_name, true, true, &lsk_c1, &lsk_c2);
+
+ return sk;
+}
+
+static void end_server(const char *tst_name, int sk,
+ struct tcp_ao_counters *begin)
+{
+ struct tcp_ao_counters end;
+
+ if (test_get_tcp_ao_counters(sk, &end))
+ test_error("test_get_tcp_ao_counters()");
+ verify_keys(tst_name, sk, false, true);
+
+ synchronize_threads(); /* 4: verified => closed */
+ close(sk);
+
+ verify_counters(tst_name, false, true, begin, &end);
+ synchronize_threads(); /* 5: counters */
+}
+
+static void try_server_run(const char *tst_name, unsigned int port, size_t quota,
+ unsigned int current_index, unsigned int rnext_index)
+{
+ struct tcp_ao_counters tmp;
+ int sk;
+
+ sk = start_server(tst_name, port, quota, &tmp,
+ current_index, rnext_index);
+ end_server(tst_name, sk, &tmp);
+}
+
+static void server_rotations(const char *tst_name, unsigned int port,
+ size_t quota, unsigned int rotations,
+ unsigned int current_index, unsigned int rnext_index)
+{
+ struct tcp_ao_counters tmp;
+ unsigned int i;
+ int sk;
+
+ sk = start_server(tst_name, port, quota, &tmp,
+ current_index, rnext_index);
+
+ for (i = current_index + 1; rotations > 0; i++, rotations--) {
+ ssize_t bytes;
+
+ if (i >= collection.nr_keys)
+ i = 0;
+ bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
+ if (bytes != quota) {
+ test_fail("%s: server served: %zd", tst_name, bytes);
+ return;
+ }
+ verify_current_rnext(tst_name, sk,
+ collection.keys[i].server_keyid, -1);
+ synchronize_threads(); /* verify current/rnext */
+ }
+ end_server(tst_name, sk, &tmp);
+}
+
+static int run_client(const char *tst_name, unsigned int port,
+ unsigned int nr_keys, int current_index, int rnext_index,
+ struct tcp_ao_counters *before,
+ const size_t msg_sz, const size_t msg_nr)
+{
+ int sk;
+
+ synchronize_threads(); /* 1: key collection initialized */
+ sk = key_collection_socket(false, port);
+
+ if (current_index >= 0 || rnext_index >= 0) {
+ int sndid = -1, rcvid = -1;
+
+ if (current_index >= 0)
+ sndid = collection.keys[current_index].client_keyid;
+ if (rnext_index >= 0)
+ rcvid = collection.keys[rnext_index].server_keyid;
+ if (test_set_key(sk, sndid, rcvid))
+ test_error("failed to set current/rnext keys");
+ }
+ if (before && test_get_tcp_ao_counters(sk, before))
+ test_error("test_get_tcp_ao_counters()");
+
+ synchronize_threads(); /* 2: MKTs added => connect() */
+ if (test_connect_socket(sk, this_ip_dest, port++) <= 0)
+ test_error("failed to connect()");
+ if (current_index < 0)
+ current_index = nr_keys - 1;
+ if (rnext_index < 0)
+ rnext_index = nr_keys - 1;
+ collection.keys[current_index].used_on_client_tx = 1;
+ collection.keys[rnext_index].used_on_server_tx = 1;
+
+ synchronize_threads(); /* 3: accepted => send data */
+ if (test_client_verify(sk, msg_sz, msg_nr, TEST_TIMEOUT_SEC)) {
+ test_fail("verify failed");
+ close(sk);
+ if (before)
+ test_tcp_ao_counters_free(before);
+ return -1;
+ }
+
+ return sk;
+}
+
+static int start_client(const char *tst_name, unsigned int port,
+ unsigned int nr_keys, int current_index, int rnext_index,
+ struct tcp_ao_counters *before,
+ const size_t msg_sz, const size_t msg_nr)
+{
+ if (init_default_key_collection(nr_keys, true))
+ test_error("Failed to init the key collection");
+
+ return run_client(tst_name, port, nr_keys, current_index,
+ rnext_index, before, msg_sz, msg_nr);
+}
+
+static void end_client(const char *tst_name, int sk, unsigned int nr_keys,
+ int current_index, int rnext_index,
+ struct tcp_ao_counters *start)
+{
+ struct tcp_ao_counters end;
+
+ /* Some application may become dependent on this kernel choice */
+ if (current_index < 0)
+ current_index = nr_keys - 1;
+ if (rnext_index < 0)
+ rnext_index = nr_keys - 1;
+ verify_current_rnext(tst_name, sk,
+ collection.keys[current_index].client_keyid,
+ collection.keys[rnext_index].server_keyid);
+ if (start && test_get_tcp_ao_counters(sk, &end))
+ test_error("test_get_tcp_ao_counters()");
+ verify_keys(tst_name, sk, false, false);
+ synchronize_threads(); /* 4: verify => closed */
+ close(sk);
+ if (start)
+ verify_counters(tst_name, false, false, start, &end);
+ synchronize_threads(); /* 5: counters */
+}
+
+static void try_unmatched_keys(int sk, int *rnext_index)
+{
+ struct test_key *key;
+ unsigned int i = 0;
+ int err;
+
+ do {
+ key = &collection.keys[i];
+ if (!key->matches_server)
+ break;
+ } while (++i < collection.nr_keys);
+ if (key->matches_server)
+ test_error("all keys on client match the server");
+
+ err = test_add_key_cr(sk, key->password, key->len, wrong_addr,
+ 0, key->client_keyid, key->server_keyid,
+ key->maclen, key->alg, 0, 0);
+ if (!err) {
+ test_fail("Added a key with non-matching ip-address for established sk");
+ return;
+ }
+ if (err == -EINVAL)
+ test_ok("Can't add a key with non-matching ip-address for established sk");
+ else
+ test_error("Failed to add a key");
+
+ err = test_add_key_cr(sk, key->password, key->len, this_ip_dest,
+ test_vrf_ifindex,
+ key->client_keyid, key->server_keyid,
+ key->maclen, key->alg, 0, 0);
+ if (!err) {
+ test_fail("Added a key with non-matching VRF for established sk");
+ return;
+ }
+ if (err == -EINVAL)
+ test_ok("Can't add a key with non-matching VRF for established sk");
+ else
+ test_error("Failed to add a key");
+
+ for (i = 0; i < collection.nr_keys; i++) {
+ key = &collection.keys[i];
+ if (!key->matches_client)
+ break;
+ }
+ if (key->matches_client)
+ test_error("all keys on server match the client");
+ if (test_set_key(sk, -1, key->server_keyid))
+ test_error("Can't change the current key");
+ if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
+ test_fail("verify failed");
+ *rnext_index = i;
+}
+
+static int client_non_matching(const char *tst_name, unsigned int port,
+ unsigned int nr_keys,
+ int current_index, int rnext_index,
+ const size_t msg_sz, const size_t msg_nr)
+{
+ unsigned int i;
+
+ if (init_default_key_collection(nr_keys, true))
+ test_error("Failed to init the key collection");
+
+ for (i = 0; i < nr_keys; i++) {
+ /* key (0, 0) matches */
+ collection.keys[i].matches_client = !!((i + 3) % 4);
+ collection.keys[i].matches_server = !!((i + 2) % 4);
+ if (kernel_config_has(KCONFIG_NET_VRF))
+ collection.keys[i].matches_vrf = !!((i + 1) % 4);
+ }
+
+ return run_client(tst_name, port, nr_keys, current_index,
+ rnext_index, NULL, msg_sz, msg_nr);
+}
+
+static void check_current_back(const char *tst_name, unsigned int port,
+ unsigned int nr_keys,
+ unsigned int current_index, unsigned int rnext_index,
+ unsigned int rotate_to_index)
+{
+ struct tcp_ao_counters tmp;
+ int sk;
+
+ sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
+ &tmp, msg_len, nr_packets);
+ if (sk < 0)
+ return;
+ if (test_set_key(sk, collection.keys[rotate_to_index].client_keyid, -1))
+ test_error("Can't change the current key");
+ if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
+ test_fail("verify failed");
+ /* There is a race here: between setting the current_key with
+ * setsockopt(TCP_AO_INFO) and starting to send some data - there
+ * might have been a segment received with the desired
+ * RNext_key set. In turn that would mean that the first outgoing
+ * segment will have the desired current_key (flipped back).
+ * Which is what the user/test wants. As it's racy, skip checking
+ * the counters, yet check what are the resulting current/rnext
+ * keys on both sides.
+ */
+ collection.keys[rotate_to_index].skip_counters_checks = 1;
+
+ end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
+}
+
+static void roll_over_keys(const char *tst_name, unsigned int port,
+ unsigned int nr_keys, unsigned int rotations,
+ unsigned int current_index, unsigned int rnext_index)
+{
+ struct tcp_ao_counters tmp;
+ unsigned int i;
+ int sk;
+
+ sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
+ &tmp, msg_len, nr_packets);
+ if (sk < 0)
+ return;
+ for (i = rnext_index + 1; rotations > 0; i++, rotations--) {
+ if (i >= collection.nr_keys)
+ i = 0;
+ if (test_set_key(sk, -1, collection.keys[i].server_keyid))
+ test_error("Can't change the Rnext key");
+ if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC)) {
+ test_fail("verify failed");
+ close(sk);
+ test_tcp_ao_counters_free(&tmp);
+ return;
+ }
+ verify_current_rnext(tst_name, sk, -1,
+ collection.keys[i].server_keyid);
+ collection.keys[i].used_on_server_tx = 1;
+ synchronize_threads(); /* verify current/rnext */
+ }
+ end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
+}
+
+static void try_client_run(const char *tst_name, unsigned int port,
+ unsigned int nr_keys, int current_index, int rnext_index)
+{
+ struct tcp_ao_counters tmp;
+ int sk;
+
+ sk = start_client(tst_name, port, nr_keys, current_index, rnext_index,
+ &tmp, msg_len, nr_packets);
+ if (sk < 0)
+ return;
+ end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp);
+}
+
+static void try_client_match(const char *tst_name, unsigned int port,
+ unsigned int nr_keys,
+ int current_index, int rnext_index)
+{
+ int sk;
+
+ sk = client_non_matching(tst_name, port, nr_keys, current_index,
+ rnext_index, msg_len, nr_packets);
+ if (sk < 0)
+ return;
+ try_unmatched_keys(sk, &rnext_index);
+ end_client(tst_name, sk, nr_keys, current_index, rnext_index, NULL);
+}
+
+static void *server_fn(void *arg)
+{
+ unsigned int port = test_server_port;
+
+ setup_vrfs();
+ try_server_run("server: Check current/rnext keys unset before connect()",
+ port++, quota, 19, 19);
+ try_server_run("server: Check current/rnext keys set before connect()",
+ port++, quota, 10, 10);
+ try_server_run("server: Check current != rnext keys set before connect()",
+ port++, quota, 5, 10);
+ try_server_run("server: Check current flapping back on peer's RnextKey request",
+ port++, quota * 2, 5, 10);
+ server_rotations("server: Rotate over all different keys", port++,
+ quota, 20, 0, 0);
+ try_server_run("server: Check accept() => established key matching",
+ port++, quota * 2, 0, 0);
+
+ synchronize_threads(); /* don't race to exit: client exits */
+ return NULL;
+}
+
+static void check_established_socket(void)
+{
+ unsigned int port = test_server_port;
+
+ setup_vrfs();
+ try_client_run("client: Check current/rnext keys unset before connect()",
+ port++, 20, -1, -1);
+ try_client_run("client: Check current/rnext keys set before connect()",
+ port++, 20, 10, 10);
+ try_client_run("client: Check current != rnext keys set before connect()",
+ port++, 20, 10, 5);
+ check_current_back("client: Check current flapping back on peer's RnextKey request",
+ port++, 20, 10, 5, 2);
+ roll_over_keys("client: Rotate over all different keys", port++,
+ 20, 20, 0, 0);
+ try_client_match("client: Check connect() => established key matching",
+ port++, 20, 0, 0);
+}
+
+static void *client_fn(void *arg)
+{
+ if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1)
+ test_error("Can't convert ip address %s", TEST_WRONG_IP);
+ check_closed_socket();
+ check_listen_socket();
+ check_established_socket();
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(120, server_fn, client_fn);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/aolib.h b/tools/testing/selftests/net/tcp_ao/lib/aolib.h
new file mode 100644
index 000000000000..fbc7f6111815
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/aolib.h
@@ -0,0 +1,605 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+/*
+ * TCP-AO selftest library. Provides helpers to unshare network
+ * namespaces, create veth, assign ip addresses, set routes,
+ * manipulate socket options, read network counter and etc.
+ * Author: Dmitry Safonov <dima@arista.com>
+ */
+#ifndef _AOLIB_H_
+#define _AOLIB_H_
+
+#include <arpa/inet.h>
+#include <errno.h>
+#include <linux/snmp.h>
+#include <linux/tcp.h>
+#include <netinet/in.h>
+#include <stdarg.h>
+#include <stdbool.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <string.h>
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "../../../../../include/linux/stringify.h"
+#include "../../../../../include/linux/bits.h"
+
+#ifndef SOL_TCP
+/* can't include <netinet/tcp.h> as including <linux/tcp.h> */
+# define SOL_TCP 6 /* TCP level */
+#endif
+
+/* Working around ksft, see the comment in lib/setup.c */
+extern void __test_msg(const char *buf);
+extern void __test_ok(const char *buf);
+extern void __test_fail(const char *buf);
+extern void __test_xfail(const char *buf);
+extern void __test_error(const char *buf);
+extern void __test_skip(const char *buf);
+
+__attribute__((__format__(__printf__, 2, 3)))
+static inline void __test_print(void (*fn)(const char *), const char *fmt, ...)
+{
+#define TEST_MSG_BUFFER_SIZE 4096
+ char buf[TEST_MSG_BUFFER_SIZE];
+ va_list arg;
+
+ va_start(arg, fmt);
+ vsnprintf(buf, sizeof(buf), fmt, arg);
+ va_end(arg);
+ fn(buf);
+}
+
+#define test_print(fmt, ...) \
+ __test_print(__test_msg, "%ld[%s:%u] " fmt "\n", \
+ syscall(SYS_gettid), \
+ __FILE__, __LINE__, ##__VA_ARGS__)
+
+#define test_ok(fmt, ...) \
+ __test_print(__test_ok, fmt "\n", ##__VA_ARGS__)
+#define test_skip(fmt, ...) \
+ __test_print(__test_skip, fmt "\n", ##__VA_ARGS__)
+#define test_xfail(fmt, ...) \
+ __test_print(__test_xfail, fmt "\n", ##__VA_ARGS__)
+
+#define test_fail(fmt, ...) \
+do { \
+ if (errno) \
+ __test_print(__test_fail, fmt ": %m\n", ##__VA_ARGS__); \
+ else \
+ __test_print(__test_fail, fmt "\n", ##__VA_ARGS__); \
+ test_failed(); \
+} while (0)
+
+#define KSFT_FAIL 1
+#define test_error(fmt, ...) \
+do { \
+ if (errno) \
+ __test_print(__test_error, "%ld[%s:%u] " fmt ": %m\n", \
+ syscall(SYS_gettid), __FILE__, __LINE__, \
+ ##__VA_ARGS__); \
+ else \
+ __test_print(__test_error, "%ld[%s:%u] " fmt "\n", \
+ syscall(SYS_gettid), __FILE__, __LINE__, \
+ ##__VA_ARGS__); \
+ exit(KSFT_FAIL); \
+} while (0)
+
+enum test_fault {
+ FAULT_TIMEOUT = 1,
+ FAULT_KEYREJECT,
+ FAULT_PREINSTALL_AO,
+ FAULT_PREINSTALL_MD5,
+ FAULT_POSTINSTALL,
+ FAULT_BUSY,
+ FAULT_CURRNEXT,
+ FAULT_FIXME,
+};
+typedef enum test_fault fault_t;
+
+enum test_needs_kconfig {
+ KCONFIG_NET_NS = 0, /* required */
+ KCONFIG_VETH, /* required */
+ KCONFIG_TCP_AO, /* required */
+ KCONFIG_TCP_MD5, /* optional, for TCP-MD5 features */
+ KCONFIG_NET_VRF, /* optional, for L3/VRF testing */
+ __KCONFIG_LAST__
+};
+extern bool kernel_config_has(enum test_needs_kconfig k);
+extern const char *tests_skip_reason[__KCONFIG_LAST__];
+static inline bool should_skip_test(const char *tst_name,
+ enum test_needs_kconfig k)
+{
+ if (kernel_config_has(k))
+ return false;
+ test_skip("%s: %s", tst_name, tests_skip_reason[k]);
+ return true;
+}
+
+union tcp_addr {
+ struct in_addr a4;
+ struct in6_addr a6;
+};
+
+typedef void *(*thread_fn)(void *);
+extern void test_failed(void);
+extern void __test_init(unsigned int ntests, int family, unsigned int prefix,
+ union tcp_addr addr1, union tcp_addr addr2,
+ thread_fn peer1, thread_fn peer2);
+
+static inline void test_init2(unsigned int ntests,
+ thread_fn peer1, thread_fn peer2,
+ int family, unsigned int prefix,
+ const char *addr1, const char *addr2)
+{
+ union tcp_addr taddr1, taddr2;
+
+ if (inet_pton(family, addr1, &taddr1) != 1)
+ test_error("Can't convert ip address %s", addr1);
+ if (inet_pton(family, addr2, &taddr2) != 1)
+ test_error("Can't convert ip address %s", addr2);
+
+ __test_init(ntests, family, prefix, taddr1, taddr2, peer1, peer2);
+}
+extern void test_add_destructor(void (*d)(void));
+
+/* To adjust optmem socket limit, approximately estimate a number,
+ * that is bigger than sizeof(struct tcp_ao_key).
+ */
+#define KERNEL_TCP_AO_KEY_SZ_ROUND_UP 300
+
+extern void test_set_optmem(size_t value);
+extern size_t test_get_optmem(void);
+
+extern const struct sockaddr_in6 addr_any6;
+extern const struct sockaddr_in addr_any4;
+
+#ifdef IPV6_TEST
+# define __TEST_CLIENT_IP(n) ("2001:db8:" __stringify(n) "::1")
+# define TEST_CLIENT_IP __TEST_CLIENT_IP(1)
+# define TEST_WRONG_IP "2001:db8:253::1"
+# define TEST_SERVER_IP "2001:db8:254::1"
+# define TEST_NETWORK "2001::"
+# define TEST_PREFIX 128
+# define TEST_FAMILY AF_INET6
+# define SOCKADDR_ANY addr_any6
+# define sockaddr_af struct sockaddr_in6
+#else
+# define __TEST_CLIENT_IP(n) ("10.0." __stringify(n) ".1")
+# define TEST_CLIENT_IP __TEST_CLIENT_IP(1)
+# define TEST_WRONG_IP "10.0.253.1"
+# define TEST_SERVER_IP "10.0.254.1"
+# define TEST_NETWORK "10.0.0.0"
+# define TEST_PREFIX 32
+# define TEST_FAMILY AF_INET
+# define SOCKADDR_ANY addr_any4
+# define sockaddr_af struct sockaddr_in
+#endif
+
+static inline union tcp_addr gen_tcp_addr(union tcp_addr net, size_t n)
+{
+ union tcp_addr ret = net;
+
+#ifdef IPV6_TEST
+ ret.a6.s6_addr32[3] = htonl(n & (BIT(32) - 1));
+ ret.a6.s6_addr32[2] = htonl((n >> 32) & (BIT(32) - 1));
+#else
+ ret.a4.s_addr = htonl(ntohl(net.a4.s_addr) + n);
+#endif
+
+ return ret;
+}
+
+static inline void tcp_addr_to_sockaddr_in(void *dest,
+ const union tcp_addr *src,
+ unsigned int port)
+{
+ sockaddr_af *out = dest;
+
+ memset(out, 0, sizeof(*out));
+#ifdef IPV6_TEST
+ out->sin6_family = AF_INET6;
+ out->sin6_port = port;
+ out->sin6_addr = src->a6;
+#else
+ out->sin_family = AF_INET;
+ out->sin_port = port;
+ out->sin_addr = src->a4;
+#endif
+}
+
+static inline void test_init(unsigned int ntests,
+ thread_fn peer1, thread_fn peer2)
+{
+ test_init2(ntests, peer1, peer2, TEST_FAMILY, TEST_PREFIX,
+ TEST_SERVER_IP, TEST_CLIENT_IP);
+}
+extern void synchronize_threads(void);
+extern void switch_ns(int fd);
+
+extern __thread union tcp_addr this_ip_addr;
+extern __thread union tcp_addr this_ip_dest;
+extern int test_family;
+
+extern void randomize_buffer(void *buf, size_t buflen);
+extern int open_netns(void);
+extern int unshare_open_netns(void);
+extern const char veth_name[];
+extern int add_veth(const char *name, int nsfda, int nsfdb);
+extern int add_vrf(const char *name, uint32_t tabid, int ifindex, int nsfd);
+extern int ip_addr_add(const char *intf, int family,
+ union tcp_addr addr, uint8_t prefix);
+extern int ip_route_add(const char *intf, int family,
+ union tcp_addr src, union tcp_addr dst);
+extern int ip_route_add_vrf(const char *intf, int family,
+ union tcp_addr src, union tcp_addr dst,
+ uint8_t vrf);
+extern int link_set_up(const char *intf);
+
+extern const unsigned int test_server_port;
+extern int test_wait_fd(int sk, time_t sec, bool write);
+extern int __test_connect_socket(int sk, const char *device,
+ void *addr, size_t addr_sz, time_t timeout);
+extern int __test_listen_socket(int backlog, void *addr, size_t addr_sz);
+
+static inline int test_listen_socket(const union tcp_addr taddr,
+ unsigned int port, int backlog)
+{
+ sockaddr_af addr;
+
+ tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port));
+ return __test_listen_socket(backlog, (void *)&addr, sizeof(addr));
+}
+
+/*
+ * In order for selftests to work under CONFIG_CRYPTO_FIPS=y,
+ * the password should be loger than 14 bytes, see hmac_setkey()
+ */
+#define TEST_TCP_AO_MINKEYLEN 14
+#define DEFAULT_TEST_PASSWORD "In this hour, I do not believe that any darkness will endure."
+
+#ifndef DEFAULT_TEST_ALGO
+#define DEFAULT_TEST_ALGO "cmac(aes128)"
+#endif
+
+#ifdef IPV6_TEST
+#define DEFAULT_TEST_PREFIX 128
+#else
+#define DEFAULT_TEST_PREFIX 32
+#endif
+
+/*
+ * Timeout on syscalls where failure is not expected.
+ * You may want to rise it if the test machine is very busy.
+ */
+#ifndef TEST_TIMEOUT_SEC
+#define TEST_TIMEOUT_SEC 5
+#endif
+
+/*
+ * Timeout on connect() where a failure is expected.
+ * If set to 0 - kernel will try to retransmit SYN number of times, set in
+ * /proc/sys/net/ipv4/tcp_syn_retries
+ * By default set to 1 to make tests pass faster on non-busy machine.
+ */
+#ifndef TEST_RETRANSMIT_SEC
+#define TEST_RETRANSMIT_SEC 1
+#endif
+
+static inline int _test_connect_socket(int sk, const union tcp_addr taddr,
+ unsigned int port, time_t timeout)
+{
+ sockaddr_af addr;
+
+ tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port));
+ return __test_connect_socket(sk, veth_name,
+ (void *)&addr, sizeof(addr), timeout);
+}
+
+static inline int test_connect_socket(int sk, const union tcp_addr taddr,
+ unsigned int port)
+{
+ return _test_connect_socket(sk, taddr, port, TEST_TIMEOUT_SEC);
+}
+
+extern int __test_set_md5(int sk, void *addr, size_t addr_sz,
+ uint8_t prefix, int vrf, const char *password);
+static inline int test_set_md5(int sk, const union tcp_addr in_addr,
+ uint8_t prefix, int vrf, const char *password)
+{
+ sockaddr_af addr;
+
+ if (prefix > DEFAULT_TEST_PREFIX)
+ prefix = DEFAULT_TEST_PREFIX;
+
+ tcp_addr_to_sockaddr_in(&addr, &in_addr, 0);
+ return __test_set_md5(sk, (void *)&addr, sizeof(addr),
+ prefix, vrf, password);
+}
+
+extern int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg,
+ void *addr, size_t addr_sz, bool set_current, bool set_rnext,
+ uint8_t prefix, uint8_t vrf,
+ uint8_t sndid, uint8_t rcvid, uint8_t maclen,
+ uint8_t keyflags, uint8_t keylen, const char *key);
+
+static inline int test_prepare_key(struct tcp_ao_add *ao,
+ const char *alg, union tcp_addr taddr,
+ bool set_current, bool set_rnext,
+ uint8_t prefix, uint8_t vrf,
+ uint8_t sndid, uint8_t rcvid, uint8_t maclen,
+ uint8_t keyflags, uint8_t keylen, const char *key)
+{
+ sockaddr_af addr;
+
+ tcp_addr_to_sockaddr_in(&addr, &taddr, 0);
+ return test_prepare_key_sockaddr(ao, alg, (void *)&addr, sizeof(addr),
+ set_current, set_rnext, prefix, vrf, sndid, rcvid,
+ maclen, keyflags, keylen, key);
+}
+
+static inline int test_prepare_def_key(struct tcp_ao_add *ao,
+ const char *key, uint8_t keyflags,
+ union tcp_addr in_addr, uint8_t prefix, uint8_t vrf,
+ uint8_t sndid, uint8_t rcvid)
+{
+ if (prefix > DEFAULT_TEST_PREFIX)
+ prefix = DEFAULT_TEST_PREFIX;
+
+ return test_prepare_key(ao, DEFAULT_TEST_ALGO, in_addr, false, false,
+ prefix, vrf, sndid, rcvid, 0, keyflags,
+ strlen(key), key);
+}
+
+extern int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out,
+ void *addr, size_t addr_sz,
+ uint8_t prefix, uint8_t sndid, uint8_t rcvid);
+extern int test_get_ao_info(int sk, struct tcp_ao_info_opt *out);
+extern int test_set_ao_info(int sk, struct tcp_ao_info_opt *in);
+extern int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a,
+ const struct tcp_ao_getsockopt *b);
+extern int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a,
+ const struct tcp_ao_info_opt *b);
+
+static inline int test_verify_socket_key(int sk, struct tcp_ao_add *key)
+{
+ struct tcp_ao_getsockopt key2 = {};
+ int err;
+
+ err = test_get_one_ao(sk, &key2, &key->addr, sizeof(key->addr),
+ key->prefix, key->sndid, key->rcvid);
+ if (err)
+ return err;
+
+ return test_cmp_getsockopt_setsockopt(key, &key2);
+}
+
+static inline int test_add_key_vrf(int sk,
+ const char *key, uint8_t keyflags,
+ union tcp_addr in_addr, uint8_t prefix,
+ uint8_t vrf, uint8_t sndid, uint8_t rcvid)
+{
+ struct tcp_ao_add tmp = {};
+ int err;
+
+ err = test_prepare_def_key(&tmp, key, keyflags, in_addr, prefix,
+ vrf, sndid, rcvid);
+ if (err)
+ return err;
+
+ err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp));
+ if (err < 0)
+ return -errno;
+
+ return test_verify_socket_key(sk, &tmp);
+}
+
+static inline int test_add_key(int sk, const char *key,
+ union tcp_addr in_addr, uint8_t prefix,
+ uint8_t sndid, uint8_t rcvid)
+{
+ return test_add_key_vrf(sk, key, 0, in_addr, prefix, 0, sndid, rcvid);
+}
+
+static inline int test_verify_socket_ao(int sk, struct tcp_ao_info_opt *ao)
+{
+ struct tcp_ao_info_opt ao2 = {};
+ int err;
+
+ err = test_get_ao_info(sk, &ao2);
+ if (err)
+ return err;
+
+ return test_cmp_getsockopt_setsockopt_ao(ao, &ao2);
+}
+
+static inline int test_set_ao_flags(int sk, bool ao_required, bool accept_icmps)
+{
+ struct tcp_ao_info_opt ao = {};
+ int err;
+
+ err = test_get_ao_info(sk, &ao);
+ /* Maybe ao_info wasn't allocated yet */
+ if (err && err != -ENOENT)
+ return err;
+
+ ao.ao_required = !!ao_required;
+ ao.accept_icmps = !!accept_icmps;
+ err = test_set_ao_info(sk, &ao);
+ if (err)
+ return err;
+
+ return test_verify_socket_ao(sk, &ao);
+}
+
+extern ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec);
+extern ssize_t test_client_loop(int sk, char *buf, size_t buf_sz,
+ const size_t msg_len, time_t timeout_sec);
+extern int test_client_verify(int sk, const size_t msg_len, const size_t nr,
+ time_t timeout_sec);
+
+struct tcp_ao_key_counters {
+ uint8_t sndid;
+ uint8_t rcvid;
+ uint64_t pkt_good;
+ uint64_t pkt_bad;
+};
+
+struct tcp_ao_counters {
+ /* per-netns */
+ uint64_t netns_ao_good;
+ uint64_t netns_ao_bad;
+ uint64_t netns_ao_key_not_found;
+ uint64_t netns_ao_required;
+ uint64_t netns_ao_dropped_icmp;
+ /* per-socket */
+ uint64_t ao_info_pkt_good;
+ uint64_t ao_info_pkt_bad;
+ uint64_t ao_info_pkt_key_not_found;
+ uint64_t ao_info_pkt_ao_required;
+ uint64_t ao_info_pkt_dropped_icmp;
+ /* per-key */
+ size_t nr_keys;
+ struct tcp_ao_key_counters *key_cnts;
+};
+extern int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out);
+
+#define TEST_CNT_KEY_GOOD BIT(0)
+#define TEST_CNT_KEY_BAD BIT(1)
+#define TEST_CNT_SOCK_GOOD BIT(2)
+#define TEST_CNT_SOCK_BAD BIT(3)
+#define TEST_CNT_SOCK_KEY_NOT_FOUND BIT(4)
+#define TEST_CNT_SOCK_AO_REQUIRED BIT(5)
+#define TEST_CNT_SOCK_DROPPED_ICMP BIT(6)
+#define TEST_CNT_NS_GOOD BIT(7)
+#define TEST_CNT_NS_BAD BIT(8)
+#define TEST_CNT_NS_KEY_NOT_FOUND BIT(9)
+#define TEST_CNT_NS_AO_REQUIRED BIT(10)
+#define TEST_CNT_NS_DROPPED_ICMP BIT(11)
+typedef uint16_t test_cnt;
+
+#define TEST_CNT_AO_GOOD (TEST_CNT_SOCK_GOOD | TEST_CNT_NS_GOOD)
+#define TEST_CNT_AO_BAD (TEST_CNT_SOCK_BAD | TEST_CNT_NS_BAD)
+#define TEST_CNT_AO_KEY_NOT_FOUND (TEST_CNT_SOCK_KEY_NOT_FOUND | \
+ TEST_CNT_NS_KEY_NOT_FOUND)
+#define TEST_CNT_AO_REQUIRED (TEST_CNT_SOCK_AO_REQUIRED | \
+ TEST_CNT_NS_AO_REQUIRED)
+#define TEST_CNT_AO_DROPPED_ICMP (TEST_CNT_SOCK_DROPPED_ICMP | \
+ TEST_CNT_NS_DROPPED_ICMP)
+#define TEST_CNT_GOOD (TEST_CNT_KEY_GOOD | TEST_CNT_AO_GOOD)
+#define TEST_CNT_BAD (TEST_CNT_KEY_BAD | TEST_CNT_AO_BAD)
+
+extern int __test_tcp_ao_counters_cmp(const char *tst_name,
+ struct tcp_ao_counters *before, struct tcp_ao_counters *after,
+ test_cnt expected);
+extern int test_tcp_ao_key_counters_cmp(const char *tst_name,
+ struct tcp_ao_counters *before, struct tcp_ao_counters *after,
+ test_cnt expected, int sndid, int rcvid);
+extern void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts);
+/*
+ * Frees buffers allocated in test_get_tcp_ao_counters().
+ * The function doesn't expect new keys or keys removed between calls
+ * to test_get_tcp_ao_counters(). Check key counters manually if they
+ * may change.
+ */
+static inline int test_tcp_ao_counters_cmp(const char *tst_name,
+ struct tcp_ao_counters *before,
+ struct tcp_ao_counters *after,
+ test_cnt expected)
+{
+ int ret;
+
+ ret = __test_tcp_ao_counters_cmp(tst_name, before, after, expected);
+ if (ret)
+ goto out;
+ ret = test_tcp_ao_key_counters_cmp(tst_name, before, after,
+ expected, -1, -1);
+out:
+ test_tcp_ao_counters_free(before);
+ test_tcp_ao_counters_free(after);
+ return ret;
+}
+
+struct netstat;
+extern struct netstat *netstat_read(void);
+extern void netstat_free(struct netstat *ns);
+extern void netstat_print_diff(struct netstat *nsa, struct netstat *nsb);
+extern uint64_t netstat_get(struct netstat *ns,
+ const char *name, bool *not_found);
+
+static inline uint64_t netstat_get_one(const char *name, bool *not_found)
+{
+ struct netstat *ns = netstat_read();
+ uint64_t ret;
+
+ ret = netstat_get(ns, name, not_found);
+
+ netstat_free(ns);
+ return ret;
+}
+
+struct tcp_sock_queue {
+ uint32_t seq;
+ void *buf;
+};
+
+struct tcp_sock_state {
+ struct tcp_info info;
+ struct tcp_repair_window trw;
+ struct tcp_sock_queue out;
+ int outq_len; /* output queue size (not sent + not acked) */
+ int outq_nsd_len; /* output queue size (not sent only) */
+ struct tcp_sock_queue in;
+ int inq_len;
+ int mss;
+ int timestamp;
+};
+
+extern void __test_sock_checkpoint(int sk, struct tcp_sock_state *state,
+ void *addr, size_t addr_size);
+static inline void test_sock_checkpoint(int sk, struct tcp_sock_state *state,
+ sockaddr_af *saddr)
+{
+ __test_sock_checkpoint(sk, state, saddr, sizeof(*saddr));
+}
+extern void test_ao_checkpoint(int sk, struct tcp_ao_repair *state);
+extern void __test_sock_restore(int sk, const char *device,
+ struct tcp_sock_state *state,
+ void *saddr, void *daddr, size_t addr_size);
+static inline void test_sock_restore(int sk, struct tcp_sock_state *state,
+ sockaddr_af *saddr,
+ const union tcp_addr daddr,
+ unsigned int dport)
+{
+ sockaddr_af addr;
+
+ tcp_addr_to_sockaddr_in(&addr, &daddr, htons(dport));
+ __test_sock_restore(sk, veth_name, state, saddr, &addr, sizeof(addr));
+}
+extern void test_ao_restore(int sk, struct tcp_ao_repair *state);
+extern void test_sock_state_free(struct tcp_sock_state *state);
+extern void test_enable_repair(int sk);
+extern void test_disable_repair(int sk);
+extern void test_kill_sk(int sk);
+static inline int test_add_repaired_key(int sk,
+ const char *key, uint8_t keyflags,
+ union tcp_addr in_addr, uint8_t prefix,
+ uint8_t sndid, uint8_t rcvid)
+{
+ struct tcp_ao_add tmp = {};
+ int err;
+
+ err = test_prepare_def_key(&tmp, key, keyflags, in_addr, prefix,
+ 0, sndid, rcvid);
+ if (err)
+ return err;
+
+ tmp.set_current = 1;
+ tmp.set_rnext = 1;
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)) < 0)
+ return -errno;
+
+ return test_verify_socket_key(sk, &tmp);
+}
+
+#endif /* _AOLIB_H_ */
diff --git a/tools/testing/selftests/net/tcp_ao/lib/kconfig.c b/tools/testing/selftests/net/tcp_ao/lib/kconfig.c
new file mode 100644
index 000000000000..f279ffc3843b
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/kconfig.c
@@ -0,0 +1,148 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Check what features does the kernel support (where the selftest is running).
+ * Somewhat inspired by CRIU kerndat/kdat kernel features detector.
+ */
+#include <pthread.h>
+#include "aolib.h"
+
+struct kconfig_t {
+ int _errno; /* the returned error if not supported */
+ int (*check_kconfig)(int *error);
+};
+
+static int has_net_ns(int *err)
+{
+ if (access("/proc/self/ns/net", F_OK) < 0) {
+ *err = errno;
+ if (errno == ENOENT)
+ return 0;
+ test_print("Unable to access /proc/self/ns/net: %m");
+ return -errno;
+ }
+ return *err = errno = 0;
+}
+
+static int has_veth(int *err)
+{
+ int orig_netns, ns_a, ns_b;
+
+ orig_netns = open_netns();
+ ns_a = unshare_open_netns();
+ ns_b = unshare_open_netns();
+
+ *err = add_veth("check_veth", ns_a, ns_b);
+
+ switch_ns(orig_netns);
+ close(orig_netns);
+ close(ns_a);
+ close(ns_b);
+ return 0;
+}
+
+static int has_tcp_ao(int *err)
+{
+ struct sockaddr_in addr = {
+ .sin_family = test_family,
+ };
+ struct tcp_ao_add tmp = {};
+ const char *password = DEFAULT_TEST_PASSWORD;
+ int sk, ret = 0;
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0) {
+ test_print("socket(): %m");
+ return -errno;
+ }
+
+ tmp.sndid = 100;
+ tmp.rcvid = 100;
+ tmp.keylen = strlen(password);
+ memcpy(tmp.key, password, strlen(password));
+ strcpy(tmp.alg_name, "hmac(sha1)");
+ memcpy(&tmp.addr, &addr, sizeof(addr));
+ *err = 0;
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)) < 0) {
+ *err = errno;
+ if (errno != ENOPROTOOPT)
+ ret = -errno;
+ }
+ close(sk);
+ return ret;
+}
+
+static int has_tcp_md5(int *err)
+{
+ union tcp_addr addr_any = {};
+ int sk, ret = 0;
+
+ sk = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0) {
+ test_print("socket(): %m");
+ return -errno;
+ }
+
+ /*
+ * Under CONFIG_CRYPTO_FIPS=y it fails with ENOMEM, rather with
+ * anything more descriptive. Oh well.
+ */
+ *err = 0;
+ if (test_set_md5(sk, addr_any, 0, -1, DEFAULT_TEST_PASSWORD)) {
+ *err = errno;
+ if (errno != ENOPROTOOPT && errno == ENOMEM) {
+ test_print("setsockopt(TCP_MD5SIG_EXT): %m");
+ ret = -errno;
+ }
+ }
+ close(sk);
+ return ret;
+}
+
+static int has_vrfs(int *err)
+{
+ int orig_netns, ns_test, ret = 0;
+
+ orig_netns = open_netns();
+ ns_test = unshare_open_netns();
+
+ *err = add_vrf("ksft-check", 55, 101, ns_test);
+ if (*err && *err != -EOPNOTSUPP) {
+ test_print("Failed to add a VRF: %d", *err);
+ ret = *err;
+ }
+
+ switch_ns(orig_netns);
+ close(orig_netns);
+ close(ns_test);
+ return ret;
+}
+
+static pthread_mutex_t kconfig_lock = PTHREAD_MUTEX_INITIALIZER;
+static struct kconfig_t kconfig[__KCONFIG_LAST__] = {
+ { -1, has_net_ns },
+ { -1, has_veth },
+ { -1, has_tcp_ao },
+ { -1, has_tcp_md5 },
+ { -1, has_vrfs },
+};
+
+const char *tests_skip_reason[__KCONFIG_LAST__] = {
+ "Tests require network namespaces support (CONFIG_NET_NS)",
+ "Tests require veth support (CONFIG_VETH)",
+ "Tests require TCP-AO support (CONFIG_TCP_AO)",
+ "setsockopt(TCP_MD5SIG_EXT) is not supported (CONFIG_TCP_MD5)",
+ "VRFs are not supported (CONFIG_NET_VRF)",
+};
+
+bool kernel_config_has(enum test_needs_kconfig k)
+{
+ bool ret;
+
+ pthread_mutex_lock(&kconfig_lock);
+ if (kconfig[k]._errno == -1) {
+ if (kconfig[k].check_kconfig(&kconfig[k]._errno))
+ test_error("Failed to initialize kconfig %u", k);
+ }
+ ret = kconfig[k]._errno == 0;
+ pthread_mutex_unlock(&kconfig_lock);
+ return ret;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/netlink.c b/tools/testing/selftests/net/tcp_ao/lib/netlink.c
new file mode 100644
index 000000000000..7f108493a29a
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/netlink.c
@@ -0,0 +1,413 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Original from tools/testing/selftests/net/ipsec.c */
+#include <linux/netlink.h>
+#include <linux/random.h>
+#include <linux/rtnetlink.h>
+#include <linux/veth.h>
+#include <net/if.h>
+#include <stdint.h>
+#include <string.h>
+#include <sys/socket.h>
+
+#include "aolib.h"
+
+#define MAX_PAYLOAD 2048
+
+static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
+{
+ if (*sock > 0) {
+ seq_nr++;
+ return 0;
+ }
+
+ *sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
+ if (*sock < 0) {
+ test_print("socket(AF_NETLINK)");
+ return -1;
+ }
+
+ randomize_buffer(seq_nr, sizeof(*seq_nr));
+
+ return 0;
+}
+
+static int netlink_check_answer(int sock, bool quite)
+{
+ struct nlmsgerror {
+ struct nlmsghdr hdr;
+ int error;
+ struct nlmsghdr orig_msg;
+ } answer;
+
+ if (recv(sock, &answer, sizeof(answer), 0) < 0) {
+ test_print("recv()");
+ return -1;
+ } else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
+ test_print("expected NLMSG_ERROR, got %d",
+ (int)answer.hdr.nlmsg_type);
+ return -1;
+ } else if (answer.error) {
+ if (!quite) {
+ test_print("NLMSG_ERROR: %d: %s",
+ answer.error, strerror(-answer.error));
+ }
+ return answer.error;
+ }
+
+ return 0;
+}
+
+static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
+{
+ return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
+}
+
+static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
+ unsigned short rta_type, const void *payload, size_t size)
+{
+ /* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
+ struct rtattr *attr = rtattr_hdr(nh);
+ size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
+
+ if (req_sz < nl_size) {
+ test_print("req buf is too small: %zu < %zu", req_sz, nl_size);
+ return -1;
+ }
+ nh->nlmsg_len = nl_size;
+
+ attr->rta_len = RTA_LENGTH(size);
+ attr->rta_type = rta_type;
+ memcpy(RTA_DATA(attr), payload, size);
+
+ return 0;
+}
+
+static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
+ unsigned short rta_type, const void *payload, size_t size)
+{
+ struct rtattr *ret = rtattr_hdr(nh);
+
+ if (rtattr_pack(nh, req_sz, rta_type, payload, size))
+ return 0;
+
+ return ret;
+}
+
+static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
+ unsigned short rta_type)
+{
+ return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
+}
+
+static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
+{
+ char *nlmsg_end = (char *)nh + nh->nlmsg_len;
+
+ attr->rta_len = nlmsg_end - (char *)attr;
+}
+
+static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
+ const char *peer, int ns)
+{
+ struct ifinfomsg pi;
+ struct rtattr *peer_attr;
+
+ memset(&pi, 0, sizeof(pi));
+ pi.ifi_family = AF_UNSPEC;
+ pi.ifi_change = 0xFFFFFFFF;
+
+ peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
+ if (!peer_attr)
+ return -1;
+
+ if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
+ return -1;
+
+ if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
+ return -1;
+
+ rtattr_end(nh, peer_attr);
+
+ return 0;
+}
+
+static int __add_veth(int sock, uint32_t seq, const char *name,
+ int ns_a, int ns_b)
+{
+ uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
+ struct {
+ struct nlmsghdr nh;
+ struct ifinfomsg info;
+ char attrbuf[MAX_PAYLOAD];
+ } req;
+ static const char veth_type[] = "veth";
+ struct rtattr *link_info, *info_data;
+
+ memset(&req, 0, sizeof(req));
+ req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info));
+ req.nh.nlmsg_type = RTM_NEWLINK;
+ req.nh.nlmsg_flags = flags;
+ req.nh.nlmsg_seq = seq;
+ req.info.ifi_family = AF_UNSPEC;
+ req.info.ifi_change = 0xFFFFFFFF;
+
+ if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, name, strlen(name)))
+ return -1;
+
+ if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
+ return -1;
+
+ link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
+ if (!link_info)
+ return -1;
+
+ if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
+ return -1;
+
+ info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
+ if (!info_data)
+ return -1;
+
+ if (veth_pack_peerb(&req.nh, sizeof(req), name, ns_b))
+ return -1;
+
+ rtattr_end(&req.nh, info_data);
+ rtattr_end(&req.nh, link_info);
+
+ if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
+ test_print("send()");
+ return -1;
+ }
+ return netlink_check_answer(sock, false);
+}
+
+int add_veth(const char *name, int nsfda, int nsfdb)
+{
+ int route_sock = -1, ret;
+ uint32_t route_seq;
+
+ if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
+ test_error("Failed to open netlink route socket\n");
+
+ ret = __add_veth(route_sock, route_seq++, name, nsfda, nsfdb);
+ close(route_sock);
+ return ret;
+}
+
+static int __ip_addr_add(int sock, uint32_t seq, const char *intf,
+ int family, union tcp_addr addr, uint8_t prefix)
+{
+ uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
+ struct {
+ struct nlmsghdr nh;
+ struct ifaddrmsg info;
+ char attrbuf[MAX_PAYLOAD];
+ } req;
+ size_t addr_len = (family == AF_INET) ? sizeof(struct in_addr) :
+ sizeof(struct in6_addr);
+
+ memset(&req, 0, sizeof(req));
+ req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info));
+ req.nh.nlmsg_type = RTM_NEWADDR;
+ req.nh.nlmsg_flags = flags;
+ req.nh.nlmsg_seq = seq;
+ req.info.ifa_family = family;
+ req.info.ifa_prefixlen = prefix;
+ req.info.ifa_index = if_nametoindex(intf);
+ req.info.ifa_flags = IFA_F_NODAD;
+
+ if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, addr_len))
+ return -1;
+
+ if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
+ test_print("send()");
+ return -1;
+ }
+ return netlink_check_answer(sock, true);
+}
+
+int ip_addr_add(const char *intf, int family,
+ union tcp_addr addr, uint8_t prefix)
+{
+ int route_sock = -1, ret;
+ uint32_t route_seq;
+
+ if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
+ test_error("Failed to open netlink route socket\n");
+
+ ret = __ip_addr_add(route_sock, route_seq++, intf,
+ family, addr, prefix);
+
+ close(route_sock);
+ return ret;
+}
+
+static int __ip_route_add(int sock, uint32_t seq, const char *intf, int family,
+ union tcp_addr src, union tcp_addr dst, uint8_t vrf)
+{
+ struct {
+ struct nlmsghdr nh;
+ struct rtmsg rt;
+ char attrbuf[MAX_PAYLOAD];
+ } req;
+ unsigned int index = if_nametoindex(intf);
+ size_t addr_len = (family == AF_INET) ? sizeof(struct in_addr) :
+ sizeof(struct in6_addr);
+
+ memset(&req, 0, sizeof(req));
+ req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.rt));
+ req.nh.nlmsg_type = RTM_NEWROUTE;
+ req.nh.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
+ req.nh.nlmsg_seq = seq;
+ req.rt.rtm_family = family;
+ req.rt.rtm_dst_len = (family == AF_INET) ? 32 : 128;
+ req.rt.rtm_table = vrf;
+ req.rt.rtm_protocol = RTPROT_BOOT;
+ req.rt.rtm_scope = RT_SCOPE_UNIVERSE;
+ req.rt.rtm_type = RTN_UNICAST;
+
+ if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, addr_len))
+ return -1;
+
+ if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, addr_len))
+ return -1;
+
+ if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
+ return -1;
+
+ if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
+ test_print("send()");
+ return -1;
+ }
+
+ return netlink_check_answer(sock, true);
+}
+
+int ip_route_add_vrf(const char *intf, int family,
+ union tcp_addr src, union tcp_addr dst, uint8_t vrf)
+{
+ int route_sock = -1, ret;
+ uint32_t route_seq;
+
+ if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
+ test_error("Failed to open netlink route socket\n");
+
+ ret = __ip_route_add(route_sock, route_seq++, intf,
+ family, src, dst, vrf);
+
+ close(route_sock);
+ return ret;
+}
+
+int ip_route_add(const char *intf, int family,
+ union tcp_addr src, union tcp_addr dst)
+{
+ return ip_route_add_vrf(intf, family, src, dst, RT_TABLE_MAIN);
+}
+
+static int __link_set_up(int sock, uint32_t seq, const char *intf)
+{
+ struct {
+ struct nlmsghdr nh;
+ struct ifinfomsg info;
+ char attrbuf[MAX_PAYLOAD];
+ } req;
+
+ memset(&req, 0, sizeof(req));
+ req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info));
+ req.nh.nlmsg_type = RTM_NEWLINK;
+ req.nh.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
+ req.nh.nlmsg_seq = seq;
+ req.info.ifi_family = AF_UNSPEC;
+ req.info.ifi_change = 0xFFFFFFFF;
+ req.info.ifi_index = if_nametoindex(intf);
+ req.info.ifi_flags = IFF_UP;
+ req.info.ifi_change = IFF_UP;
+
+ if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
+ test_print("send()");
+ return -1;
+ }
+ return netlink_check_answer(sock, false);
+}
+
+int link_set_up(const char *intf)
+{
+ int route_sock = -1, ret;
+ uint32_t route_seq;
+
+ if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
+ test_error("Failed to open netlink route socket\n");
+
+ ret = __link_set_up(route_sock, route_seq++, intf);
+
+ close(route_sock);
+ return ret;
+}
+
+static int __add_vrf(int sock, uint32_t seq, const char *name,
+ uint32_t tabid, int ifindex, int nsfd)
+{
+ uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
+ struct {
+ struct nlmsghdr nh;
+ struct ifinfomsg info;
+ char attrbuf[MAX_PAYLOAD];
+ } req;
+ static const char vrf_type[] = "vrf";
+ struct rtattr *link_info, *info_data;
+
+ memset(&req, 0, sizeof(req));
+ req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info));
+ req.nh.nlmsg_type = RTM_NEWLINK;
+ req.nh.nlmsg_flags = flags;
+ req.nh.nlmsg_seq = seq;
+ req.info.ifi_family = AF_UNSPEC;
+ req.info.ifi_change = 0xFFFFFFFF;
+ req.info.ifi_index = ifindex;
+
+ if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, name, strlen(name)))
+ return -1;
+
+ if (nsfd >= 0)
+ if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD,
+ &nsfd, sizeof(nsfd)))
+ return -1;
+
+ link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
+ if (!link_info)
+ return -1;
+
+ if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, vrf_type, sizeof(vrf_type)))
+ return -1;
+
+ info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
+ if (!info_data)
+ return -1;
+
+ if (rtattr_pack(&req.nh, sizeof(req), IFLA_VRF_TABLE,
+ &tabid, sizeof(tabid)))
+ return -1;
+
+ rtattr_end(&req.nh, info_data);
+ rtattr_end(&req.nh, link_info);
+
+ if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
+ test_print("send()");
+ return -1;
+ }
+ return netlink_check_answer(sock, true);
+}
+
+int add_vrf(const char *name, uint32_t tabid, int ifindex, int nsfd)
+{
+ int route_sock = -1, ret;
+ uint32_t route_seq;
+
+ if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
+ test_error("Failed to open netlink route socket\n");
+
+ ret = __add_vrf(route_sock, route_seq++, name, tabid, ifindex, nsfd);
+ close(route_sock);
+ return ret;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/proc.c b/tools/testing/selftests/net/tcp_ao/lib/proc.c
new file mode 100644
index 000000000000..2fb6dd8adba6
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/proc.c
@@ -0,0 +1,273 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <inttypes.h>
+#include <pthread.h>
+#include <stdio.h>
+#include "../../../../../include/linux/compiler.h"
+#include "../../../../../include/linux/kernel.h"
+#include "aolib.h"
+
+struct netstat_counter {
+ uint64_t val;
+ char *name;
+};
+
+struct netstat {
+ char *header_name;
+ struct netstat *next;
+ size_t counters_nr;
+ struct netstat_counter *counters;
+};
+
+static struct netstat *lookup_type(struct netstat *ns,
+ const char *type, size_t len)
+{
+ while (ns != NULL) {
+ size_t cmp = max(len, strlen(ns->header_name));
+
+ if (!strncmp(ns->header_name, type, cmp))
+ return ns;
+ ns = ns->next;
+ }
+ return NULL;
+}
+
+static struct netstat *lookup_get(struct netstat *ns,
+ const char *type, const size_t len)
+{
+ struct netstat *ret;
+
+ ret = lookup_type(ns, type, len);
+ if (ret != NULL)
+ return ret;
+
+ ret = malloc(sizeof(struct netstat));
+ if (!ret)
+ test_error("malloc()");
+
+ ret->header_name = strndup(type, len);
+ if (ret->header_name == NULL)
+ test_error("strndup()");
+ ret->next = ns;
+ ret->counters_nr = 0;
+ ret->counters = NULL;
+
+ return ret;
+}
+
+static struct netstat *lookup_get_column(struct netstat *ns, const char *line)
+{
+ char *column;
+
+ column = strchr(line, ':');
+ if (!column)
+ test_error("can't parse netstat file");
+
+ return lookup_get(ns, line, column - line);
+}
+
+static void netstat_read_type(FILE *fnetstat, struct netstat **dest, char *line)
+{
+ struct netstat *type = lookup_get_column(*dest, line);
+ const char *pos = line;
+ size_t i, nr_elems = 0;
+ char tmp;
+
+ while ((pos = strchr(pos, ' '))) {
+ nr_elems++;
+ pos++;
+ }
+
+ *dest = type;
+ type->counters = reallocarray(type->counters,
+ type->counters_nr + nr_elems,
+ sizeof(struct netstat_counter));
+ if (!type->counters)
+ test_error("reallocarray()");
+
+ pos = strchr(line, ' ') + 1;
+
+ if (fscanf(fnetstat, type->header_name) == EOF)
+ test_error("fscanf(%s)", type->header_name);
+ if (fread(&tmp, 1, 1, fnetstat) != 1 || tmp != ':')
+ test_error("Unexpected netstat format (%c)", tmp);
+
+ for (i = type->counters_nr; i < type->counters_nr + nr_elems; i++) {
+ struct netstat_counter *nc = &type->counters[i];
+ const char *new_pos = strchr(pos, ' ');
+ const char *fmt = " %" PRIu64;
+
+ if (new_pos == NULL)
+ new_pos = strchr(pos, '\n');
+
+ nc->name = strndup(pos, new_pos - pos);
+ if (nc->name == NULL)
+ test_error("strndup()");
+
+ if (unlikely(!strcmp(nc->name, "MaxConn")))
+ fmt = " %" PRId64; /* MaxConn is signed, RFC 2012 */
+ if (fscanf(fnetstat, fmt, &nc->val) != 1)
+ test_error("fscanf(%s)", nc->name);
+ pos = new_pos + 1;
+ }
+ type->counters_nr += nr_elems;
+
+ if (fread(&tmp, 1, 1, fnetstat) != 1 || tmp != '\n')
+ test_error("Unexpected netstat format");
+}
+
+static const char *snmp6_name = "Snmp6";
+static void snmp6_read(FILE *fnetstat, struct netstat **dest)
+{
+ struct netstat *type = lookup_get(*dest, snmp6_name, strlen(snmp6_name));
+ char *counter_name;
+ size_t i;
+
+ for (i = type->counters_nr;; i++) {
+ struct netstat_counter *nc;
+ uint64_t counter;
+
+ if (fscanf(fnetstat, "%ms", &counter_name) == EOF)
+ break;
+ if (fscanf(fnetstat, "%" PRIu64, &counter) == EOF)
+ test_error("Unexpected snmp6 format");
+ type->counters = reallocarray(type->counters, i + 1,
+ sizeof(struct netstat_counter));
+ if (!type->counters)
+ test_error("reallocarray()");
+ nc = &type->counters[i];
+ nc->name = counter_name;
+ nc->val = counter;
+ }
+ type->counters_nr = i;
+ *dest = type;
+}
+
+struct netstat *netstat_read(void)
+{
+ struct netstat *ret = 0;
+ size_t line_sz = 0;
+ char *line = NULL;
+ FILE *fnetstat;
+
+ /*
+ * Opening thread-self instead of /proc/net/... as the latter
+ * points to /proc/self/net/ which instantiates thread-leader's
+ * net-ns, see:
+ * commit 155134fef2b6 ("Revert "proc: Point /proc/{mounts,net} at..")
+ */
+ errno = 0;
+ fnetstat = fopen("/proc/thread-self/net/netstat", "r");
+ if (fnetstat == NULL)
+ test_error("failed to open /proc/net/netstat");
+
+ while (getline(&line, &line_sz, fnetstat) != -1)
+ netstat_read_type(fnetstat, &ret, line);
+ fclose(fnetstat);
+
+ errno = 0;
+ fnetstat = fopen("/proc/thread-self/net/snmp", "r");
+ if (fnetstat == NULL)
+ test_error("failed to open /proc/net/snmp");
+
+ while (getline(&line, &line_sz, fnetstat) != -1)
+ netstat_read_type(fnetstat, &ret, line);
+ fclose(fnetstat);
+
+ errno = 0;
+ fnetstat = fopen("/proc/thread-self/net/snmp6", "r");
+ if (fnetstat == NULL)
+ test_error("failed to open /proc/net/snmp6");
+
+ snmp6_read(fnetstat, &ret);
+ fclose(fnetstat);
+
+ free(line);
+ return ret;
+}
+
+void netstat_free(struct netstat *ns)
+{
+ while (ns != NULL) {
+ struct netstat *prev = ns;
+ size_t i;
+
+ free(ns->header_name);
+ for (i = 0; i < ns->counters_nr; i++)
+ free(ns->counters[i].name);
+ free(ns->counters);
+ ns = ns->next;
+ free(prev);
+ }
+}
+
+static inline void
+__netstat_print_diff(uint64_t a, struct netstat *nsb, size_t i)
+{
+ if (unlikely(!strcmp(nsb->header_name, "MaxConn"))) {
+ test_print("%8s %25s: %" PRId64 " => %" PRId64,
+ nsb->header_name, nsb->counters[i].name,
+ a, nsb->counters[i].val);
+ return;
+ }
+
+ test_print("%8s %25s: %" PRIu64 " => %" PRIu64, nsb->header_name,
+ nsb->counters[i].name, a, nsb->counters[i].val);
+}
+
+void netstat_print_diff(struct netstat *nsa, struct netstat *nsb)
+{
+ size_t i, j;
+
+ while (nsb != NULL) {
+ if (unlikely(strcmp(nsb->header_name, nsa->header_name))) {
+ for (i = 0; i < nsb->counters_nr; i++)
+ __netstat_print_diff(0, nsb, i);
+ nsb = nsb->next;
+ continue;
+ }
+
+ if (nsb->counters_nr < nsa->counters_nr)
+ test_error("Unexpected: some counters disappeared!");
+
+ for (j = 0, i = 0; i < nsb->counters_nr; i++) {
+ if (strcmp(nsb->counters[i].name, nsa->counters[j].name)) {
+ __netstat_print_diff(0, nsb, i);
+ continue;
+ }
+
+ if (nsa->counters[j].val == nsb->counters[i].val) {
+ j++;
+ continue;
+ }
+
+ __netstat_print_diff(nsa->counters[j].val, nsb, i);
+ j++;
+ }
+ if (j != nsa->counters_nr)
+ test_error("Unexpected: some counters disappeared!");
+
+ nsb = nsb->next;
+ nsa = nsa->next;
+ }
+}
+
+uint64_t netstat_get(struct netstat *ns, const char *name, bool *not_found)
+{
+ if (not_found)
+ *not_found = false;
+
+ while (ns != NULL) {
+ size_t i;
+
+ for (i = 0; i < ns->counters_nr; i++) {
+ if (!strcmp(name, ns->counters[i].name))
+ return ns->counters[i].val;
+ }
+
+ ns = ns->next;
+ }
+
+ if (not_found)
+ *not_found = true;
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/repair.c b/tools/testing/selftests/net/tcp_ao/lib/repair.c
new file mode 100644
index 000000000000..9893b3ba69f5
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/repair.c
@@ -0,0 +1,254 @@
+// SPDX-License-Identifier: GPL-2.0
+/* This is over-simplified TCP_REPAIR for TCP_ESTABLISHED sockets
+ * It tests that TCP-AO enabled connection can be restored.
+ * For the proper socket repair see:
+ * https://github.com/checkpoint-restore/criu/blob/criu-dev/soccr/soccr.h
+ */
+#include <fcntl.h>
+#include <linux/sockios.h>
+#include <sys/ioctl.h>
+#include "aolib.h"
+
+#ifndef TCPOPT_MAXSEG
+# define TCPOPT_MAXSEG 2
+#endif
+#ifndef TCPOPT_WINDOW
+# define TCPOPT_WINDOW 3
+#endif
+#ifndef TCPOPT_SACK_PERMITTED
+# define TCPOPT_SACK_PERMITTED 4
+#endif
+#ifndef TCPOPT_TIMESTAMP
+# define TCPOPT_TIMESTAMP 8
+#endif
+
+enum {
+ TCP_ESTABLISHED = 1,
+ TCP_SYN_SENT,
+ TCP_SYN_RECV,
+ TCP_FIN_WAIT1,
+ TCP_FIN_WAIT2,
+ TCP_TIME_WAIT,
+ TCP_CLOSE,
+ TCP_CLOSE_WAIT,
+ TCP_LAST_ACK,
+ TCP_LISTEN,
+ TCP_CLOSING, /* Now a valid state */
+ TCP_NEW_SYN_RECV,
+
+ TCP_MAX_STATES /* Leave at the end! */
+};
+
+static void test_sock_checkpoint_queue(int sk, int queue, int qlen,
+ struct tcp_sock_queue *q)
+{
+ socklen_t len;
+ int ret;
+
+ if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue)))
+ test_error("setsockopt(TCP_REPAIR_QUEUE)");
+
+ len = sizeof(q->seq);
+ ret = getsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &q->seq, &len);
+ if (ret || len != sizeof(q->seq))
+ test_error("getsockopt(TCP_QUEUE_SEQ): %d", (int)len);
+
+ if (!qlen) {
+ q->buf = NULL;
+ return;
+ }
+
+ q->buf = malloc(qlen);
+ if (q->buf == NULL)
+ test_error("malloc()");
+ ret = recv(sk, q->buf, qlen, MSG_PEEK | MSG_DONTWAIT);
+ if (ret != qlen)
+ test_error("recv(%d): %d", qlen, ret);
+}
+
+void __test_sock_checkpoint(int sk, struct tcp_sock_state *state,
+ void *addr, size_t addr_size)
+{
+ socklen_t len = sizeof(state->info);
+ int ret;
+
+ memset(state, 0, sizeof(*state));
+
+ ret = getsockopt(sk, SOL_TCP, TCP_INFO, &state->info, &len);
+ if (ret || len != sizeof(state->info))
+ test_error("getsockopt(TCP_INFO): %d", (int)len);
+
+ len = addr_size;
+ if (getsockname(sk, addr, &len) || len != addr_size)
+ test_error("getsockname(): %d", (int)len);
+
+ len = sizeof(state->trw);
+ ret = getsockopt(sk, SOL_TCP, TCP_REPAIR_WINDOW, &state->trw, &len);
+ if (ret || len != sizeof(state->trw))
+ test_error("getsockopt(TCP_REPAIR_WINDOW): %d", (int)len);
+
+ if (ioctl(sk, SIOCOUTQ, &state->outq_len))
+ test_error("ioctl(SIOCOUTQ)");
+
+ if (ioctl(sk, SIOCOUTQNSD, &state->outq_nsd_len))
+ test_error("ioctl(SIOCOUTQNSD)");
+ test_sock_checkpoint_queue(sk, TCP_SEND_QUEUE, state->outq_len, &state->out);
+
+ if (ioctl(sk, SIOCINQ, &state->inq_len))
+ test_error("ioctl(SIOCINQ)");
+ test_sock_checkpoint_queue(sk, TCP_RECV_QUEUE, state->inq_len, &state->in);
+
+ if (state->info.tcpi_state == TCP_CLOSE)
+ state->outq_len = state->outq_nsd_len = 0;
+
+ len = sizeof(state->mss);
+ ret = getsockopt(sk, SOL_TCP, TCP_MAXSEG, &state->mss, &len);
+ if (ret || len != sizeof(state->mss))
+ test_error("getsockopt(TCP_MAXSEG): %d", (int)len);
+
+ len = sizeof(state->timestamp);
+ ret = getsockopt(sk, SOL_TCP, TCP_TIMESTAMP, &state->timestamp, &len);
+ if (ret || len != sizeof(state->timestamp))
+ test_error("getsockopt(TCP_TIMESTAMP): %d", (int)len);
+}
+
+void test_ao_checkpoint(int sk, struct tcp_ao_repair *state)
+{
+ socklen_t len = sizeof(*state);
+ int ret;
+
+ memset(state, 0, sizeof(*state));
+
+ ret = getsockopt(sk, SOL_TCP, TCP_AO_REPAIR, state, &len);
+ if (ret || len != sizeof(*state))
+ test_error("getsockopt(TCP_AO_REPAIR): %d", (int)len);
+}
+
+static void test_sock_restore_seq(int sk, int queue, uint32_t seq)
+{
+ if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue)))
+ test_error("setsockopt(TCP_REPAIR_QUEUE)");
+
+ if (setsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &seq, sizeof(seq)))
+ test_error("setsockopt(TCP_QUEUE_SEQ)");
+}
+
+static void test_sock_restore_queue(int sk, int queue, void *buf, int len)
+{
+ int chunk = len;
+ size_t off = 0;
+
+ if (len == 0)
+ return;
+
+ if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue)))
+ test_error("setsockopt(TCP_REPAIR_QUEUE)");
+
+ do {
+ int ret;
+
+ ret = send(sk, buf + off, chunk, 0);
+ if (ret <= 0) {
+ if (chunk > 1024) {
+ chunk >>= 1;
+ continue;
+ }
+ test_error("send()");
+ }
+ off += ret;
+ len -= ret;
+ } while (len > 0);
+}
+
+void __test_sock_restore(int sk, const char *device,
+ struct tcp_sock_state *state,
+ void *saddr, void *daddr, size_t addr_size)
+{
+ struct tcp_repair_opt opts[4];
+ unsigned int opt_nr = 0;
+ long flags;
+
+ if (bind(sk, saddr, addr_size))
+ test_error("bind()");
+
+ flags = fcntl(sk, F_GETFL);
+ if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
+ test_error("fcntl()");
+
+ test_sock_restore_seq(sk, TCP_RECV_QUEUE, state->in.seq - state->inq_len);
+ test_sock_restore_seq(sk, TCP_SEND_QUEUE, state->out.seq - state->outq_len);
+
+ if (device != NULL && setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE,
+ device, strlen(device) + 1))
+ test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
+
+ if (connect(sk, daddr, addr_size))
+ test_error("connect()");
+
+ if (state->info.tcpi_options & TCPI_OPT_SACK) {
+ opts[opt_nr].opt_code = TCPOPT_SACK_PERMITTED;
+ opts[opt_nr].opt_val = 0;
+ opt_nr++;
+ }
+ if (state->info.tcpi_options & TCPI_OPT_WSCALE) {
+ opts[opt_nr].opt_code = TCPOPT_WINDOW;
+ opts[opt_nr].opt_val = state->info.tcpi_snd_wscale +
+ (state->info.tcpi_rcv_wscale << 16);
+ opt_nr++;
+ }
+ if (state->info.tcpi_options & TCPI_OPT_TIMESTAMPS) {
+ opts[opt_nr].opt_code = TCPOPT_TIMESTAMP;
+ opts[opt_nr].opt_val = 0;
+ opt_nr++;
+ }
+ opts[opt_nr].opt_code = TCPOPT_MAXSEG;
+ opts[opt_nr].opt_val = state->mss;
+ opt_nr++;
+
+ if (setsockopt(sk, SOL_TCP, TCP_REPAIR_OPTIONS, opts, opt_nr * sizeof(opts[0])))
+ test_error("setsockopt(TCP_REPAIR_OPTIONS)");
+
+ if (state->info.tcpi_options & TCPI_OPT_TIMESTAMPS) {
+ if (setsockopt(sk, SOL_TCP, TCP_TIMESTAMP,
+ &state->timestamp, opt_nr * sizeof(opts[0])))
+ test_error("setsockopt(TCP_TIMESTAMP)");
+ }
+ test_sock_restore_queue(sk, TCP_RECV_QUEUE, state->in.buf, state->inq_len);
+ test_sock_restore_queue(sk, TCP_SEND_QUEUE, state->out.buf, state->outq_len);
+ if (setsockopt(sk, SOL_TCP, TCP_REPAIR_WINDOW, &state->trw, sizeof(state->trw)))
+ test_error("setsockopt(TCP_REPAIR_WINDOW)");
+}
+
+void test_ao_restore(int sk, struct tcp_ao_repair *state)
+{
+ if (setsockopt(sk, SOL_TCP, TCP_AO_REPAIR, state, sizeof(*state)))
+ test_error("setsockopt(TCP_AO_REPAIR)");
+}
+
+void test_sock_state_free(struct tcp_sock_state *state)
+{
+ free(state->out.buf);
+ free(state->in.buf);
+}
+
+void test_enable_repair(int sk)
+{
+ int val = TCP_REPAIR_ON;
+
+ if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
+ test_error("setsockopt(TCP_REPAIR)");
+}
+
+void test_disable_repair(int sk)
+{
+ int val = TCP_REPAIR_OFF_NO_WP;
+
+ if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val)))
+ test_error("setsockopt(TCP_REPAIR)");
+}
+
+void test_kill_sk(int sk)
+{
+ test_enable_repair(sk);
+ close(sk);
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/setup.c b/tools/testing/selftests/net/tcp_ao/lib/setup.c
new file mode 100644
index 000000000000..92276f916f2f
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/setup.c
@@ -0,0 +1,361 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <fcntl.h>
+#include <pthread.h>
+#include <sched.h>
+#include <signal.h>
+#include "aolib.h"
+
+/*
+ * Can't be included in the header: it defines static variables which
+ * will be unique to every object. Let's include it only once here.
+ */
+#include "../../../kselftest.h"
+
+/* Prevent overriding of one thread's output by another */
+static pthread_mutex_t ksft_print_lock = PTHREAD_MUTEX_INITIALIZER;
+
+void __test_msg(const char *buf)
+{
+ pthread_mutex_lock(&ksft_print_lock);
+ ksft_print_msg(buf);
+ pthread_mutex_unlock(&ksft_print_lock);
+}
+void __test_ok(const char *buf)
+{
+ pthread_mutex_lock(&ksft_print_lock);
+ ksft_test_result_pass(buf);
+ pthread_mutex_unlock(&ksft_print_lock);
+}
+void __test_fail(const char *buf)
+{
+ pthread_mutex_lock(&ksft_print_lock);
+ ksft_test_result_fail(buf);
+ pthread_mutex_unlock(&ksft_print_lock);
+}
+void __test_xfail(const char *buf)
+{
+ pthread_mutex_lock(&ksft_print_lock);
+ ksft_test_result_xfail(buf);
+ pthread_mutex_unlock(&ksft_print_lock);
+}
+void __test_error(const char *buf)
+{
+ pthread_mutex_lock(&ksft_print_lock);
+ ksft_test_result_error(buf);
+ pthread_mutex_unlock(&ksft_print_lock);
+}
+void __test_skip(const char *buf)
+{
+ pthread_mutex_lock(&ksft_print_lock);
+ ksft_test_result_skip(buf);
+ pthread_mutex_unlock(&ksft_print_lock);
+}
+
+static volatile int failed;
+static volatile int skipped;
+
+void test_failed(void)
+{
+ failed = 1;
+}
+
+static void test_exit(void)
+{
+ if (failed) {
+ ksft_exit_fail();
+ } else if (skipped) {
+ /* ksft_exit_skip() is different from ksft_exit_*() */
+ ksft_print_cnts();
+ exit(KSFT_SKIP);
+ } else {
+ ksft_exit_pass();
+ }
+}
+
+struct dlist_t {
+ void (*destruct)(void);
+ struct dlist_t *next;
+};
+static struct dlist_t *destructors_list;
+
+void test_add_destructor(void (*d)(void))
+{
+ struct dlist_t *p;
+
+ p = malloc(sizeof(struct dlist_t));
+ if (p == NULL)
+ test_error("malloc() failed");
+
+ p->next = destructors_list;
+ p->destruct = d;
+ destructors_list = p;
+}
+
+static void test_destructor(void) __attribute__((destructor));
+static void test_destructor(void)
+{
+ while (destructors_list) {
+ struct dlist_t *p = destructors_list->next;
+
+ destructors_list->destruct();
+ free(destructors_list);
+ destructors_list = p;
+ }
+ test_exit();
+}
+
+static void sig_int(int signo)
+{
+ test_error("Caught SIGINT - exiting");
+}
+
+int open_netns(void)
+{
+ const char *netns_path = "/proc/self/ns/net";
+ int fd;
+
+ fd = open(netns_path, O_RDONLY);
+ if (fd < 0)
+ test_error("open(%s)", netns_path);
+ return fd;
+}
+
+int unshare_open_netns(void)
+{
+ if (unshare(CLONE_NEWNET) != 0)
+ test_error("unshare()");
+
+ return open_netns();
+}
+
+void switch_ns(int fd)
+{
+ if (setns(fd, CLONE_NEWNET))
+ test_error("setns()");
+}
+
+int switch_save_ns(int new_ns)
+{
+ int ret = open_netns();
+
+ switch_ns(new_ns);
+ return ret;
+}
+
+static int nsfd_outside = -1;
+static int nsfd_parent = -1;
+static int nsfd_child = -1;
+const char veth_name[] = "ktst-veth";
+
+static void init_namespaces(void)
+{
+ nsfd_outside = open_netns();
+ nsfd_parent = unshare_open_netns();
+ nsfd_child = unshare_open_netns();
+}
+
+static void link_init(const char *veth, int family, uint8_t prefix,
+ union tcp_addr addr, union tcp_addr dest)
+{
+ if (link_set_up(veth))
+ test_error("Failed to set link up");
+ if (ip_addr_add(veth, family, addr, prefix))
+ test_error("Failed to add ip address");
+ if (ip_route_add(veth, family, addr, dest))
+ test_error("Failed to add route");
+}
+
+static unsigned int nr_threads = 1;
+
+static pthread_mutex_t sync_lock = PTHREAD_MUTEX_INITIALIZER;
+static pthread_cond_t sync_cond = PTHREAD_COND_INITIALIZER;
+static volatile unsigned int stage_threads[2];
+static volatile unsigned int stage_nr;
+
+/* synchronize all threads in the same stage */
+void synchronize_threads(void)
+{
+ unsigned int q = stage_nr;
+
+ pthread_mutex_lock(&sync_lock);
+ stage_threads[q]++;
+ if (stage_threads[q] == nr_threads) {
+ stage_nr ^= 1;
+ stage_threads[stage_nr] = 0;
+ pthread_cond_signal(&sync_cond);
+ }
+ while (stage_threads[q] < nr_threads)
+ pthread_cond_wait(&sync_cond, &sync_lock);
+ pthread_mutex_unlock(&sync_lock);
+}
+
+__thread union tcp_addr this_ip_addr;
+__thread union tcp_addr this_ip_dest;
+int test_family;
+
+struct new_pthread_arg {
+ thread_fn func;
+ union tcp_addr my_ip;
+ union tcp_addr dest_ip;
+};
+static void *new_pthread_entry(void *arg)
+{
+ struct new_pthread_arg *p = arg;
+
+ this_ip_addr = p->my_ip;
+ this_ip_dest = p->dest_ip;
+ p->func(NULL); /* shouldn't return */
+ exit(KSFT_FAIL);
+}
+
+static void __test_skip_all(const char *msg)
+{
+ ksft_set_plan(1);
+ ksft_print_header();
+ skipped = 1;
+ test_skip("%s", msg);
+ exit(KSFT_SKIP);
+}
+
+void __test_init(unsigned int ntests, int family, unsigned int prefix,
+ union tcp_addr addr1, union tcp_addr addr2,
+ thread_fn peer1, thread_fn peer2)
+{
+ struct sigaction sa = {
+ .sa_handler = sig_int,
+ .sa_flags = SA_RESTART,
+ };
+ time_t seed = time(NULL);
+
+ sigemptyset(&sa.sa_mask);
+ if (sigaction(SIGINT, &sa, NULL))
+ test_error("Can't set SIGINT handler");
+
+ test_family = family;
+ if (!kernel_config_has(KCONFIG_NET_NS))
+ __test_skip_all(tests_skip_reason[KCONFIG_NET_NS]);
+ if (!kernel_config_has(KCONFIG_VETH))
+ __test_skip_all(tests_skip_reason[KCONFIG_VETH]);
+ if (!kernel_config_has(KCONFIG_TCP_AO))
+ __test_skip_all(tests_skip_reason[KCONFIG_TCP_AO]);
+
+ ksft_set_plan(ntests);
+ test_print("rand seed %u", (unsigned int)seed);
+ srand(seed);
+
+
+ ksft_print_header();
+ init_namespaces();
+
+ if (add_veth(veth_name, nsfd_parent, nsfd_child))
+ test_error("Failed to add veth");
+
+ switch_ns(nsfd_child);
+ link_init(veth_name, family, prefix, addr2, addr1);
+ if (peer2) {
+ struct new_pthread_arg targ;
+ pthread_t t;
+
+ targ.my_ip = addr2;
+ targ.dest_ip = addr1;
+ targ.func = peer2;
+ nr_threads++;
+ if (pthread_create(&t, NULL, new_pthread_entry, &targ))
+ test_error("Failed to create pthread");
+ }
+ switch_ns(nsfd_parent);
+ link_init(veth_name, family, prefix, addr1, addr2);
+
+ this_ip_addr = addr1;
+ this_ip_dest = addr2;
+ peer1(NULL);
+ if (failed)
+ exit(KSFT_FAIL);
+ else
+ exit(KSFT_PASS);
+}
+
+/* /proc/sys/net/core/optmem_max artifically limits the amount of memory
+ * that can be allocated with sock_kmalloc() on each socket in the system.
+ * It is not virtualized in v6.7, so it has to written outside test
+ * namespaces. To be nice a test will revert optmem back to the old value.
+ * Keeping it simple without any file lock, which means the tests that
+ * need to set/increase optmem value shouldn't run in parallel.
+ * Also, not re-entrant.
+ * Since commit f5769faeec36 ("net: Namespace-ify sysctl_optmem_max")
+ * it is per-namespace, keeping logic for non-virtualized optmem_max
+ * for v6.7, which supports TCP-AO.
+ */
+static const char *optmem_file = "/proc/sys/net/core/optmem_max";
+static size_t saved_optmem;
+static int optmem_ns = -1;
+
+static bool is_optmem_namespaced(void)
+{
+ if (optmem_ns == -1) {
+ int old_ns = switch_save_ns(nsfd_child);
+
+ optmem_ns = !access(optmem_file, F_OK);
+ switch_ns(old_ns);
+ }
+ return !!optmem_ns;
+}
+
+size_t test_get_optmem(void)
+{
+ int old_ns = 0;
+ FILE *foptmem;
+ size_t ret;
+
+ if (!is_optmem_namespaced())
+ old_ns = switch_save_ns(nsfd_outside);
+ foptmem = fopen(optmem_file, "r");
+ if (!foptmem)
+ test_error("failed to open %s", optmem_file);
+
+ if (fscanf(foptmem, "%zu", &ret) != 1)
+ test_error("can't read from %s", optmem_file);
+ fclose(foptmem);
+ if (!is_optmem_namespaced())
+ switch_ns(old_ns);
+ return ret;
+}
+
+static void __test_set_optmem(size_t new, size_t *old)
+{
+ int old_ns = 0;
+ FILE *foptmem;
+
+ if (old != NULL)
+ *old = test_get_optmem();
+
+ if (!is_optmem_namespaced())
+ old_ns = switch_save_ns(nsfd_outside);
+ foptmem = fopen(optmem_file, "w");
+ if (!foptmem)
+ test_error("failed to open %s", optmem_file);
+
+ if (fprintf(foptmem, "%zu", new) <= 0)
+ test_error("can't write %zu to %s", new, optmem_file);
+ fclose(foptmem);
+ if (!is_optmem_namespaced())
+ switch_ns(old_ns);
+}
+
+static void test_revert_optmem(void)
+{
+ if (saved_optmem == 0)
+ return;
+
+ __test_set_optmem(saved_optmem, NULL);
+}
+
+void test_set_optmem(size_t value)
+{
+ if (saved_optmem == 0) {
+ __test_set_optmem(value, &saved_optmem);
+ test_add_destructor(test_revert_optmem);
+ } else {
+ __test_set_optmem(value, NULL);
+ }
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/sock.c b/tools/testing/selftests/net/tcp_ao/lib/sock.c
new file mode 100644
index 000000000000..15aeb0963058
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/sock.c
@@ -0,0 +1,596 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <alloca.h>
+#include <fcntl.h>
+#include <inttypes.h>
+#include <string.h>
+#include "../../../../../include/linux/kernel.h"
+#include "../../../../../include/linux/stringify.h"
+#include "aolib.h"
+
+const unsigned int test_server_port = 7010;
+int __test_listen_socket(int backlog, void *addr, size_t addr_sz)
+{
+ int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ long flags;
+
+ if (sk < 0)
+ test_error("socket()");
+
+ err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name,
+ strlen(veth_name) + 1);
+ if (err < 0)
+ test_error("setsockopt(SO_BINDTODEVICE)");
+
+ if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0)
+ test_error("bind()");
+
+ flags = fcntl(sk, F_GETFL);
+ if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
+ test_error("fcntl()");
+
+ if (listen(sk, backlog))
+ test_error("listen()");
+
+ return sk;
+}
+
+int test_wait_fd(int sk, time_t sec, bool write)
+{
+ struct timeval tv = { .tv_sec = sec };
+ struct timeval *ptv = NULL;
+ fd_set fds, efds;
+ int ret;
+ socklen_t slen = sizeof(ret);
+
+ FD_ZERO(&fds);
+ FD_SET(sk, &fds);
+ FD_ZERO(&efds);
+ FD_SET(sk, &efds);
+
+ if (sec)
+ ptv = &tv;
+
+ errno = 0;
+ if (write)
+ ret = select(sk + 1, NULL, &fds, &efds, ptv);
+ else
+ ret = select(sk + 1, &fds, NULL, &efds, ptv);
+ if (ret < 0)
+ return -errno;
+ if (ret == 0) {
+ errno = ETIMEDOUT;
+ return -ETIMEDOUT;
+ }
+
+ if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen))
+ return -errno;
+ if (ret)
+ return -ret;
+ return 0;
+}
+
+int __test_connect_socket(int sk, const char *device,
+ void *addr, size_t addr_sz, time_t timeout)
+{
+ long flags;
+ int err;
+
+ if (device != NULL) {
+ err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device,
+ strlen(device) + 1);
+ if (err < 0)
+ test_error("setsockopt(SO_BINDTODEVICE, %s)", device);
+ }
+
+ if (!timeout) {
+ err = connect(sk, addr, addr_sz);
+ if (err) {
+ err = -errno;
+ goto out;
+ }
+ return 0;
+ }
+
+ flags = fcntl(sk, F_GETFL);
+ if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0))
+ test_error("fcntl()");
+
+ if (connect(sk, addr, addr_sz) < 0) {
+ if (errno != EINPROGRESS) {
+ err = -errno;
+ goto out;
+ }
+ if (timeout < 0)
+ return sk;
+ err = test_wait_fd(sk, timeout, 1);
+ if (err)
+ goto out;
+ }
+ return sk;
+
+out:
+ close(sk);
+ return err;
+}
+
+int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix,
+ int vrf, const char *password)
+{
+ size_t pwd_len = strlen(password);
+ struct tcp_md5sig md5sig = {};
+
+ md5sig.tcpm_keylen = pwd_len;
+ memcpy(md5sig.tcpm_key, password, pwd_len);
+ md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX;
+ md5sig.tcpm_prefixlen = prefix;
+ if (vrf >= 0) {
+ md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX;
+ md5sig.tcpm_ifindex = (uint8_t)vrf;
+ }
+ memcpy(&md5sig.tcpm_addr, addr, addr_sz);
+
+ errno = 0;
+ return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT,
+ &md5sig, sizeof(md5sig));
+}
+
+
+int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg,
+ void *addr, size_t addr_sz, bool set_current, bool set_rnext,
+ uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid,
+ uint8_t maclen, uint8_t keyflags,
+ uint8_t keylen, const char *key)
+{
+ memset(ao, 0, sizeof(struct tcp_ao_add));
+
+ ao->set_current = !!set_current;
+ ao->set_rnext = !!set_rnext;
+ ao->prefix = prefix;
+ ao->sndid = sndid;
+ ao->rcvid = rcvid;
+ ao->maclen = maclen;
+ ao->keyflags = keyflags;
+ ao->keylen = keylen;
+ ao->ifindex = vrf;
+
+ memcpy(&ao->addr, addr, addr_sz);
+
+ if (strlen(alg) > 64)
+ return -ENOBUFS;
+ strncpy(ao->alg_name, alg, 64);
+
+ memcpy(ao->key, key,
+ (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen);
+ return 0;
+}
+
+static int test_get_ao_keys_nr(int sk)
+{
+ struct tcp_ao_getsockopt tmp = {};
+ socklen_t tmp_sz = sizeof(tmp);
+ int ret;
+
+ tmp.nkeys = 1;
+ tmp.get_all = 1;
+
+ ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
+ if (ret)
+ return -errno;
+ return (int)tmp.nkeys;
+}
+
+int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out,
+ void *addr, size_t addr_sz, uint8_t prefix,
+ uint8_t sndid, uint8_t rcvid)
+{
+ struct tcp_ao_getsockopt tmp = {};
+ socklen_t tmp_sz = sizeof(tmp);
+ int ret;
+
+ memcpy(&tmp.addr, addr, addr_sz);
+ tmp.prefix = prefix;
+ tmp.sndid = sndid;
+ tmp.rcvid = rcvid;
+ tmp.nkeys = 1;
+
+ ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz);
+ if (ret)
+ return ret;
+ if (tmp.nkeys != 1)
+ return -E2BIG;
+ *out = tmp;
+ return 0;
+}
+
+int test_get_ao_info(int sk, struct tcp_ao_info_opt *out)
+{
+ socklen_t sz = sizeof(*out);
+
+ out->reserved = 0;
+ out->reserved2 = 0;
+ if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz))
+ return -errno;
+ if (sz != sizeof(*out))
+ return -EMSGSIZE;
+ return 0;
+}
+
+int test_set_ao_info(int sk, struct tcp_ao_info_opt *in)
+{
+ socklen_t sz = sizeof(*in);
+
+ in->reserved = 0;
+ in->reserved2 = 0;
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz))
+ return -errno;
+ return 0;
+}
+
+int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a,
+ const struct tcp_ao_getsockopt *b)
+{
+ bool is_kdf_aes_128_cmac = false;
+ bool is_cmac_aes = false;
+
+ if (!strcmp("cmac(aes128)", a->alg_name)) {
+ is_kdf_aes_128_cmac = (a->keylen != 16);
+ is_cmac_aes = true;
+ }
+
+#define __cmp_ao(member) \
+do { \
+ if (b->member != a->member) { \
+ test_fail("getsockopt(): " __stringify(member) " %u != %u", \
+ b->member, a->member); \
+ return -1; \
+ } \
+} while(0)
+ __cmp_ao(sndid);
+ __cmp_ao(rcvid);
+ __cmp_ao(prefix);
+ __cmp_ao(keyflags);
+ __cmp_ao(ifindex);
+ if (a->maclen) {
+ __cmp_ao(maclen);
+ } else if (b->maclen != 12) {
+ test_fail("getsockopt(): expected default maclen 12, but it's %u",
+ b->maclen);
+ return -1;
+ }
+ if (!is_kdf_aes_128_cmac) {
+ __cmp_ao(keylen);
+ } else if (b->keylen != 16) {
+ test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u",
+ b->keylen);
+ return -1;
+ }
+#undef __cmp_ao
+ if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) {
+ test_fail("getsockopt(): returned key is different `%s' != `%s'",
+ b->key, a->key);
+ return -1;
+ }
+ if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) {
+ test_fail("getsockopt(): returned address is different");
+ return -1;
+ }
+ if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) {
+ test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name);
+ return -1;
+ }
+ if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) {
+ test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name);
+ return -1;
+ }
+ /* For a established key rotation test don't add a key with
+ * set_current = 1, as it's likely to change by peer's request;
+ * rather use setsockopt(TCP_AO_INFO)
+ */
+ if (a->set_current != b->is_current) {
+ test_fail("getsockopt(): returned key is not Current_key");
+ return -1;
+ }
+ if (a->set_rnext != b->is_rnext) {
+ test_fail("getsockopt(): returned key is not RNext_key");
+ return -1;
+ }
+
+ return 0;
+}
+
+int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a,
+ const struct tcp_ao_info_opt *b)
+{
+ /* No check for ::current_key, as it may change by the peer */
+ if (a->ao_required != b->ao_required) {
+ test_fail("getsockopt(): returned ao doesn't have ao_required");
+ return -1;
+ }
+ if (a->accept_icmps != b->accept_icmps) {
+ test_fail("getsockopt(): returned ao doesn't accept ICMPs");
+ return -1;
+ }
+ if (a->set_rnext && a->rnext != b->rnext) {
+ test_fail("getsockopt(): RNext KeyID has changed");
+ return -1;
+ }
+#define __cmp_cnt(member) \
+do { \
+ if (b->member != a->member) { \
+ test_fail("getsockopt(): " __stringify(member) " %llu != %llu", \
+ b->member, a->member); \
+ return -1; \
+ } \
+} while(0)
+ if (a->set_counters) {
+ __cmp_cnt(pkt_good);
+ __cmp_cnt(pkt_bad);
+ __cmp_cnt(pkt_key_not_found);
+ __cmp_cnt(pkt_ao_required);
+ __cmp_cnt(pkt_dropped_icmp);
+ }
+#undef __cmp_cnt
+ return 0;
+}
+
+int test_get_tcp_ao_counters(int sk, struct tcp_ao_counters *out)
+{
+ struct tcp_ao_getsockopt *key_dump;
+ socklen_t key_dump_sz = sizeof(*key_dump);
+ struct tcp_ao_info_opt info = {};
+ bool c1, c2, c3, c4, c5;
+ struct netstat *ns;
+ int err, nr_keys;
+
+ memset(out, 0, sizeof(*out));
+
+ /* per-netns */
+ ns = netstat_read();
+ out->netns_ao_good = netstat_get(ns, "TCPAOGood", &c1);
+ out->netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2);
+ out->netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3);
+ out->netns_ao_required = netstat_get(ns, "TCPAORequired", &c4);
+ out->netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5);
+ netstat_free(ns);
+ if (c1 || c2 || c3 || c4 || c5)
+ return -EOPNOTSUPP;
+
+ err = test_get_ao_info(sk, &info);
+ if (err)
+ return err;
+
+ /* per-socket */
+ out->ao_info_pkt_good = info.pkt_good;
+ out->ao_info_pkt_bad = info.pkt_bad;
+ out->ao_info_pkt_key_not_found = info.pkt_key_not_found;
+ out->ao_info_pkt_ao_required = info.pkt_ao_required;
+ out->ao_info_pkt_dropped_icmp = info.pkt_dropped_icmp;
+
+ /* per-key */
+ nr_keys = test_get_ao_keys_nr(sk);
+ if (nr_keys < 0)
+ return nr_keys;
+ if (nr_keys == 0)
+ test_error("test_get_ao_keys_nr() == 0");
+ out->nr_keys = (size_t)nr_keys;
+ key_dump = calloc(nr_keys, key_dump_sz);
+ if (!key_dump)
+ return -errno;
+
+ key_dump[0].nkeys = nr_keys;
+ key_dump[0].get_all = 1;
+ key_dump[0].get_all = 1;
+ err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS,
+ key_dump, &key_dump_sz);
+ if (err) {
+ free(key_dump);
+ return -errno;
+ }
+
+ out->key_cnts = calloc(nr_keys, sizeof(out->key_cnts[0]));
+ if (!out->key_cnts) {
+ free(key_dump);
+ return -errno;
+ }
+
+ while (nr_keys--) {
+ out->key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid;
+ out->key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid;
+ out->key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good;
+ out->key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad;
+ }
+ free(key_dump);
+
+ return 0;
+}
+
+int __test_tcp_ao_counters_cmp(const char *tst_name,
+ struct tcp_ao_counters *before,
+ struct tcp_ao_counters *after,
+ test_cnt expected)
+{
+#define __cmp_ao(cnt, expecting_inc) \
+do { \
+ if (before->cnt > after->cnt) { \
+ test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \
+ tst_name ?: "", before->cnt, after->cnt); \
+ return -1; \
+ } \
+ if ((before->cnt != after->cnt) != (expecting_inc)) { \
+ test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \
+ tst_name ?: "", (expecting_inc) ? "" : "not ", \
+ before->cnt, after->cnt); \
+ return -1; \
+ } \
+} while(0)
+
+ errno = 0;
+ /* per-netns */
+ __cmp_ao(netns_ao_good, !!(expected & TEST_CNT_NS_GOOD));
+ __cmp_ao(netns_ao_bad, !!(expected & TEST_CNT_NS_BAD));
+ __cmp_ao(netns_ao_key_not_found,
+ !!(expected & TEST_CNT_NS_KEY_NOT_FOUND));
+ __cmp_ao(netns_ao_required, !!(expected & TEST_CNT_NS_AO_REQUIRED));
+ __cmp_ao(netns_ao_dropped_icmp,
+ !!(expected & TEST_CNT_NS_DROPPED_ICMP));
+ /* per-socket */
+ __cmp_ao(ao_info_pkt_good, !!(expected & TEST_CNT_SOCK_GOOD));
+ __cmp_ao(ao_info_pkt_bad, !!(expected & TEST_CNT_SOCK_BAD));
+ __cmp_ao(ao_info_pkt_key_not_found,
+ !!(expected & TEST_CNT_SOCK_KEY_NOT_FOUND));
+ __cmp_ao(ao_info_pkt_ao_required, !!(expected & TEST_CNT_SOCK_AO_REQUIRED));
+ __cmp_ao(ao_info_pkt_dropped_icmp,
+ !!(expected & TEST_CNT_SOCK_DROPPED_ICMP));
+ return 0;
+#undef __cmp_ao
+}
+
+int test_tcp_ao_key_counters_cmp(const char *tst_name,
+ struct tcp_ao_counters *before,
+ struct tcp_ao_counters *after,
+ test_cnt expected,
+ int sndid, int rcvid)
+{
+ size_t i;
+#define __cmp_ao(i, cnt, expecting_inc) \
+do { \
+ if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) { \
+ test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \
+ tst_name ?: "", before->key_cnts[i].cnt, \
+ after->key_cnts[i].cnt, \
+ before->key_cnts[i].sndid, \
+ before->key_cnts[i].rcvid); \
+ return -1; \
+ } \
+ if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != (expecting_inc)) { \
+ test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \
+ tst_name ?: "", (expecting_inc) ? "" : "not ",\
+ before->key_cnts[i].cnt, \
+ after->key_cnts[i].cnt, \
+ before->key_cnts[i].sndid, \
+ before->key_cnts[i].rcvid); \
+ return -1; \
+ } \
+} while(0)
+
+ if (before->nr_keys != after->nr_keys) {
+ test_fail("%s: Keys changed on the socket %zu != %zu",
+ tst_name, before->nr_keys, after->nr_keys);
+ return -1;
+ }
+
+ /* per-key */
+ i = before->nr_keys;
+ while (i--) {
+ if (sndid >= 0 && before->key_cnts[i].sndid != sndid)
+ continue;
+ if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid)
+ continue;
+ __cmp_ao(i, pkt_good, !!(expected & TEST_CNT_KEY_GOOD));
+ __cmp_ao(i, pkt_bad, !!(expected & TEST_CNT_KEY_BAD));
+ }
+ return 0;
+#undef __cmp_ao
+}
+
+void test_tcp_ao_counters_free(struct tcp_ao_counters *cnts)
+{
+ free(cnts->key_cnts);
+}
+
+#define TEST_BUF_SIZE 4096
+ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec)
+{
+ ssize_t total = 0;
+
+ do {
+ char buf[TEST_BUF_SIZE];
+ ssize_t bytes, sent;
+ int ret;
+
+ ret = test_wait_fd(sk, timeout_sec, 0);
+ if (ret)
+ return ret;
+
+ bytes = recv(sk, buf, sizeof(buf), 0);
+
+ if (bytes < 0)
+ test_error("recv(): %zd", bytes);
+ if (bytes == 0)
+ break;
+
+ ret = test_wait_fd(sk, timeout_sec, 1);
+ if (ret)
+ return ret;
+
+ sent = send(sk, buf, bytes, 0);
+ if (sent == 0)
+ break;
+ if (sent != bytes)
+ test_error("send()");
+ total += bytes;
+ } while (!quota || total < quota);
+
+ return total;
+}
+
+ssize_t test_client_loop(int sk, char *buf, size_t buf_sz,
+ const size_t msg_len, time_t timeout_sec)
+{
+ char msg[msg_len];
+ int nodelay = 1;
+ size_t i;
+
+ if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay)))
+ test_error("setsockopt(TCP_NODELAY)");
+
+ for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) {
+ size_t sent, bytes = min(msg_len, buf_sz - i);
+ int ret;
+
+ ret = test_wait_fd(sk, timeout_sec, 1);
+ if (ret)
+ return ret;
+
+ sent = send(sk, buf + i, bytes, 0);
+ if (sent == 0)
+ break;
+ if (sent != bytes)
+ test_error("send()");
+
+ bytes = 0;
+ do {
+ ssize_t got;
+
+ ret = test_wait_fd(sk, timeout_sec, 0);
+ if (ret)
+ return ret;
+
+ got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0);
+ if (got <= 0)
+ return i;
+ bytes += got;
+ } while (bytes < sent);
+ if (bytes > sent)
+ test_error("recv(): %zd > %zd", bytes, sent);
+ if (memcmp(buf + i, msg, bytes) != 0) {
+ test_fail("received message differs");
+ return -1;
+ }
+ }
+ return i;
+}
+
+int test_client_verify(int sk, const size_t msg_len, const size_t nr,
+ time_t timeout_sec)
+{
+ size_t buf_sz = msg_len * nr;
+ char *buf = alloca(buf_sz);
+ ssize_t ret;
+
+ randomize_buffer(buf, buf_sz);
+ ret = test_client_loop(sk, buf, buf_sz, msg_len, timeout_sec);
+ if (ret < 0)
+ return (int)ret;
+ return ret != buf_sz ? -1 : 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/lib/utils.c b/tools/testing/selftests/net/tcp_ao/lib/utils.c
new file mode 100644
index 000000000000..372daca525f5
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/lib/utils.c
@@ -0,0 +1,30 @@
+// SPDX-License-Identifier: GPL-2.0
+#include "aolib.h"
+#include <string.h>
+
+void randomize_buffer(void *buf, size_t buflen)
+{
+ int *p = (int *)buf;
+ size_t words = buflen / sizeof(int);
+ size_t leftover = buflen % sizeof(int);
+
+ if (!buflen)
+ return;
+
+ while (words--)
+ *p++ = rand();
+
+ if (leftover) {
+ int tmp = rand();
+
+ memcpy(buf + buflen - leftover, &tmp, leftover);
+ }
+}
+
+const struct sockaddr_in6 addr_any6 = {
+ .sin6_family = AF_INET6,
+};
+
+const struct sockaddr_in addr_any4 = {
+ .sin_family = AF_INET,
+};
diff --git a/tools/testing/selftests/net/tcp_ao/restore.c b/tools/testing/selftests/net/tcp_ao/restore.c
new file mode 100644
index 000000000000..8fdc808df325
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/restore.c
@@ -0,0 +1,236 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Author: Dmitry Safonov <dima@arista.com> */
+/* This is over-simplified TCP_REPAIR for TCP_ESTABLISHED sockets
+ * It tests that TCP-AO enabled connection can be restored.
+ * For the proper socket repair see:
+ * https://github.com/checkpoint-restore/criu/blob/criu-dev/soccr/soccr.h
+ */
+#include <inttypes.h>
+#include "aolib.h"
+
+const size_t nr_packets = 20;
+const size_t msg_len = 100;
+const size_t quota = nr_packets * msg_len;
+#define fault(type) (inj == FAULT_ ## type)
+
+static void try_server_run(const char *tst_name, unsigned int port,
+ fault_t inj, test_cnt cnt_expected)
+{
+ const char *cnt_name = "TCPAOGood";
+ struct tcp_ao_counters ao1, ao2;
+ uint64_t before_cnt, after_cnt;
+ int sk, lsk;
+ time_t timeout;
+ ssize_t bytes;
+
+ if (fault(TIMEOUT))
+ cnt_name = "TCPAOBad";
+ lsk = test_listen_socket(this_ip_addr, port, 1);
+
+ if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ synchronize_threads(); /* 1: MKT added => connect() */
+
+ if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
+ test_error("test_wait_fd()");
+
+ sk = accept(lsk, NULL, NULL);
+ if (sk < 0)
+ test_error("accept()");
+
+ synchronize_threads(); /* 2: accepted => send data */
+ close(lsk);
+
+ bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
+ if (bytes != quota) {
+ test_fail("%s: server served: %zd", tst_name, bytes);
+ goto out;
+ }
+
+ before_cnt = netstat_get_one(cnt_name, NULL);
+ if (test_get_tcp_ao_counters(sk, &ao1))
+ test_error("test_get_tcp_ao_counters()");
+
+ timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
+ bytes = test_server_run(sk, quota, timeout);
+ if (fault(TIMEOUT)) {
+ if (bytes > 0)
+ test_fail("%s: server served: %zd", tst_name, bytes);
+ else
+ test_ok("%s: server couldn't serve", tst_name);
+ } else {
+ if (bytes != quota)
+ test_fail("%s: server served: %zd", tst_name, bytes);
+ else
+ test_ok("%s: server alive", tst_name);
+ }
+ if (test_get_tcp_ao_counters(sk, &ao2))
+ test_error("test_get_tcp_ao_counters()");
+ after_cnt = netstat_get_one(cnt_name, NULL);
+
+ test_tcp_ao_counters_cmp(tst_name, &ao1, &ao2, cnt_expected);
+
+ if (after_cnt <= before_cnt) {
+ test_fail("%s: %s counter did not increase: %zu <= %zu",
+ tst_name, cnt_name, after_cnt, before_cnt);
+ } else {
+ test_ok("%s: counter %s increased %zu => %zu",
+ tst_name, cnt_name, before_cnt, after_cnt);
+ }
+
+ /*
+ * Before close() as that will send FIN and move the peer in TCP_CLOSE
+ * and that will prevent reading AO counters from the peer's socket.
+ */
+ synchronize_threads(); /* 3: verified => closed */
+out:
+ close(sk);
+}
+
+static void *server_fn(void *arg)
+{
+ unsigned int port = test_server_port;
+
+ try_server_run("TCP-AO migrate to another socket", port++,
+ 0, TEST_CNT_GOOD);
+ try_server_run("TCP-AO with wrong send ISN", port++,
+ FAULT_TIMEOUT, TEST_CNT_BAD);
+ try_server_run("TCP-AO with wrong receive ISN", port++,
+ FAULT_TIMEOUT, TEST_CNT_BAD);
+ try_server_run("TCP-AO with wrong send SEQ ext number", port++,
+ FAULT_TIMEOUT, TEST_CNT_BAD);
+ try_server_run("TCP-AO with wrong receive SEQ ext number", port++,
+ FAULT_TIMEOUT, TEST_CNT_NS_BAD | TEST_CNT_GOOD);
+
+ synchronize_threads(); /* don't race to exit: client exits */
+ return NULL;
+}
+
+static void test_get_sk_checkpoint(unsigned int server_port, sockaddr_af *saddr,
+ struct tcp_sock_state *img,
+ struct tcp_ao_repair *ao_img)
+{
+ int sk;
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ synchronize_threads(); /* 1: MKT added => connect() */
+ if (test_connect_socket(sk, this_ip_dest, server_port) <= 0)
+ test_error("failed to connect()");
+
+ synchronize_threads(); /* 2: accepted => send data */
+ if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
+ test_fail("pre-migrate verify failed");
+
+ test_enable_repair(sk);
+ test_sock_checkpoint(sk, img, saddr);
+ test_ao_checkpoint(sk, ao_img);
+ test_kill_sk(sk);
+}
+
+static void test_sk_restore(const char *tst_name, unsigned int server_port,
+ sockaddr_af *saddr, struct tcp_sock_state *img,
+ struct tcp_ao_repair *ao_img,
+ fault_t inj, test_cnt cnt_expected)
+{
+ const char *cnt_name = "TCPAOGood";
+ struct tcp_ao_counters ao1, ao2;
+ uint64_t before_cnt, after_cnt;
+ time_t timeout;
+ int sk;
+
+ if (fault(TIMEOUT))
+ cnt_name = "TCPAOBad";
+
+ before_cnt = netstat_get_one(cnt_name, NULL);
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ test_enable_repair(sk);
+ test_sock_restore(sk, img, saddr, this_ip_dest, server_port);
+ if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ test_ao_restore(sk, ao_img);
+
+ if (test_get_tcp_ao_counters(sk, &ao1))
+ test_error("test_get_tcp_ao_counters()");
+
+ test_disable_repair(sk);
+ test_sock_state_free(img);
+
+ timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
+ if (test_client_verify(sk, msg_len, nr_packets, timeout)) {
+ if (fault(TIMEOUT))
+ test_ok("%s: post-migrate connection is broken", tst_name);
+ else
+ test_fail("%s: post-migrate connection is working", tst_name);
+ } else {
+ if (fault(TIMEOUT))
+ test_fail("%s: post-migrate connection still working", tst_name);
+ else
+ test_ok("%s: post-migrate connection is alive", tst_name);
+ }
+ if (test_get_tcp_ao_counters(sk, &ao2))
+ test_error("test_get_tcp_ao_counters()");
+ after_cnt = netstat_get_one(cnt_name, NULL);
+
+ test_tcp_ao_counters_cmp(tst_name, &ao1, &ao2, cnt_expected);
+
+ if (after_cnt <= before_cnt) {
+ test_fail("%s: %s counter did not increase: %zu <= %zu",
+ tst_name, cnt_name, after_cnt, before_cnt);
+ } else {
+ test_ok("%s: counter %s increased %zu => %zu",
+ tst_name, cnt_name, before_cnt, after_cnt);
+ }
+ synchronize_threads(); /* 3: verified => closed */
+ close(sk);
+}
+
+static void *client_fn(void *arg)
+{
+ unsigned int port = test_server_port;
+ struct tcp_sock_state tcp_img;
+ struct tcp_ao_repair ao_img;
+ sockaddr_af saddr;
+
+ test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img);
+ test_sk_restore("TCP-AO migrate to another socket", port++,
+ &saddr, &tcp_img, &ao_img, 0, TEST_CNT_GOOD);
+
+ test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img);
+ ao_img.snt_isn += 1;
+ test_sk_restore("TCP-AO with wrong send ISN", port++,
+ &saddr, &tcp_img, &ao_img, FAULT_TIMEOUT, TEST_CNT_BAD);
+
+ test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img);
+ ao_img.rcv_isn += 1;
+ test_sk_restore("TCP-AO with wrong receive ISN", port++,
+ &saddr, &tcp_img, &ao_img, FAULT_TIMEOUT, TEST_CNT_BAD);
+
+ test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img);
+ ao_img.snd_sne += 1;
+ test_sk_restore("TCP-AO with wrong send SEQ ext number", port++,
+ &saddr, &tcp_img, &ao_img, FAULT_TIMEOUT,
+ TEST_CNT_NS_BAD | TEST_CNT_GOOD);
+
+ test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img);
+ ao_img.rcv_sne += 1;
+ test_sk_restore("TCP-AO with wrong receive SEQ ext number", port++,
+ &saddr, &tcp_img, &ao_img, FAULT_TIMEOUT,
+ TEST_CNT_NS_GOOD | TEST_CNT_BAD);
+
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(20, server_fn, client_fn);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/rst.c b/tools/testing/selftests/net/tcp_ao/rst.c
new file mode 100644
index 000000000000..7df8b8700e39
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/rst.c
@@ -0,0 +1,457 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * The test checks that both active and passive reset have correct TCP-AO
+ * signature. An "active" reset (abort) here is procured from closing
+ * listen() socket with non-accepted connections in the queue:
+ * inet_csk_listen_stop() => inet_child_forget() =>
+ * => tcp_disconnect() => tcp_send_active_reset()
+ *
+ * The passive reset is quite hard to get on established TCP connections.
+ * It could be procured from non-established states, but the synchronization
+ * part from userspace in order to reliably get RST seems uneasy.
+ * So, instead it's procured by corrupting SEQ number on TIMED-WAIT state.
+ *
+ * It's important to test both passive and active RST as they go through
+ * different code-paths:
+ * - tcp_send_active_reset() makes no-data skb, sends it with tcp_transmit_skb()
+ * - tcp_v*_send_reset() create their reply skbs and send them with
+ * ip_send_unicast_reply()
+ *
+ * In both cases TCP-AO signatures have to be correct, which is verified by
+ * (1) checking that the TCP-AO connection was reset and (2) TCP-AO counters.
+ *
+ * Author: Dmitry Safonov <dima@arista.com>
+ */
+#include <inttypes.h>
+#include "../../../../include/linux/kernel.h"
+#include "aolib.h"
+
+const size_t quota = 1000;
+const size_t packet_sz = 100;
+/*
+ * Backlog == 0 means 1 connection in queue, see:
+ * commit 64a146513f8f ("[NET]: Revert incorrect accept queue...")
+ */
+const unsigned int backlog;
+
+static void netstats_check(struct netstat *before, struct netstat *after,
+ char *msg)
+{
+ uint64_t before_cnt, after_cnt;
+
+ before_cnt = netstat_get(before, "TCPAORequired", NULL);
+ after_cnt = netstat_get(after, "TCPAORequired", NULL);
+ if (after_cnt > before_cnt)
+ test_fail("Segments without AO sign (%s): %" PRIu64 " => %" PRIu64,
+ msg, before_cnt, after_cnt);
+ else
+ test_ok("No segments without AO sign (%s)", msg);
+
+ before_cnt = netstat_get(before, "TCPAOGood", NULL);
+ after_cnt = netstat_get(after, "TCPAOGood", NULL);
+ if (after_cnt <= before_cnt)
+ test_fail("Signed AO segments (%s): %" PRIu64 " => %" PRIu64,
+ msg, before_cnt, after_cnt);
+ else
+ test_ok("Signed AO segments (%s): %" PRIu64 " => %" PRIu64,
+ msg, before_cnt, after_cnt);
+
+ before_cnt = netstat_get(before, "TCPAOBad", NULL);
+ after_cnt = netstat_get(after, "TCPAOBad", NULL);
+ if (after_cnt > before_cnt)
+ test_fail("Segments with bad AO sign (%s): %" PRIu64 " => %" PRIu64,
+ msg, before_cnt, after_cnt);
+ else
+ test_ok("No segments with bad AO sign (%s)", msg);
+}
+
+/*
+ * Another way to send RST, but not through tcp_v{4,6}_send_reset()
+ * is tcp_send_active_reset(), that is not in reply to inbound segment,
+ * but rather active send. It uses tcp_transmit_skb(), so that should
+ * work, but as it also sends RST - nice that it can be covered as well.
+ */
+static void close_forced(int sk)
+{
+ struct linger sl;
+
+ sl.l_onoff = 1;
+ sl.l_linger = 0;
+ if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)))
+ test_error("setsockopt(SO_LINGER)");
+ close(sk);
+}
+
+static void test_server_active_rst(unsigned int port)
+{
+ struct tcp_ao_counters cnt1, cnt2;
+ ssize_t bytes;
+ int sk, lsk;
+
+ lsk = test_listen_socket(this_ip_addr, port, backlog);
+ if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ if (test_get_tcp_ao_counters(lsk, &cnt1))
+ test_error("test_get_tcp_ao_counters()");
+
+ synchronize_threads(); /* 1: MKT added */
+ if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
+ test_error("test_wait_fd()");
+
+ sk = accept(lsk, NULL, NULL);
+ if (sk < 0)
+ test_error("accept()");
+
+ synchronize_threads(); /* 2: connection accept()ed, another queued */
+ if (test_get_tcp_ao_counters(lsk, &cnt2))
+ test_error("test_get_tcp_ao_counters()");
+
+ synchronize_threads(); /* 3: close listen socket */
+ close(lsk);
+ bytes = test_server_run(sk, quota, 0);
+ if (bytes != quota)
+ test_error("servered only %zd bytes", bytes);
+ else
+ test_ok("servered %zd bytes", bytes);
+
+ synchronize_threads(); /* 4: finishing up */
+ close_forced(sk);
+
+ synchronize_threads(); /* 5: closed active sk */
+
+ synchronize_threads(); /* 6: counters checks */
+ if (test_tcp_ao_counters_cmp("active RST server", &cnt1, &cnt2, TEST_CNT_GOOD))
+ test_fail("MKT counters (server) have not only good packets");
+ else
+ test_ok("MKT counters are good on server");
+}
+
+static void test_server_passive_rst(unsigned int port)
+{
+ struct tcp_ao_counters ao1, ao2;
+ int sk, lsk;
+ ssize_t bytes;
+
+ lsk = test_listen_socket(this_ip_addr, port, 1);
+
+ if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ synchronize_threads(); /* 1: MKT added => connect() */
+ if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
+ test_error("test_wait_fd()");
+
+ sk = accept(lsk, NULL, NULL);
+ if (sk < 0)
+ test_error("accept()");
+
+ synchronize_threads(); /* 2: accepted => send data */
+ close(lsk);
+ if (test_get_tcp_ao_counters(sk, &ao1))
+ test_error("test_get_tcp_ao_counters()");
+
+ bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
+ if (bytes != quota) {
+ if (bytes > 0)
+ test_fail("server served: %zd", bytes);
+ else
+ test_fail("server returned %zd", bytes);
+ }
+
+ synchronize_threads(); /* 3: checkpoint the client */
+ synchronize_threads(); /* 4: close the server, creating twsk */
+ if (test_get_tcp_ao_counters(sk, &ao2))
+ test_error("test_get_tcp_ao_counters()");
+ close(sk);
+
+ synchronize_threads(); /* 5: restore the socket, send more data */
+ test_tcp_ao_counters_cmp("passive RST server", &ao1, &ao2, TEST_CNT_GOOD);
+
+ synchronize_threads(); /* 6: server exits */
+}
+
+static void *server_fn(void *arg)
+{
+ struct netstat *ns_before, *ns_after;
+ unsigned int port = test_server_port;
+
+ ns_before = netstat_read();
+
+ test_server_active_rst(port++);
+ test_server_passive_rst(port++);
+
+ ns_after = netstat_read();
+ netstats_check(ns_before, ns_after, "server");
+ netstat_free(ns_after);
+ netstat_free(ns_before);
+ synchronize_threads(); /* exit */
+
+ synchronize_threads(); /* don't race to exit() - client exits */
+ return NULL;
+}
+
+static int test_wait_fds(int sk[], size_t nr, bool is_writable[],
+ ssize_t wait_for, time_t sec)
+{
+ struct timeval tv = { .tv_sec = sec };
+ struct timeval *ptv = NULL;
+ fd_set left;
+ size_t i;
+ int ret;
+
+ FD_ZERO(&left);
+ for (i = 0; i < nr; i++) {
+ FD_SET(sk[i], &left);
+ if (is_writable)
+ is_writable[i] = false;
+ }
+
+ if (sec)
+ ptv = &tv;
+
+ do {
+ bool is_empty = true;
+ fd_set fds, efds;
+ int nfd = 0;
+
+ FD_ZERO(&fds);
+ FD_ZERO(&efds);
+ for (i = 0; i < nr; i++) {
+ if (!FD_ISSET(sk[i], &left))
+ continue;
+
+ if (sk[i] > nfd)
+ nfd = sk[i];
+
+ FD_SET(sk[i], &fds);
+ FD_SET(sk[i], &efds);
+ is_empty = false;
+ }
+ if (is_empty)
+ return -ENOENT;
+
+ errno = 0;
+ ret = select(nfd + 1, NULL, &fds, &efds, ptv);
+ if (ret < 0)
+ return -errno;
+ if (!ret)
+ return -ETIMEDOUT;
+ for (i = 0; i < nr; i++) {
+ if (FD_ISSET(sk[i], &fds)) {
+ if (is_writable)
+ is_writable[i] = true;
+ FD_CLR(sk[i], &left);
+ wait_for--;
+ continue;
+ }
+ if (FD_ISSET(sk[i], &efds)) {
+ FD_CLR(sk[i], &left);
+ wait_for--;
+ }
+ }
+ } while (wait_for > 0);
+
+ return 0;
+}
+
+static void test_client_active_rst(unsigned int port)
+{
+ /* one in queue, another accept()ed */
+ unsigned int wait_for = backlog + 2;
+ int i, sk[3], err;
+ bool is_writable[ARRAY_SIZE(sk)] = {false};
+ unsigned int last = ARRAY_SIZE(sk) - 1;
+
+ for (i = 0; i < ARRAY_SIZE(sk); i++) {
+ sk[i] = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk[i] < 0)
+ test_error("socket()");
+ if (test_add_key(sk[i], DEFAULT_TEST_PASSWORD,
+ this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ }
+
+ synchronize_threads(); /* 1: MKT added */
+ for (i = 0; i < last; i++) {
+ err = _test_connect_socket(sk[i], this_ip_dest, port,
+ (i == 0) ? TEST_TIMEOUT_SEC : -1);
+
+ if (err < 0)
+ test_error("failed to connect()");
+ }
+
+ synchronize_threads(); /* 2: connection accept()ed, another queued */
+ err = test_wait_fds(sk, last, is_writable, wait_for, TEST_TIMEOUT_SEC);
+ if (err < 0)
+ test_error("test_wait_fds(): %d", err);
+
+ synchronize_threads(); /* 3: close listen socket */
+ if (test_client_verify(sk[0], packet_sz, quota / packet_sz, TEST_TIMEOUT_SEC))
+ test_fail("Failed to send data on connected socket");
+ else
+ test_ok("Verified established tcp connection");
+
+ synchronize_threads(); /* 4: finishing up */
+ err = _test_connect_socket(sk[last], this_ip_dest, port, -1);
+ if (err < 0)
+ test_error("failed to connect()");
+
+ synchronize_threads(); /* 5: closed active sk */
+ err = test_wait_fds(sk, ARRAY_SIZE(sk), NULL,
+ wait_for, TEST_TIMEOUT_SEC);
+ if (err < 0)
+ test_error("select(): %d", err);
+
+ for (i = 0; i < ARRAY_SIZE(sk); i++) {
+ socklen_t slen = sizeof(err);
+
+ if (getsockopt(sk[i], SOL_SOCKET, SO_ERROR, &err, &slen))
+ test_error("getsockopt()");
+ if (is_writable[i] && err != ECONNRESET) {
+ test_fail("sk[%d] = %d, err = %d, connection wasn't reset",
+ i, sk[i], err);
+ } else {
+ test_ok("sk[%d] = %d%s", i, sk[i],
+ is_writable[i] ? ", connection was reset" : "");
+ }
+ }
+ synchronize_threads(); /* 6: counters checks */
+}
+
+static void test_client_passive_rst(unsigned int port)
+{
+ struct tcp_ao_counters ao1, ao2;
+ struct tcp_ao_repair ao_img;
+ struct tcp_sock_state img;
+ sockaddr_af saddr;
+ int sk, err;
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ synchronize_threads(); /* 1: MKT added => connect() */
+ if (test_connect_socket(sk, this_ip_dest, port) <= 0)
+ test_error("failed to connect()");
+
+ synchronize_threads(); /* 2: accepted => send data */
+ if (test_client_verify(sk, packet_sz, quota / packet_sz, TEST_TIMEOUT_SEC))
+ test_fail("Failed to send data on connected socket");
+ else
+ test_ok("Verified established tcp connection");
+
+ synchronize_threads(); /* 3: checkpoint the client */
+ test_enable_repair(sk);
+ test_sock_checkpoint(sk, &img, &saddr);
+ test_ao_checkpoint(sk, &ao_img);
+ test_disable_repair(sk);
+
+ synchronize_threads(); /* 4: close the server, creating twsk */
+
+ /*
+ * The "corruption" in SEQ has to be small enough to fit into TCP
+ * window, see tcp_timewait_state_process() for out-of-window
+ * segments.
+ */
+ img.out.seq += 5; /* 5 is more noticeable in tcpdump than 1 */
+
+ /*
+ * FIXME: This is kind-of ugly and dirty, but it works.
+ *
+ * At this moment, the server has close'ed(sk).
+ * The passive RST that is being targeted here is new data after
+ * half-duplex close, see tcp_timewait_state_process() => TCP_TW_RST
+ *
+ * What is needed here is:
+ * (1) wait for FIN from the server
+ * (2) make sure that the ACK from the client went out
+ * (3) make sure that the ACK was received and processed by the server
+ *
+ * Otherwise, the data that will be sent from "repaired" socket
+ * post SEQ corruption may get to the server before it's in
+ * TCP_FIN_WAIT2.
+ *
+ * (1) is easy with select()/poll()
+ * (2) is possible by polling tcpi_state from TCP_INFO
+ * (3) is quite complex: as server's socket was already closed,
+ * probably the way to do it would be tcp-diag.
+ */
+ sleep(TEST_RETRANSMIT_SEC);
+
+ synchronize_threads(); /* 5: restore the socket, send more data */
+ test_kill_sk(sk);
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ test_enable_repair(sk);
+ test_sock_restore(sk, &img, &saddr, this_ip_dest, port);
+ if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ test_ao_restore(sk, &ao_img);
+
+ if (test_get_tcp_ao_counters(sk, &ao1))
+ test_error("test_get_tcp_ao_counters()");
+
+ test_disable_repair(sk);
+ test_sock_state_free(&img);
+
+ /*
+ * This is how "passive reset" is acquired in this test from TCP_TW_RST:
+ *
+ * IP 10.0.254.1.7011 > 10.0.1.1.59772: Flags [P.], seq 901:1001, ack 1001, win 249,
+ * options [tcp-ao keyid 100 rnextkeyid 100 mac 0x10217d6c36a22379086ef3b1], length 100
+ * IP 10.0.254.1.7011 > 10.0.1.1.59772: Flags [F.], seq 1001, ack 1001, win 249,
+ * options [tcp-ao keyid 100 rnextkeyid 100 mac 0x104ffc99b98c10a5298cc268], length 0
+ * IP 10.0.1.1.59772 > 10.0.254.1.7011: Flags [.], ack 1002, win 251,
+ * options [tcp-ao keyid 100 rnextkeyid 100 mac 0xe496dd4f7f5a8a66873c6f93,nop,nop,sack 1 {1001:1002}], length 0
+ * IP 10.0.1.1.59772 > 10.0.254.1.7011: Flags [P.], seq 1006:1106, ack 1001, win 251,
+ * options [tcp-ao keyid 100 rnextkeyid 100 mac 0x1b5f3330fb23fbcd0c77d0ca], length 100
+ * IP 10.0.254.1.7011 > 10.0.1.1.59772: Flags [R], seq 3215596252, win 0,
+ * options [tcp-ao keyid 100 rnextkeyid 100 mac 0x0bcfbbf497bce844312304b2], length 0
+ */
+ err = test_client_verify(sk, packet_sz, quota / packet_sz, 2 * TEST_TIMEOUT_SEC);
+ /* Make sure that the connection was reset, not timeouted */
+ if (err && err == -ECONNRESET)
+ test_ok("client sock was passively reset post-seq-adjust");
+ else if (err)
+ test_fail("client sock was not reset post-seq-adjust: %d", err);
+ else
+ test_fail("client sock is yet connected post-seq-adjust");
+
+ if (test_get_tcp_ao_counters(sk, &ao2))
+ test_error("test_get_tcp_ao_counters()");
+
+ synchronize_threads(); /* 6: server exits */
+ close(sk);
+ test_tcp_ao_counters_cmp("client passive RST", &ao1, &ao2, TEST_CNT_GOOD);
+}
+
+static void *client_fn(void *arg)
+{
+ struct netstat *ns_before, *ns_after;
+ unsigned int port = test_server_port;
+
+ ns_before = netstat_read();
+
+ test_client_active_rst(port++);
+ test_client_passive_rst(port++);
+
+ ns_after = netstat_read();
+ netstats_check(ns_before, ns_after, "client");
+ netstat_free(ns_after);
+ netstat_free(ns_before);
+
+ synchronize_threads(); /* exit */
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(14, server_fn, client_fn);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/self-connect.c b/tools/testing/selftests/net/tcp_ao/self-connect.c
new file mode 100644
index 000000000000..e154d9e198a9
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/self-connect.c
@@ -0,0 +1,197 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Author: Dmitry Safonov <dima@arista.com> */
+#include <inttypes.h>
+#include "aolib.h"
+
+static union tcp_addr local_addr;
+
+static void __setup_lo_intf(const char *lo_intf,
+ const char *addr_str, uint8_t prefix)
+{
+ if (inet_pton(TEST_FAMILY, addr_str, &local_addr) != 1)
+ test_error("Can't convert local ip address");
+
+ if (ip_addr_add(lo_intf, TEST_FAMILY, local_addr, prefix))
+ test_error("Failed to add %s ip address", lo_intf);
+
+ if (link_set_up(lo_intf))
+ test_error("Failed to bring %s up", lo_intf);
+}
+
+static void setup_lo_intf(const char *lo_intf)
+{
+#ifdef IPV6_TEST
+ __setup_lo_intf(lo_intf, "::1", 128);
+#else
+ __setup_lo_intf(lo_intf, "127.0.0.1", 8);
+#endif
+}
+
+static void tcp_self_connect(const char *tst, unsigned int port,
+ bool different_keyids, bool check_restore)
+{
+ uint64_t before_challenge_ack, after_challenge_ack;
+ uint64_t before_syn_challenge, after_syn_challenge;
+ struct tcp_ao_counters before_ao, after_ao;
+ uint64_t before_aogood, after_aogood;
+ struct netstat *ns_before, *ns_after;
+ const size_t nr_packets = 20;
+ struct tcp_ao_repair ao_img;
+ struct tcp_sock_state img;
+ sockaddr_af addr;
+ int sk;
+
+ tcp_addr_to_sockaddr_in(&addr, &local_addr, htons(port));
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ if (different_keyids) {
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 5, 7))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 7, 5))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ } else {
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ }
+
+ if (bind(sk, (struct sockaddr *)&addr, sizeof(addr)) < 0)
+ test_error("bind()");
+
+ ns_before = netstat_read();
+ before_aogood = netstat_get(ns_before, "TCPAOGood", NULL);
+ before_challenge_ack = netstat_get(ns_before, "TCPChallengeACK", NULL);
+ before_syn_challenge = netstat_get(ns_before, "TCPSYNChallenge", NULL);
+ if (test_get_tcp_ao_counters(sk, &before_ao))
+ test_error("test_get_tcp_ao_counters()");
+
+ if (__test_connect_socket(sk, "lo", (struct sockaddr *)&addr,
+ sizeof(addr), TEST_TIMEOUT_SEC) < 0) {
+ ns_after = netstat_read();
+ netstat_print_diff(ns_before, ns_after);
+ test_error("failed to connect()");
+ }
+
+ if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
+ test_fail("%s: tcp connection verify failed", tst);
+ close(sk);
+ return;
+ }
+
+ ns_after = netstat_read();
+ after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
+ after_challenge_ack = netstat_get(ns_after, "TCPChallengeACK", NULL);
+ after_syn_challenge = netstat_get(ns_after, "TCPSYNChallenge", NULL);
+ if (test_get_tcp_ao_counters(sk, &after_ao))
+ test_error("test_get_tcp_ao_counters()");
+ if (!check_restore) {
+ /* to debug: netstat_print_diff(ns_before, ns_after); */
+ netstat_free(ns_before);
+ }
+ netstat_free(ns_after);
+
+ if (after_aogood <= before_aogood) {
+ test_fail("%s: TCPAOGood counter mismatch: %zu <= %zu",
+ tst, after_aogood, before_aogood);
+ close(sk);
+ return;
+ }
+ if (after_challenge_ack <= before_challenge_ack ||
+ after_syn_challenge <= before_syn_challenge) {
+ /*
+ * It's also meant to test simultaneous open, so check
+ * these counters as well.
+ */
+ test_fail("%s: Didn't challenge SYN or ACK: %zu <= %zu OR %zu <= %zu",
+ tst, after_challenge_ack, before_challenge_ack,
+ after_syn_challenge, before_syn_challenge);
+ close(sk);
+ return;
+ }
+
+ if (test_tcp_ao_counters_cmp(tst, &before_ao, &after_ao, TEST_CNT_GOOD)) {
+ close(sk);
+ return;
+ }
+
+ if (!check_restore) {
+ test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64,
+ tst, before_aogood, after_aogood);
+ close(sk);
+ return;
+ }
+
+ test_enable_repair(sk);
+ test_sock_checkpoint(sk, &img, &addr);
+#ifdef IPV6_TEST
+ addr.sin6_port = htons(port + 1);
+#else
+ addr.sin_port = htons(port + 1);
+#endif
+ test_ao_checkpoint(sk, &ao_img);
+ test_kill_sk(sk);
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ test_enable_repair(sk);
+ __test_sock_restore(sk, "lo", &img, &addr, &addr, sizeof(addr));
+ if (different_keyids) {
+ if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
+ local_addr, -1, 7, 5))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
+ local_addr, -1, 5, 7))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ } else {
+ if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0,
+ local_addr, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ }
+ test_ao_restore(sk, &ao_img);
+ test_disable_repair(sk);
+ test_sock_state_free(&img);
+ if (test_client_verify(sk, 100, nr_packets, TEST_TIMEOUT_SEC)) {
+ test_fail("%s: tcp connection verify failed", tst);
+ close(sk);
+ return;
+ }
+ ns_after = netstat_read();
+ after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
+ /* to debug: netstat_print_diff(ns_before, ns_after); */
+ netstat_free(ns_before);
+ netstat_free(ns_after);
+ close(sk);
+ if (after_aogood <= before_aogood) {
+ test_fail("%s: TCPAOGood counter mismatch: %zu <= %zu",
+ tst, after_aogood, before_aogood);
+ return;
+ }
+ test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64,
+ tst, before_aogood, after_aogood);
+}
+
+static void *client_fn(void *arg)
+{
+ unsigned int port = test_server_port;
+
+ setup_lo_intf("lo");
+
+ tcp_self_connect("self-connect(same keyids)", port++, false, false);
+ tcp_self_connect("self-connect(different keyids)", port++, true, false);
+ tcp_self_connect("self-connect(restore)", port, false, true);
+ port += 2;
+ tcp_self_connect("self-connect(restore, different keyids)", port, true, true);
+ port += 2;
+
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(4, client_fn, NULL);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/seq-ext.c b/tools/testing/selftests/net/tcp_ao/seq-ext.c
new file mode 100644
index 000000000000..ad4e77d6823e
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/seq-ext.c
@@ -0,0 +1,245 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Check that after SEQ number wrap-around:
+ * 1. SEQ-extension has upper bytes set
+ * 2. TCP conneciton is alive and no TCPAOBad segments
+ * In order to test (2), the test doesn't just adjust seq number for a queue
+ * on a connected socket, but migrates it to another sk+port number, so
+ * that there won't be any delayed packets that will fail to verify
+ * with the new SEQ numbers.
+ */
+#include <inttypes.h>
+#include "aolib.h"
+
+const unsigned int nr_packets = 1000;
+const unsigned int msg_len = 1000;
+const unsigned int quota = nr_packets * msg_len;
+unsigned int client_new_port;
+
+/* Move them closer to roll-over */
+static void test_adjust_seqs(struct tcp_sock_state *img,
+ struct tcp_ao_repair *ao_img,
+ bool server)
+{
+ uint32_t new_seq1, new_seq2;
+
+ /* make them roll-over during quota, but on different segments */
+ if (server) {
+ new_seq1 = ((uint32_t)-1) - msg_len;
+ new_seq2 = ((uint32_t)-1) - (quota - 2 * msg_len);
+ } else {
+ new_seq1 = ((uint32_t)-1) - (quota - 2 * msg_len);
+ new_seq2 = ((uint32_t)-1) - msg_len;
+ }
+
+ img->in.seq = new_seq1;
+ img->trw.snd_wl1 = img->in.seq - msg_len;
+ img->out.seq = new_seq2;
+ img->trw.rcv_wup = img->in.seq;
+}
+
+static int test_sk_restore(struct tcp_sock_state *img,
+ struct tcp_ao_repair *ao_img, sockaddr_af *saddr,
+ const union tcp_addr daddr, unsigned int dport,
+ struct tcp_ao_counters *cnt)
+{
+ int sk;
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ test_enable_repair(sk);
+ test_sock_restore(sk, img, saddr, daddr, dport);
+ if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, daddr, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ test_ao_restore(sk, ao_img);
+
+ if (test_get_tcp_ao_counters(sk, cnt))
+ test_error("test_get_tcp_ao_counters()");
+
+ test_disable_repair(sk);
+ test_sock_state_free(img);
+ return sk;
+}
+
+static void *server_fn(void *arg)
+{
+ uint64_t before_good, after_good, after_bad;
+ struct tcp_ao_counters ao1, ao2;
+ struct tcp_sock_state img;
+ struct tcp_ao_repair ao_img;
+ sockaddr_af saddr;
+ ssize_t bytes;
+ int sk, lsk;
+
+ lsk = test_listen_socket(this_ip_addr, test_server_port, 1);
+
+ if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ synchronize_threads(); /* 1: MKT added => connect() */
+
+ if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
+ test_error("test_wait_fd()");
+
+ sk = accept(lsk, NULL, NULL);
+ if (sk < 0)
+ test_error("accept()");
+
+ synchronize_threads(); /* 2: accepted => send data */
+ close(lsk);
+
+ bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
+ if (bytes != quota) {
+ if (bytes > 0)
+ test_fail("server served: %zd", bytes);
+ else
+ test_fail("server returned: %zd", bytes);
+ goto out;
+ }
+
+ before_good = netstat_get_one("TCPAOGood", NULL);
+
+ synchronize_threads(); /* 3: restore the connection on another port */
+
+ test_enable_repair(sk);
+ test_sock_checkpoint(sk, &img, &saddr);
+ test_ao_checkpoint(sk, &ao_img);
+ test_kill_sk(sk);
+#ifdef IPV6_TEST
+ saddr.sin6_port = htons(ntohs(saddr.sin6_port) + 1);
+#else
+ saddr.sin_port = htons(ntohs(saddr.sin_port) + 1);
+#endif
+ test_adjust_seqs(&img, &ao_img, true);
+ synchronize_threads(); /* 4: dump finished */
+ sk = test_sk_restore(&img, &ao_img, &saddr, this_ip_dest,
+ client_new_port, &ao1);
+
+ synchronize_threads(); /* 5: verify counters during SEQ-number rollover */
+ bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
+ if (bytes != quota) {
+ if (bytes > 0)
+ test_fail("server served: %zd", bytes);
+ else
+ test_fail("server returned: %zd", bytes);
+ } else {
+ test_ok("server alive");
+ }
+
+ if (test_get_tcp_ao_counters(sk, &ao2))
+ test_error("test_get_tcp_ao_counters()");
+ after_good = netstat_get_one("TCPAOGood", NULL);
+
+ test_tcp_ao_counters_cmp(NULL, &ao1, &ao2, TEST_CNT_GOOD);
+
+ if (after_good <= before_good) {
+ test_fail("TCPAOGood counter did not increase: %zu <= %zu",
+ after_good, before_good);
+ } else {
+ test_ok("TCPAOGood counter increased %zu => %zu",
+ before_good, after_good);
+ }
+ after_bad = netstat_get_one("TCPAOBad", NULL);
+ if (after_bad)
+ test_fail("TCPAOBad counter is non-zero: %zu", after_bad);
+ else
+ test_ok("TCPAOBad counter didn't increase");
+ test_enable_repair(sk);
+ test_ao_checkpoint(sk, &ao_img);
+ if (ao_img.snd_sne && ao_img.rcv_sne) {
+ test_ok("SEQ extension incremented: %u/%u",
+ ao_img.snd_sne, ao_img.rcv_sne);
+ } else {
+ test_fail("SEQ extension was not incremented: %u/%u",
+ ao_img.snd_sne, ao_img.rcv_sne);
+ }
+
+ synchronize_threads(); /* 6: verified => closed */
+out:
+ close(sk);
+ return NULL;
+}
+
+static void *client_fn(void *arg)
+{
+ uint64_t before_good, after_good, after_bad;
+ struct tcp_ao_counters ao1, ao2;
+ struct tcp_sock_state img;
+ struct tcp_ao_repair ao_img;
+ sockaddr_af saddr;
+ int sk;
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ synchronize_threads(); /* 1: MKT added => connect() */
+ if (test_connect_socket(sk, this_ip_dest, test_server_port) <= 0)
+ test_error("failed to connect()");
+
+ synchronize_threads(); /* 2: accepted => send data */
+ if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC)) {
+ test_fail("pre-migrate verify failed");
+ return NULL;
+ }
+
+ before_good = netstat_get_one("TCPAOGood", NULL);
+
+ synchronize_threads(); /* 3: restore the connection on another port */
+ test_enable_repair(sk);
+ test_sock_checkpoint(sk, &img, &saddr);
+ test_ao_checkpoint(sk, &ao_img);
+ test_kill_sk(sk);
+#ifdef IPV6_TEST
+ client_new_port = ntohs(saddr.sin6_port) + 1;
+ saddr.sin6_port = htons(ntohs(saddr.sin6_port) + 1);
+#else
+ client_new_port = ntohs(saddr.sin_port) + 1;
+ saddr.sin_port = htons(ntohs(saddr.sin_port) + 1);
+#endif
+ test_adjust_seqs(&img, &ao_img, false);
+ synchronize_threads(); /* 4: dump finished */
+ sk = test_sk_restore(&img, &ao_img, &saddr, this_ip_dest,
+ test_server_port + 1, &ao1);
+
+ synchronize_threads(); /* 5: verify counters during SEQ-number rollover */
+ if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
+ test_fail("post-migrate verify failed");
+ else
+ test_ok("post-migrate connection alive");
+
+ if (test_get_tcp_ao_counters(sk, &ao2))
+ test_error("test_get_tcp_ao_counters()");
+ after_good = netstat_get_one("TCPAOGood", NULL);
+
+ test_tcp_ao_counters_cmp(NULL, &ao1, &ao2, TEST_CNT_GOOD);
+
+ if (after_good <= before_good) {
+ test_fail("TCPAOGood counter did not increase: %zu <= %zu",
+ after_good, before_good);
+ } else {
+ test_ok("TCPAOGood counter increased %zu => %zu",
+ before_good, after_good);
+ }
+ after_bad = netstat_get_one("TCPAOBad", NULL);
+ if (after_bad)
+ test_fail("TCPAOBad counter is non-zero: %zu", after_bad);
+ else
+ test_ok("TCPAOBad counter didn't increase");
+
+ synchronize_threads(); /* 6: verified => closed */
+ close(sk);
+
+ synchronize_threads(); /* don't race to exit: let server exit() */
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(7, server_fn, client_fn);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/setsockopt-closed.c b/tools/testing/selftests/net/tcp_ao/setsockopt-closed.c
new file mode 100644
index 000000000000..452de131fa3a
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/setsockopt-closed.c
@@ -0,0 +1,835 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Author: Dmitry Safonov <dima@arista.com> */
+#include <inttypes.h>
+#include "../../../../include/linux/kernel.h"
+#include "aolib.h"
+
+static union tcp_addr tcp_md5_client;
+
+static int test_port = 7788;
+static void make_listen(int sk)
+{
+ sockaddr_af addr;
+
+ tcp_addr_to_sockaddr_in(&addr, &this_ip_addr, htons(test_port++));
+ if (bind(sk, (struct sockaddr *)&addr, sizeof(addr)) < 0)
+ test_error("bind()");
+ if (listen(sk, 1))
+ test_error("listen()");
+}
+
+static void test_vefify_ao_info(int sk, struct tcp_ao_info_opt *info,
+ const char *tst)
+{
+ struct tcp_ao_info_opt tmp;
+ socklen_t len = sizeof(tmp);
+
+ if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, &tmp, &len))
+ test_error("getsockopt(TCP_AO_INFO) failed");
+
+#define __cmp_ao(member) \
+do { \
+ if (info->member != tmp.member) { \
+ test_fail("%s: getsockopt(): " __stringify(member) " %zu != %zu", \
+ tst, (size_t)info->member, (size_t)tmp.member); \
+ return; \
+ } \
+} while(0)
+ if (info->set_current)
+ __cmp_ao(current_key);
+ if (info->set_rnext)
+ __cmp_ao(rnext);
+ if (info->set_counters) {
+ __cmp_ao(pkt_good);
+ __cmp_ao(pkt_bad);
+ __cmp_ao(pkt_key_not_found);
+ __cmp_ao(pkt_ao_required);
+ __cmp_ao(pkt_dropped_icmp);
+ }
+ __cmp_ao(ao_required);
+ __cmp_ao(accept_icmps);
+
+ test_ok("AO info get: %s", tst);
+#undef __cmp_ao
+}
+
+static void __setsockopt_checked(int sk, int optname, bool get,
+ void *optval, socklen_t *len,
+ int err, const char *tst, const char *tst2)
+{
+ int ret;
+
+ if (!tst)
+ tst = "";
+ if (!tst2)
+ tst2 = "";
+
+ errno = 0;
+ if (get)
+ ret = getsockopt(sk, IPPROTO_TCP, optname, optval, len);
+ else
+ ret = setsockopt(sk, IPPROTO_TCP, optname, optval, *len);
+ if (ret == -1) {
+ if (errno == err)
+ test_ok("%s%s", tst ?: "", tst2 ?: "");
+ else
+ test_fail("%s%s: %setsockopt() failed",
+ tst, tst2, get ? "g" : "s");
+ close(sk);
+ return;
+ }
+
+ if (err) {
+ test_fail("%s%s: %setsockopt() was expected to fail with %d",
+ tst, tst2, get ? "g" : "s", err);
+ } else {
+ test_ok("%s%s", tst ?: "", tst2 ?: "");
+ if (optname == TCP_AO_ADD_KEY) {
+ test_verify_socket_key(sk, optval);
+ } else if (optname == TCP_AO_INFO && !get) {
+ test_vefify_ao_info(sk, optval, tst2);
+ } else if (optname == TCP_AO_GET_KEYS) {
+ if (*len != sizeof(struct tcp_ao_getsockopt))
+ test_fail("%s%s: get keys returned wrong tcp_ao_getsockopt size",
+ tst, tst2);
+ }
+ }
+ close(sk);
+}
+
+static void setsockopt_checked(int sk, int optname, void *optval,
+ int err, const char *tst)
+{
+ const char *cmd = NULL;
+ socklen_t len;
+
+ switch (optname) {
+ case TCP_AO_ADD_KEY:
+ cmd = "key add: ";
+ len = sizeof(struct tcp_ao_add);
+ break;
+ case TCP_AO_DEL_KEY:
+ cmd = "key del: ";
+ len = sizeof(struct tcp_ao_del);
+ break;
+ case TCP_AO_INFO:
+ cmd = "AO info set: ";
+ len = sizeof(struct tcp_ao_info_opt);
+ break;
+ default:
+ break;
+ }
+
+ __setsockopt_checked(sk, optname, false, optval, &len, err, cmd, tst);
+}
+
+static int prepare_defs(int cmd, void *optval)
+{
+ int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+
+ if (sk < 0)
+ test_error("socket()");
+
+ switch (cmd) {
+ case TCP_AO_ADD_KEY: {
+ struct tcp_ao_add *add = optval;
+
+ if (test_prepare_def_key(add, DEFAULT_TEST_PASSWORD, 0, this_ip_dest,
+ -1, 0, 100, 100))
+ test_error("prepare default tcp_ao_add");
+ break;
+ }
+ case TCP_AO_DEL_KEY: {
+ struct tcp_ao_del *del = optval;
+
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest,
+ DEFAULT_TEST_PREFIX, 100, 100))
+ test_error("add default key");
+ memset(del, 0, sizeof(struct tcp_ao_del));
+ del->sndid = 100;
+ del->rcvid = 100;
+ del->prefix = DEFAULT_TEST_PREFIX;
+ tcp_addr_to_sockaddr_in(&del->addr, &this_ip_dest, 0);
+ break;
+ }
+ case TCP_AO_INFO: {
+ struct tcp_ao_info_opt *info = optval;
+
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest,
+ DEFAULT_TEST_PREFIX, 100, 100))
+ test_error("add default key");
+ memset(info, 0, sizeof(struct tcp_ao_info_opt));
+ break;
+ }
+ case TCP_AO_GET_KEYS: {
+ struct tcp_ao_getsockopt *get = optval;
+
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest,
+ DEFAULT_TEST_PREFIX, 100, 100))
+ test_error("add default key");
+ memset(get, 0, sizeof(struct tcp_ao_getsockopt));
+ get->nkeys = 1;
+ get->get_all = 1;
+ break;
+ }
+ default:
+ test_error("unknown cmd");
+ }
+
+ return sk;
+}
+
+static void test_extend(int cmd, bool get, const char *tst, socklen_t under_size)
+{
+ struct {
+ union {
+ struct tcp_ao_add add;
+ struct tcp_ao_del del;
+ struct tcp_ao_getsockopt get;
+ struct tcp_ao_info_opt info;
+ };
+ char *extend[100];
+ } tmp_opt;
+ socklen_t extended_size = sizeof(tmp_opt);
+ int sk;
+
+ memset(&tmp_opt, 0, sizeof(tmp_opt));
+ sk = prepare_defs(cmd, &tmp_opt);
+ __setsockopt_checked(sk, cmd, get, &tmp_opt, &under_size,
+ EINVAL, tst, ": minimum size");
+
+ memset(&tmp_opt, 0, sizeof(tmp_opt));
+ sk = prepare_defs(cmd, &tmp_opt);
+ __setsockopt_checked(sk, cmd, get, &tmp_opt, &extended_size,
+ 0, tst, ": extended size");
+
+ memset(&tmp_opt, 0, sizeof(tmp_opt));
+ sk = prepare_defs(cmd, &tmp_opt);
+ __setsockopt_checked(sk, cmd, get, NULL, &extended_size,
+ EFAULT, tst, ": null optval");
+
+ if (get) {
+ memset(&tmp_opt, 0, sizeof(tmp_opt));
+ sk = prepare_defs(cmd, &tmp_opt);
+ __setsockopt_checked(sk, cmd, get, &tmp_opt, NULL,
+ EFAULT, tst, ": null optlen");
+ }
+}
+
+static void extend_tests(void)
+{
+ test_extend(TCP_AO_ADD_KEY, false, "AO add",
+ offsetof(struct tcp_ao_add, key));
+ test_extend(TCP_AO_DEL_KEY, false, "AO del",
+ offsetof(struct tcp_ao_del, keyflags));
+ test_extend(TCP_AO_INFO, false, "AO set info",
+ offsetof(struct tcp_ao_info_opt, pkt_dropped_icmp));
+ test_extend(TCP_AO_INFO, true, "AO get info", -1);
+ test_extend(TCP_AO_GET_KEYS, true, "AO get keys", -1);
+}
+
+static void test_optmem_limit(void)
+{
+ size_t i, keys_limit, current_optmem = test_get_optmem();
+ struct tcp_ao_add ao;
+ union tcp_addr net = {};
+ int sk;
+
+ if (inet_pton(TEST_FAMILY, TEST_NETWORK, &net) != 1)
+ test_error("Can't convert ip address %s", TEST_NETWORK);
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ keys_limit = current_optmem / KERNEL_TCP_AO_KEY_SZ_ROUND_UP;
+ for (i = 0;; i++) {
+ union tcp_addr key_peer;
+ int err;
+
+ key_peer = gen_tcp_addr(net, i + 1);
+ tcp_addr_to_sockaddr_in(&ao.addr, &key_peer, 0);
+ err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY,
+ &ao, sizeof(ao));
+ if (!err) {
+ /*
+ * TCP_AO_ADD_KEY should be the same order as the real
+ * sizeof(struct tcp_ao_key) in kernel.
+ */
+ if (i <= keys_limit * 10)
+ continue;
+ test_fail("optmem limit test failed: added %zu key", i);
+ break;
+ }
+ if (i < keys_limit) {
+ test_fail("optmem limit test failed: couldn't add %zu key", i);
+ break;
+ }
+ test_ok("optmem limit was hit on adding %zu key", i);
+ break;
+ }
+ close(sk);
+}
+
+static void test_einval_add_key(void)
+{
+ struct tcp_ao_add ao;
+ int sk;
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.keylen = TCP_AO_MAXKEYLEN + 1;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "too big keylen");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.reserved = 1;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "using reserved padding");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.reserved2 = 1;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "using reserved2 padding");
+
+ /* tcp_ao_verify_ipv{4,6}() checks */
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.addr.ss_family = AF_UNIX;
+ memcpy(&ao.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY));
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "wrong address family");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ tcp_addr_to_sockaddr_in(&ao.addr, &this_ip_dest, 1234);
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "port (unsupported)");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.prefix = 0;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "no prefix, addr");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.prefix = 0;
+ memcpy(&ao.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY));
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, 0, "no prefix, any addr");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.prefix = 32;
+ memcpy(&ao.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY));
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "prefix, any addr");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.prefix = 129;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "too big prefix");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.prefix = 2;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "too short prefix");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.keyflags = (uint8_t)(-1);
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "bad key flags");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ make_listen(sk);
+ ao.set_current = 1;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "add current key on a listen socket");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ make_listen(sk);
+ ao.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "add rnext key on a listen socket");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ make_listen(sk);
+ ao.set_current = 1;
+ ao.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "add current+rnext key on a listen socket");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.set_current = 1;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, 0, "add key and set as current");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, 0, "add key and set as rnext");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.set_current = 1;
+ ao.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, 0, "add key and set as current+rnext");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.ifindex = 42;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL,
+ "ifindex without TCP_AO_KEYF_IFNINDEX");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.keyflags |= TCP_AO_KEYF_IFINDEX;
+ ao.ifindex = 42;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "non-existent VRF");
+ /*
+ * tcp_md5_do_lookup{,_any_l3index}() are checked in unsigned-md5
+ * see client_vrf_tests().
+ */
+
+ test_optmem_limit();
+
+ /* tcp_ao_parse_crypto() */
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao.maclen = 100;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EMSGSIZE, "maclen bigger than TCP hdr");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ strcpy(ao.alg_name, "imaginary hash algo");
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, ENOENT, "bad algo");
+}
+
+static void test_einval_del_key(void)
+{
+ struct tcp_ao_del del;
+ int sk;
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.reserved = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "using reserved padding");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.reserved2 = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "using reserved2 padding");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ make_listen(sk);
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0))
+ test_error("add key");
+ del.set_current = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "del and set current key on a listen socket");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ make_listen(sk);
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0))
+ test_error("add key");
+ del.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "del and set rnext key on a listen socket");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ make_listen(sk);
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0))
+ test_error("add key");
+ del.set_current = 1;
+ del.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "del and set current+rnext key on a listen socket");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.keyflags = (uint8_t)(-1);
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "bad key flags");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.ifindex = 42;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL,
+ "ifindex without TCP_AO_KEYF_IFNINDEX");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.keyflags |= TCP_AO_KEYF_IFINDEX;
+ del.ifindex = 42;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "non-existent VRF");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.set_current = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set non-existing current key");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set non-existing rnext key");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.set_current = 1;
+ del.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set non-existing current+rnext key");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0))
+ test_error("add key");
+ del.set_current = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, 0, "set current key");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0))
+ test_error("add key");
+ del.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, 0, "set rnext key");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0))
+ test_error("add key");
+ del.set_current = 1;
+ del.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, 0, "set current+rnext key");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.set_current = 1;
+ del.current_key = 100;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set as current key to be removed");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.set_rnext = 1;
+ del.rnext = 100;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set as rnext key to be removed");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.set_current = 1;
+ del.current_key = 100;
+ del.set_rnext = 1;
+ del.rnext = 100;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set as current+rnext key to be removed");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.del_async = 1;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "async on non-listen");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.sndid = 101;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "non-existing sndid");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ del.rcvid = 101;
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "non-existing rcvid");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ tcp_addr_to_sockaddr_in(&del.addr, &this_ip_addr, 0);
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "incorrect addr");
+
+ sk = prepare_defs(TCP_AO_DEL_KEY, &del);
+ setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, 0, "correct key delete");
+}
+
+static void test_einval_ao_info(void)
+{
+ struct tcp_ao_info_opt info;
+ int sk;
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ make_listen(sk);
+ info.set_current = 1;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, EINVAL, "set current key on a listen socket");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ make_listen(sk);
+ info.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, EINVAL, "set rnext key on a listen socket");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ make_listen(sk);
+ info.set_current = 1;
+ info.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, EINVAL, "set current+rnext key on a listen socket");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.reserved = 1;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, EINVAL, "using reserved padding");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.reserved2 = 1;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, EINVAL, "using reserved2 padding");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.accept_icmps = 1;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "accept_icmps");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.ao_required = 1;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "ao required");
+
+ if (!should_skip_test("ao required with MD5 key", KCONFIG_TCP_MD5)) {
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.ao_required = 1;
+ if (test_set_md5(sk, tcp_md5_client, TEST_PREFIX, -1,
+ "long long secret")) {
+ test_error("setsockopt(TCP_MD5SIG_EXT)");
+ close(sk);
+ } else {
+ setsockopt_checked(sk, TCP_AO_INFO, &info, EKEYREJECTED,
+ "ao required with MD5 key");
+ }
+ }
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.set_current = 1;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, ENOENT, "set non-existing current key");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, ENOENT, "set non-existing rnext key");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.set_current = 1;
+ info.set_rnext = 1;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, ENOENT, "set non-existing current+rnext key");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.set_current = 1;
+ info.current_key = 100;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "set current key");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.set_rnext = 1;
+ info.rnext = 100;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "set rnext key");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.set_current = 1;
+ info.set_rnext = 1;
+ info.current_key = 100;
+ info.rnext = 100;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "set current+rnext key");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ info.set_counters = 1;
+ info.pkt_good = 321;
+ info.pkt_bad = 888;
+ info.pkt_key_not_found = 654;
+ info.pkt_ao_required = 987654;
+ info.pkt_dropped_icmp = 10000;
+ setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "set counters");
+
+ sk = prepare_defs(TCP_AO_INFO, &info);
+ setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "no-op");
+}
+
+static void getsockopt_checked(int sk, struct tcp_ao_getsockopt *optval,
+ int err, const char *tst)
+{
+ socklen_t len = sizeof(struct tcp_ao_getsockopt);
+
+ __setsockopt_checked(sk, TCP_AO_GET_KEYS, true, optval, &len, err,
+ "get keys: ", tst);
+}
+
+static void test_einval_get_keys(void)
+{
+ struct tcp_ao_getsockopt out;
+ int sk;
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+ getsockopt_checked(sk, &out, ENOENT, "no ao_info");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ getsockopt_checked(sk, &out, 0, "proper tcp_ao_get_mkts()");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.pkt_good = 643;
+ getsockopt_checked(sk, &out, EINVAL, "set out-only pkt_good counter");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.pkt_bad = 94;
+ getsockopt_checked(sk, &out, EINVAL, "set out-only pkt_bad counter");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.keyflags = (uint8_t)(-1);
+ getsockopt_checked(sk, &out, EINVAL, "bad keyflags");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.ifindex = 42;
+ getsockopt_checked(sk, &out, EINVAL,
+ "ifindex without TCP_AO_KEYF_IFNINDEX");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.reserved = 1;
+ getsockopt_checked(sk, &out, EINVAL, "using reserved field");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.prefix = 0;
+ tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0);
+ getsockopt_checked(sk, &out, EINVAL, "no prefix, addr");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.prefix = 0;
+ memcpy(&out.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY));
+ getsockopt_checked(sk, &out, 0, "no prefix, any addr");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.prefix = 32;
+ memcpy(&out.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY));
+ getsockopt_checked(sk, &out, EINVAL, "prefix, any addr");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.prefix = 129;
+ tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0);
+ getsockopt_checked(sk, &out, EINVAL, "too big prefix");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.prefix = 2;
+ tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0);
+ getsockopt_checked(sk, &out, EINVAL, "too short prefix");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.prefix = DEFAULT_TEST_PREFIX;
+ tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0);
+ getsockopt_checked(sk, &out, 0, "prefix + addr");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 1;
+ out.prefix = DEFAULT_TEST_PREFIX;
+ getsockopt_checked(sk, &out, EINVAL, "get_all + prefix");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 1;
+ tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0);
+ getsockopt_checked(sk, &out, EINVAL, "get_all + addr");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 1;
+ out.sndid = 1;
+ getsockopt_checked(sk, &out, EINVAL, "get_all + sndid");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 1;
+ out.rcvid = 1;
+ getsockopt_checked(sk, &out, EINVAL, "get_all + rcvid");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.is_current = 1;
+ out.prefix = DEFAULT_TEST_PREFIX;
+ getsockopt_checked(sk, &out, EINVAL, "current + prefix");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.is_current = 1;
+ tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0);
+ getsockopt_checked(sk, &out, EINVAL, "current + addr");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.is_current = 1;
+ out.sndid = 1;
+ getsockopt_checked(sk, &out, EINVAL, "current + sndid");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.is_current = 1;
+ out.rcvid = 1;
+ getsockopt_checked(sk, &out, EINVAL, "current + rcvid");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.is_rnext = 1;
+ out.prefix = DEFAULT_TEST_PREFIX;
+ getsockopt_checked(sk, &out, EINVAL, "rnext + prefix");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.is_rnext = 1;
+ tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0);
+ getsockopt_checked(sk, &out, EINVAL, "rnext + addr");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.is_rnext = 1;
+ out.sndid = 1;
+ getsockopt_checked(sk, &out, EINVAL, "rnext + sndid");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.is_rnext = 1;
+ out.rcvid = 1;
+ getsockopt_checked(sk, &out, EINVAL, "rnext + rcvid");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 1;
+ out.is_current = 1;
+ getsockopt_checked(sk, &out, EINVAL, "get_all + current");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 1;
+ out.is_rnext = 1;
+ getsockopt_checked(sk, &out, EINVAL, "get_all + rnext");
+
+ sk = prepare_defs(TCP_AO_GET_KEYS, &out);
+ out.get_all = 0;
+ out.is_current = 1;
+ out.is_rnext = 1;
+ getsockopt_checked(sk, &out, 0, "current + rnext");
+}
+
+static void einval_tests(void)
+{
+ test_einval_add_key();
+ test_einval_del_key();
+ test_einval_ao_info();
+ test_einval_get_keys();
+}
+
+static void duplicate_tests(void)
+{
+ union tcp_addr network_dup;
+ struct tcp_ao_add ao, ao2;
+ int sk;
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao)))
+ test_error("setsockopt()");
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: full copy");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ ao2 = ao;
+ memcpy(&ao2.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY));
+ ao2.prefix = 0;
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao2, sizeof(ao)))
+ test_error("setsockopt()");
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: any addr key on the socket");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao)))
+ test_error("setsockopt()");
+ memcpy(&ao.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY));
+ ao.prefix = 0;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: add any addr key");
+
+ if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_dup) != 1)
+ test_error("Can't convert ip address %s", TEST_NETWORK);
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao)))
+ test_error("setsockopt()");
+ if (test_prepare_def_key(&ao, "password", 0, network_dup,
+ 16, 0, 100, 100))
+ test_error("prepare default tcp_ao_add");
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: add any addr for the same subnet");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao)))
+ test_error("setsockopt()");
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: full copy of a key");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao)))
+ test_error("setsockopt()");
+ ao.rcvid = 101;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: RecvID differs");
+
+ sk = prepare_defs(TCP_AO_ADD_KEY, &ao);
+ if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao)))
+ test_error("setsockopt()");
+ ao.sndid = 101;
+ setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: SendID differs");
+}
+
+static void *client_fn(void *arg)
+{
+ if (inet_pton(TEST_FAMILY, __TEST_CLIENT_IP(2), &tcp_md5_client) != 1)
+ test_error("Can't convert ip address");
+ extend_tests();
+ einval_tests();
+ duplicate_tests();
+ /*
+ * TODO: check getsockopt(TCP_AO_GET_KEYS) with different filters
+ * returning proper nr & keys;
+ */
+
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(120, client_fn, NULL);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/tcp_ao/settings b/tools/testing/selftests/net/tcp_ao/settings
new file mode 100644
index 000000000000..6091b45d226b
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/settings
@@ -0,0 +1 @@
+timeout=120
diff --git a/tools/testing/selftests/net/tcp_ao/unsigned-md5.c b/tools/testing/selftests/net/tcp_ao/unsigned-md5.c
new file mode 100644
index 000000000000..6b59a652159f
--- /dev/null
+++ b/tools/testing/selftests/net/tcp_ao/unsigned-md5.c
@@ -0,0 +1,741 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Author: Dmitry Safonov <dima@arista.com> */
+#include <inttypes.h>
+#include "aolib.h"
+
+#define fault(type) (inj == FAULT_ ## type)
+static const char *md5_password = "Some evil genius, enemy to mankind, must have been the first contriver.";
+static const char *ao_password = DEFAULT_TEST_PASSWORD;
+
+static union tcp_addr client2;
+static union tcp_addr client3;
+
+static const int test_vrf_ifindex = 200;
+static const uint8_t test_vrf_tabid = 42;
+static void setup_vrfs(void)
+{
+ int err;
+
+ if (!kernel_config_has(KCONFIG_NET_VRF))
+ return;
+
+ err = add_vrf("ksft-vrf", test_vrf_tabid, test_vrf_ifindex, -1);
+ if (err)
+ test_error("Failed to add a VRF: %d", err);
+
+ err = link_set_up("ksft-vrf");
+ if (err)
+ test_error("Failed to bring up a VRF");
+
+ err = ip_route_add_vrf(veth_name, TEST_FAMILY,
+ this_ip_addr, this_ip_dest, test_vrf_tabid);
+ if (err)
+ test_error("Failed to add a route to VRF: %d", err);
+}
+
+static void try_accept(const char *tst_name, unsigned int port,
+ union tcp_addr *md5_addr, uint8_t md5_prefix,
+ union tcp_addr *ao_addr, uint8_t ao_prefix,
+ bool set_ao_required,
+ uint8_t sndid, uint8_t rcvid, uint8_t vrf,
+ const char *cnt_name, test_cnt cnt_expected,
+ int needs_tcp_md5, fault_t inj)
+{
+ struct tcp_ao_counters ao_cnt1, ao_cnt2;
+ uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */
+ int lsk, err, sk = 0;
+ time_t timeout;
+
+ if (needs_tcp_md5 && should_skip_test(tst_name, KCONFIG_TCP_MD5))
+ return;
+
+ lsk = test_listen_socket(this_ip_addr, port, 1);
+
+ if (md5_addr && test_set_md5(lsk, *md5_addr, md5_prefix, -1, md5_password))
+ test_error("setsockopt(TCP_MD5SIG_EXT)");
+
+ if (ao_addr && test_add_key(lsk, ao_password,
+ *ao_addr, ao_prefix, sndid, rcvid))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ if (set_ao_required && test_set_ao_flags(lsk, true, false))
+ test_error("setsockopt(TCP_AO_INFO)");
+
+ if (cnt_name)
+ before_cnt = netstat_get_one(cnt_name, NULL);
+ if (ao_addr && test_get_tcp_ao_counters(lsk, &ao_cnt1))
+ test_error("test_get_tcp_ao_counters()");
+
+ synchronize_threads(); /* preparations done */
+
+ timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
+ err = test_wait_fd(lsk, timeout, 0);
+ if (err == -ETIMEDOUT) {
+ if (!fault(TIMEOUT))
+ test_fail("timed out for accept()");
+ } else if (err < 0) {
+ test_error("test_wait_fd()");
+ } else {
+ if (fault(TIMEOUT))
+ test_fail("ready to accept");
+
+ sk = accept(lsk, NULL, NULL);
+ if (sk < 0) {
+ test_error("accept()");
+ } else {
+ if (fault(TIMEOUT))
+ test_fail("%s: accepted", tst_name);
+ }
+ }
+
+ if (ao_addr && test_get_tcp_ao_counters(lsk, &ao_cnt2))
+ test_error("test_get_tcp_ao_counters()");
+ close(lsk);
+
+ if (!cnt_name) {
+ test_ok("%s: no counter checks", tst_name);
+ goto out;
+ }
+
+ after_cnt = netstat_get_one(cnt_name, NULL);
+
+ if (after_cnt <= before_cnt) {
+ test_fail("%s: %s counter did not increase: %zu <= %zu",
+ tst_name, cnt_name, after_cnt, before_cnt);
+ } else {
+ test_ok("%s: counter %s increased %zu => %zu",
+ tst_name, cnt_name, before_cnt, after_cnt);
+ }
+ if (ao_addr)
+ test_tcp_ao_counters_cmp(tst_name, &ao_cnt1, &ao_cnt2, cnt_expected);
+
+out:
+ synchronize_threads(); /* test_kill_sk() */
+ if (sk > 0)
+ test_kill_sk(sk);
+}
+
+static void server_add_routes(void)
+{
+ int family = TEST_FAMILY;
+
+ synchronize_threads(); /* client_add_ips() */
+
+ if (ip_route_add(veth_name, family, this_ip_addr, client2))
+ test_error("Failed to add route");
+ if (ip_route_add(veth_name, family, this_ip_addr, client3))
+ test_error("Failed to add route");
+}
+
+static void server_add_fail_tests(unsigned int *port)
+{
+ union tcp_addr addr_any = {};
+
+ try_accept("TCP-AO established: add TCP-MD5 key", (*port)++, NULL, 0,
+ &addr_any, 0, 0, 100, 100, 0, "TCPAOGood", TEST_CNT_GOOD,
+ 1, 0);
+ try_accept("TCP-MD5 established: add TCP-AO key", (*port)++, &addr_any,
+ 0, NULL, 0, 0, 0, 0, 0, NULL, 0, 1, 0);
+ try_accept("non-signed established: add TCP-AO key", (*port)++, NULL, 0,
+ NULL, 0, 0, 0, 0, 0, "CurrEstab", 0, 0, 0);
+}
+
+static void server_vrf_tests(unsigned int *port)
+{
+ setup_vrfs();
+}
+
+static void *server_fn(void *arg)
+{
+ unsigned int port = test_server_port;
+ union tcp_addr addr_any = {};
+
+ server_add_routes();
+
+ try_accept("AO server (INADDR_ANY): AO client", port++, NULL, 0,
+ &addr_any, 0, 0, 100, 100, 0, "TCPAOGood",
+ TEST_CNT_GOOD, 0, 0);
+ try_accept("AO server (INADDR_ANY): MD5 client", port++, NULL, 0,
+ &addr_any, 0, 0, 100, 100, 0, "TCPMD5Unexpected",
+ 0, 1, FAULT_TIMEOUT);
+ try_accept("AO server (INADDR_ANY): no sign client", port++, NULL, 0,
+ &addr_any, 0, 0, 100, 100, 0, "TCPAORequired",
+ TEST_CNT_AO_REQUIRED, 0, FAULT_TIMEOUT);
+ try_accept("AO server (AO_REQUIRED): AO client", port++, NULL, 0,
+ &this_ip_dest, TEST_PREFIX, true,
+ 100, 100, 0, "TCPAOGood", TEST_CNT_GOOD, 0, 0);
+ try_accept("AO server (AO_REQUIRED): unsigned client", port++, NULL, 0,
+ &this_ip_dest, TEST_PREFIX, true,
+ 100, 100, 0, "TCPAORequired",
+ TEST_CNT_AO_REQUIRED, 0, FAULT_TIMEOUT);
+
+ try_accept("MD5 server (INADDR_ANY): AO client", port++, &addr_any, 0,
+ NULL, 0, 0, 0, 0, 0, "TCPAOKeyNotFound",
+ 0, 1, FAULT_TIMEOUT);
+ try_accept("MD5 server (INADDR_ANY): MD5 client", port++, &addr_any, 0,
+ NULL, 0, 0, 0, 0, 0, NULL, 0, 1, 0);
+ try_accept("MD5 server (INADDR_ANY): no sign client", port++, &addr_any,
+ 0, NULL, 0, 0, 0, 0, 0, "TCPMD5NotFound",
+ 0, 1, FAULT_TIMEOUT);
+
+ try_accept("no sign server: AO client", port++, NULL, 0,
+ NULL, 0, 0, 0, 0, 0, "TCPAOKeyNotFound",
+ TEST_CNT_AO_KEY_NOT_FOUND, 0, FAULT_TIMEOUT);
+ try_accept("no sign server: MD5 client", port++, NULL, 0,
+ NULL, 0, 0, 0, 0, 0, "TCPMD5Unexpected",
+ 0, 1, FAULT_TIMEOUT);
+ try_accept("no sign server: no sign client", port++, NULL, 0,
+ NULL, 0, 0, 0, 0, 0, "CurrEstab", 0, 0, 0);
+
+ try_accept("AO+MD5 server: AO client (matching)", port++,
+ &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, "TCPAOGood", TEST_CNT_GOOD, 1, 0);
+ try_accept("AO+MD5 server: AO client (misconfig, matching MD5)", port++,
+ &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND,
+ 1, FAULT_TIMEOUT);
+ try_accept("AO+MD5 server: AO client (misconfig, non-matching)", port++,
+ &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND,
+ 1, FAULT_TIMEOUT);
+ try_accept("AO+MD5 server: MD5 client (matching)", port++,
+ &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, NULL, 0, 1, 0);
+ try_accept("AO+MD5 server: MD5 client (misconfig, matching AO)", port++,
+ &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, "TCPMD5Unexpected", 0, 1, FAULT_TIMEOUT);
+ try_accept("AO+MD5 server: MD5 client (misconfig, non-matching)", port++,
+ &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, "TCPMD5Unexpected", 0, 1, FAULT_TIMEOUT);
+ try_accept("AO+MD5 server: no sign client (unmatched)", port++,
+ &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, "CurrEstab", 0, 1, 0);
+ try_accept("AO+MD5 server: no sign client (misconfig, matching AO)",
+ port++, &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, "TCPAORequired",
+ TEST_CNT_AO_REQUIRED, 1, FAULT_TIMEOUT);
+ try_accept("AO+MD5 server: no sign client (misconfig, matching MD5)",
+ port++, &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, "TCPMD5NotFound", 0, 1, FAULT_TIMEOUT);
+
+ try_accept("AO+MD5 server: client with both [TCP-MD5] and TCP-AO keys",
+ port++, &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, NULL, 0, 1, FAULT_TIMEOUT);
+ try_accept("AO+MD5 server: client with both TCP-MD5 and [TCP-AO] keys",
+ port++, &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0,
+ 100, 100, 0, NULL, 0, 1, FAULT_TIMEOUT);
+
+ server_add_fail_tests(&port);
+
+ server_vrf_tests(&port);
+
+ /* client exits */
+ synchronize_threads();
+ return NULL;
+}
+
+static int client_bind(int sk, union tcp_addr bind_addr)
+{
+#ifdef IPV6_TEST
+ struct sockaddr_in6 addr = {
+ .sin6_family = AF_INET6,
+ .sin6_port = 0,
+ .sin6_addr = bind_addr.a6,
+ };
+#else
+ struct sockaddr_in addr = {
+ .sin_family = AF_INET,
+ .sin_port = 0,
+ .sin_addr = bind_addr.a4,
+ };
+#endif
+ return bind(sk, &addr, sizeof(addr));
+}
+
+static void try_connect(const char *tst_name, unsigned int port,
+ union tcp_addr *md5_addr, uint8_t md5_prefix,
+ union tcp_addr *ao_addr, uint8_t ao_prefix,
+ uint8_t sndid, uint8_t rcvid, uint8_t vrf,
+ fault_t inj, int needs_tcp_md5, union tcp_addr *bind_addr)
+{
+ time_t timeout;
+ int sk, ret;
+
+ if (needs_tcp_md5 && should_skip_test(tst_name, KCONFIG_TCP_MD5))
+ return;
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ if (bind_addr && client_bind(sk, *bind_addr))
+ test_error("bind()");
+
+ if (md5_addr && test_set_md5(sk, *md5_addr, md5_prefix, -1, md5_password))
+ test_error("setsockopt(TCP_MD5SIG_EXT)");
+
+ if (ao_addr && test_add_key(sk, ao_password, *ao_addr,
+ ao_prefix, sndid, rcvid))
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+
+ synchronize_threads(); /* preparations done */
+
+ timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
+ ret = _test_connect_socket(sk, this_ip_dest, port, timeout);
+
+ if (ret < 0) {
+ if (fault(KEYREJECT) && ret == -EKEYREJECTED)
+ test_ok("%s: connect() was prevented", tst_name);
+ else if (ret == -ETIMEDOUT && fault(TIMEOUT))
+ test_ok("%s", tst_name);
+ else if (ret == -ECONNREFUSED &&
+ (fault(TIMEOUT) || fault(KEYREJECT)))
+ test_ok("%s: refused to connect", tst_name);
+ else
+ test_error("%s: connect() returned %d", tst_name, ret);
+ goto out;
+ }
+
+ if (fault(TIMEOUT) || fault(KEYREJECT))
+ test_fail("%s: connected", tst_name);
+ else
+ test_ok("%s: connected", tst_name);
+
+out:
+ synchronize_threads(); /* test_kill_sk() */
+ /* _test_connect_socket() cleans up on failure */
+ if (ret > 0)
+ test_kill_sk(sk);
+}
+
+#define PREINSTALL_MD5_FIRST BIT(0)
+#define PREINSTALL_AO BIT(1)
+#define POSTINSTALL_AO BIT(2)
+#define PREINSTALL_MD5 BIT(3)
+#define POSTINSTALL_MD5 BIT(4)
+
+static int try_add_key_vrf(int sk, union tcp_addr in_addr, uint8_t prefix,
+ int vrf, uint8_t sndid, uint8_t rcvid,
+ bool set_ao_required)
+{
+ uint8_t keyflags = 0;
+
+ if (vrf >= 0)
+ keyflags |= TCP_AO_KEYF_IFINDEX;
+ else
+ vrf = 0;
+ if (set_ao_required) {
+ int err = test_set_ao_flags(sk, true, 0);
+
+ if (err)
+ return err;
+ }
+ return test_add_key_vrf(sk, ao_password, keyflags, in_addr, prefix,
+ (uint8_t)vrf, sndid, rcvid);
+}
+
+static bool test_continue(const char *tst_name, int err,
+ fault_t inj, bool added_ao)
+{
+ bool expected_to_fail;
+
+ expected_to_fail = fault(PREINSTALL_AO) && added_ao;
+ expected_to_fail |= fault(PREINSTALL_MD5) && !added_ao;
+
+ if (!err) {
+ if (!expected_to_fail)
+ return true;
+ test_fail("%s: setsockopt()s were expected to fail", tst_name);
+ return false;
+ }
+ if (err != -EKEYREJECTED || !expected_to_fail) {
+ test_error("%s: setsockopt(%s) = %d", tst_name,
+ added_ao ? "TCP_AO_ADD_KEY" : "TCP_MD5SIG_EXT", err);
+ return false;
+ }
+ test_ok("%s: prefailed as expected: %m", tst_name);
+ return false;
+}
+
+static int open_add(const char *tst_name, unsigned int port,
+ unsigned int strategy,
+ union tcp_addr md5_addr, uint8_t md5_prefix, int md5_vrf,
+ union tcp_addr ao_addr, uint8_t ao_prefix,
+ int ao_vrf, bool set_ao_required,
+ uint8_t sndid, uint8_t rcvid,
+ fault_t inj)
+{
+ int sk;
+
+ sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
+ if (sk < 0)
+ test_error("socket()");
+
+ if (client_bind(sk, this_ip_addr))
+ test_error("bind()");
+
+ if (strategy & PREINSTALL_MD5_FIRST) {
+ if (test_set_md5(sk, md5_addr, md5_prefix, md5_vrf, md5_password))
+ test_error("setsockopt(TCP_MD5SIG_EXT)");
+ }
+
+ if (strategy & PREINSTALL_AO) {
+ int err = try_add_key_vrf(sk, ao_addr, ao_prefix, ao_vrf,
+ sndid, rcvid, set_ao_required);
+
+ if (!test_continue(tst_name, err, inj, true)) {
+ close(sk);
+ return -1;
+ }
+ }
+
+ if (strategy & PREINSTALL_MD5) {
+ errno = 0;
+ test_set_md5(sk, md5_addr, md5_prefix, md5_vrf, md5_password);
+ if (!test_continue(tst_name, -errno, inj, false)) {
+ close(sk);
+ return -1;
+ }
+ }
+
+ return sk;
+}
+
+static void try_to_preadd(const char *tst_name, unsigned int port,
+ unsigned int strategy,
+ union tcp_addr md5_addr, uint8_t md5_prefix,
+ int md5_vrf,
+ union tcp_addr ao_addr, uint8_t ao_prefix,
+ int ao_vrf, bool set_ao_required,
+ uint8_t sndid, uint8_t rcvid,
+ int needs_tcp_md5, int needs_vrf, fault_t inj)
+{
+ int sk;
+
+ if (needs_tcp_md5 && should_skip_test(tst_name, KCONFIG_TCP_MD5))
+ return;
+ if (needs_vrf && should_skip_test(tst_name, KCONFIG_NET_VRF))
+ return;
+
+ sk = open_add(tst_name, port, strategy, md5_addr, md5_prefix, md5_vrf,
+ ao_addr, ao_prefix, ao_vrf, set_ao_required,
+ sndid, rcvid, inj);
+ if (sk < 0)
+ return;
+
+ test_ok("%s", tst_name);
+ close(sk);
+}
+
+static void try_to_add(const char *tst_name, unsigned int port,
+ unsigned int strategy,
+ union tcp_addr md5_addr, uint8_t md5_prefix,
+ int md5_vrf,
+ union tcp_addr ao_addr, uint8_t ao_prefix,
+ int ao_vrf, uint8_t sndid, uint8_t rcvid,
+ int needs_tcp_md5, fault_t inj)
+{
+ time_t timeout;
+ int sk, ret;
+
+ if (needs_tcp_md5 && should_skip_test(tst_name, KCONFIG_TCP_MD5))
+ return;
+
+ sk = open_add(tst_name, port, strategy, md5_addr, md5_prefix, md5_vrf,
+ ao_addr, ao_prefix, ao_vrf, 0, sndid, rcvid, inj);
+ if (sk < 0)
+ return;
+
+ synchronize_threads(); /* preparations done */
+
+ timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
+ ret = _test_connect_socket(sk, this_ip_dest, port, timeout);
+
+ if (ret <= 0) {
+ test_error("%s: connect() returned %d", tst_name, ret);
+ goto out;
+ }
+
+ if (strategy & POSTINSTALL_MD5) {
+ if (test_set_md5(sk, md5_addr, md5_prefix, md5_vrf, md5_password)) {
+ if (fault(POSTINSTALL)) {
+ test_ok("%s: postfailed as expected", tst_name);
+ goto out;
+ } else {
+ test_error("setsockopt(TCP_MD5SIG_EXT)");
+ }
+ } else if (fault(POSTINSTALL)) {
+ test_fail("%s: post setsockopt() was expected to fail", tst_name);
+ goto out;
+ }
+ }
+
+ if (strategy & POSTINSTALL_AO) {
+ if (try_add_key_vrf(sk, ao_addr, ao_prefix, ao_vrf,
+ sndid, rcvid, 0)) {
+ if (fault(POSTINSTALL)) {
+ test_ok("%s: postfailed as expected", tst_name);
+ goto out;
+ } else {
+ test_error("setsockopt(TCP_AO_ADD_KEY)");
+ }
+ } else if (fault(POSTINSTALL)) {
+ test_fail("%s: post setsockopt() was expected to fail", tst_name);
+ goto out;
+ }
+ }
+
+out:
+ synchronize_threads(); /* test_kill_sk() */
+ /* _test_connect_socket() cleans up on failure */
+ if (ret > 0)
+ test_kill_sk(sk);
+}
+
+static void client_add_ip(union tcp_addr *client, const char *ip)
+{
+ int err, family = TEST_FAMILY;
+
+ if (inet_pton(family, ip, client) != 1)
+ test_error("Can't convert ip address %s", ip);
+
+ err = ip_addr_add(veth_name, family, *client, TEST_PREFIX);
+ if (err)
+ test_error("Failed to add ip address: %d", err);
+}
+
+static void client_add_ips(void)
+{
+ client_add_ip(&client2, __TEST_CLIENT_IP(2));
+ client_add_ip(&client3, __TEST_CLIENT_IP(3));
+ synchronize_threads(); /* server_add_routes() */
+}
+
+static void client_add_fail_tests(unsigned int *port)
+{
+ try_to_add("TCP-AO established: add TCP-MD5 key",
+ (*port)++, POSTINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_dest, TEST_PREFIX, -1, this_ip_dest, TEST_PREFIX, 0,
+ 100, 100, 1, FAULT_POSTINSTALL);
+ try_to_add("TCP-MD5 established: add TCP-AO key",
+ (*port)++, PREINSTALL_MD5 | POSTINSTALL_AO,
+ this_ip_dest, TEST_PREFIX, -1, this_ip_dest, TEST_PREFIX, 0,
+ 100, 100, 1, FAULT_POSTINSTALL);
+ try_to_add("non-signed established: add TCP-AO key",
+ (*port)++, POSTINSTALL_AO,
+ this_ip_dest, TEST_PREFIX, -1, this_ip_dest, TEST_PREFIX, 0,
+ 100, 100, 0, FAULT_POSTINSTALL);
+
+ try_to_add("TCP-AO key intersects with existing TCP-MD5 key",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, -1, this_ip_addr, TEST_PREFIX, -1,
+ 100, 100, 1, FAULT_PREINSTALL_AO);
+ try_to_add("TCP-MD5 key intersects with existing TCP-AO key",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, -1, this_ip_addr, TEST_PREFIX, -1,
+ 100, 100, 1, FAULT_PREINSTALL_MD5);
+
+ try_to_preadd("TCP-MD5 key + TCP-AO required",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, -1,
+ this_ip_addr, TEST_PREFIX, -1, true,
+ 100, 100, 1, 0, FAULT_PREINSTALL_AO);
+ try_to_preadd("TCP-AO required on socket + TCP-MD5 key",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, -1,
+ this_ip_addr, TEST_PREFIX, -1, true,
+ 100, 100, 1, 0, FAULT_PREINSTALL_MD5);
+}
+
+static void client_vrf_tests(unsigned int *port)
+{
+ setup_vrfs();
+
+ /* The following restrictions for setsockopt()s are expected:
+ *
+ * |--------------|-----------------|-------------|-------------|
+ * | | MD5 key without | MD5 key | MD5 key |
+ * | | l3index | l3index=0 | l3index=N |
+ * |--------------|-----------------|-------------|-------------|
+ * | TCP-AO key | | | |
+ * | without | reject | reject | reject |
+ * | l3index | | | |
+ * |--------------|-----------------|-------------|-------------|
+ * | TCP-AO key | | | |
+ * | l3index=0 | reject | reject | allow |
+ * |--------------|-----------------|-------------|-------------|
+ * | TCP-AO key | | | |
+ * | l3index=N | reject | allow | reject |
+ * |--------------|-----------------|-------------|-------------|
+ */
+ try_to_preadd("VRF: TCP-AO key (no l3index) + TCP-MD5 key (no l3index)",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, -1,
+ this_ip_addr, TEST_PREFIX, -1, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_MD5);
+ try_to_preadd("VRF: TCP-MD5 key (no l3index) + TCP-AO key (no l3index)",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, -1,
+ this_ip_addr, TEST_PREFIX, -1, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_AO);
+ try_to_preadd("VRF: TCP-AO key (no l3index) + TCP-MD5 key (l3index=0)",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, 0,
+ this_ip_addr, TEST_PREFIX, -1, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_MD5);
+ try_to_preadd("VRF: TCP-MD5 key (l3index=0) + TCP-AO key (no l3index)",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, 0,
+ this_ip_addr, TEST_PREFIX, -1, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_AO);
+ try_to_preadd("VRF: TCP-AO key (no l3index) + TCP-MD5 key (l3index=N)",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex,
+ this_ip_addr, TEST_PREFIX, -1, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_MD5);
+ try_to_preadd("VRF: TCP-MD5 key (l3index=N) + TCP-AO key (no l3index)",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex,
+ this_ip_addr, TEST_PREFIX, -1, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_AO);
+
+ try_to_preadd("VRF: TCP-AO key (l3index=0) + TCP-MD5 key (no l3index)",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, -1,
+ this_ip_addr, TEST_PREFIX, 0, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_MD5);
+ try_to_preadd("VRF: TCP-MD5 key (no l3index) + TCP-AO key (l3index=0)",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, -1,
+ this_ip_addr, TEST_PREFIX, 0, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_AO);
+ try_to_preadd("VRF: TCP-AO key (l3index=0) + TCP-MD5 key (l3index=0)",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, 0,
+ this_ip_addr, TEST_PREFIX, 0, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_MD5);
+ try_to_preadd("VRF: TCP-MD5 key (l3index=0) + TCP-AO key (l3index=0)",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, 0,
+ this_ip_addr, TEST_PREFIX, 0, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_AO);
+ try_to_preadd("VRF: TCP-AO key (l3index=0) + TCP-MD5 key (l3index=N)",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex,
+ this_ip_addr, TEST_PREFIX, 0, 0, 100, 100,
+ 1, 1, 0);
+ try_to_preadd("VRF: TCP-MD5 key (l3index=N) + TCP-AO key (l3index=0)",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex,
+ this_ip_addr, TEST_PREFIX, 0, 0, 100, 100,
+ 1, 1, 0);
+
+ try_to_preadd("VRF: TCP-AO key (l3index=N) + TCP-MD5 key (no l3index)",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex,
+ this_ip_addr, TEST_PREFIX, -1, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_MD5);
+ try_to_preadd("VRF: TCP-MD5 key (no l3index) + TCP-AO key (l3index=N)",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, -1,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_AO);
+ try_to_preadd("VRF: TCP-AO key (l3index=N) + TCP-MD5 key (l3index=0)",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, 0,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex, 0, 100, 100,
+ 1, 1, 0);
+ try_to_preadd("VRF: TCP-MD5 key (l3index=0) + TCP-AO key (l3index=N)",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, 0,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex, 0, 100, 100,
+ 1, 1, 0);
+ try_to_preadd("VRF: TCP-AO key (l3index=N) + TCP-MD5 key (l3index=N)",
+ (*port)++, PREINSTALL_MD5 | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_MD5);
+ try_to_preadd("VRF: TCP-MD5 key (l3index=N) + TCP-AO key (l3index=N)",
+ (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex,
+ this_ip_addr, TEST_PREFIX, test_vrf_ifindex, 0, 100, 100,
+ 1, 1, FAULT_PREINSTALL_AO);
+}
+
+static void *client_fn(void *arg)
+{
+ unsigned int port = test_server_port;
+ union tcp_addr addr_any = {};
+
+ client_add_ips();
+
+ try_connect("AO server (INADDR_ANY): AO client", port++, NULL, 0,
+ &addr_any, 0, 100, 100, 0, 0, 0, &this_ip_addr);
+ try_connect("AO server (INADDR_ANY): MD5 client", port++, &addr_any, 0,
+ NULL, 0, 100, 100, 0, FAULT_TIMEOUT, 1, &this_ip_addr);
+ try_connect("AO server (INADDR_ANY): unsigned client", port++, NULL, 0,
+ NULL, 0, 100, 100, 0, FAULT_TIMEOUT, 0, &this_ip_addr);
+ try_connect("AO server (AO_REQUIRED): AO client", port++, NULL, 0,
+ &addr_any, 0, 100, 100, 0, 0, 0, &this_ip_addr);
+ try_connect("AO server (AO_REQUIRED): unsigned client", port++, NULL, 0,
+ NULL, 0, 100, 100, 0, FAULT_TIMEOUT, 0, &client2);
+
+ try_connect("MD5 server (INADDR_ANY): AO client", port++, NULL, 0,
+ &addr_any, 0, 100, 100, 0, FAULT_TIMEOUT, 1, &this_ip_addr);
+ try_connect("MD5 server (INADDR_ANY): MD5 client", port++, &addr_any, 0,
+ NULL, 0, 100, 100, 0, 0, 1, &this_ip_addr);
+ try_connect("MD5 server (INADDR_ANY): no sign client", port++, NULL, 0,
+ NULL, 0, 100, 100, 0, FAULT_TIMEOUT, 1, &this_ip_addr);
+
+ try_connect("no sign server: AO client", port++, NULL, 0,
+ &addr_any, 0, 100, 100, 0, FAULT_TIMEOUT, 0, &this_ip_addr);
+ try_connect("no sign server: MD5 client", port++, &addr_any, 0,
+ NULL, 0, 100, 100, 0, FAULT_TIMEOUT, 1, &this_ip_addr);
+ try_connect("no sign server: no sign client", port++, NULL, 0,
+ NULL, 0, 100, 100, 0, 0, 0, &this_ip_addr);
+
+ try_connect("AO+MD5 server: AO client (matching)", port++, NULL, 0,
+ &addr_any, 0, 100, 100, 0, 0, 1, &client2);
+ try_connect("AO+MD5 server: AO client (misconfig, matching MD5)",
+ port++, NULL, 0, &addr_any, 0, 100, 100, 0,
+ FAULT_TIMEOUT, 1, &this_ip_addr);
+ try_connect("AO+MD5 server: AO client (misconfig, non-matching)",
+ port++, NULL, 0, &addr_any, 0, 100, 100, 0,
+ FAULT_TIMEOUT, 1, &client3);
+ try_connect("AO+MD5 server: MD5 client (matching)", port++, &addr_any, 0,
+ NULL, 0, 100, 100, 0, 0, 1, &this_ip_addr);
+ try_connect("AO+MD5 server: MD5 client (misconfig, matching AO)",
+ port++, &addr_any, 0, NULL, 0, 100, 100, 0, FAULT_TIMEOUT,
+ 1, &client2);
+ try_connect("AO+MD5 server: MD5 client (misconfig, non-matching)",
+ port++, &addr_any, 0, NULL, 0, 100, 100, 0, FAULT_TIMEOUT,
+ 1, &client3);
+ try_connect("AO+MD5 server: no sign client (unmatched)",
+ port++, NULL, 0, NULL, 0, 100, 100, 0, 0, 1, &client3);
+ try_connect("AO+MD5 server: no sign client (misconfig, matching AO)",
+ port++, NULL, 0, NULL, 0, 100, 100, 0, FAULT_TIMEOUT,
+ 1, &client2);
+ try_connect("AO+MD5 server: no sign client (misconfig, matching MD5)",
+ port++, NULL, 0, NULL, 0, 100, 100, 0, FAULT_TIMEOUT,
+ 1, &this_ip_addr);
+
+ try_connect("AO+MD5 server: client with both [TCP-MD5] and TCP-AO keys",
+ port++, &this_ip_addr, TEST_PREFIX,
+ &client2, TEST_PREFIX, 100, 100, 0, FAULT_KEYREJECT,
+ 1, &this_ip_addr);
+ try_connect("AO+MD5 server: client with both TCP-MD5 and [TCP-AO] keys",
+ port++, &this_ip_addr, TEST_PREFIX,
+ &client2, TEST_PREFIX, 100, 100, 0, FAULT_KEYREJECT,
+ 1, &client2);
+
+ client_add_fail_tests(&port);
+ client_vrf_tests(&port);
+
+ return NULL;
+}
+
+int main(int argc, char *argv[])
+{
+ test_init(72, server_fn, client_fn);
+ return 0;
+}