diff options
Diffstat (limited to 'tools/testing/selftests/net/tcp_ao')
26 files changed, 9580 insertions, 0 deletions
diff --git a/tools/testing/selftests/net/tcp_ao/.gitignore b/tools/testing/selftests/net/tcp_ao/.gitignore new file mode 100644 index 000000000000..e8bb81b715b7 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/.gitignore @@ -0,0 +1,2 @@ +*_ipv4 +*_ipv6 diff --git a/tools/testing/selftests/net/tcp_ao/Makefile b/tools/testing/selftests/net/tcp_ao/Makefile new file mode 100644 index 000000000000..5b0205c70c39 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/Makefile @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: GPL-2.0 +TEST_BOTH_AF := bench-lookups +TEST_BOTH_AF += connect +TEST_BOTH_AF += connect-deny +TEST_BOTH_AF += icmps-accept icmps-discard +TEST_BOTH_AF += key-management +TEST_BOTH_AF += restore +TEST_BOTH_AF += rst +TEST_BOTH_AF += self-connect +TEST_BOTH_AF += seq-ext +TEST_BOTH_AF += setsockopt-closed +TEST_BOTH_AF += unsigned-md5 + +TEST_IPV4_PROGS := $(TEST_BOTH_AF:%=%_ipv4) +TEST_IPV6_PROGS := $(TEST_BOTH_AF:%=%_ipv6) + +TEST_GEN_PROGS := $(TEST_IPV4_PROGS) $(TEST_IPV6_PROGS) + +top_srcdir := ../../../../.. +include ../../lib.mk + +HOSTAR ?= ar + +LIBDIR := $(OUTPUT)/lib +LIB := $(LIBDIR)/libaotst.a +LDLIBS += $(LIB) -pthread +LIBDEPS := lib/aolib.h Makefile + +CFLAGS += -Wall -O2 -g -fno-strict-aliasing +CFLAGS += $(KHDR_INCLUDES) +CFLAGS += -iquote ./lib/ -I ../../../../include/ + +# Library +LIBSRC := ftrace.c ftrace-tcp.c kconfig.c netlink.c +LIBSRC += proc.c repair.c setup.c sock.c utils.c +LIBOBJ := $(LIBSRC:%.c=$(LIBDIR)/%.o) +EXTRA_CLEAN += $(LIBOBJ) $(LIB) + +$(LIB): $(LIBOBJ) + $(HOSTAR) rcs $@ $^ + +$(LIBDIR)/%.o: ./lib/%.c $(LIBDEPS) + mkdir -p $(LIBDIR) + $(CC) $< $(CFLAGS) $(CPPFLAGS) -o $@ -c + +$(TEST_GEN_PROGS): $(LIB) + +$(OUTPUT)/%_ipv4: %.c + $(LINK.c) $^ $(LDLIBS) -o $@ + +$(OUTPUT)/%_ipv6: %.c + $(LINK.c) -DIPV6_TEST $^ $(LDLIBS) -o $@ + +$(OUTPUT)/icmps-accept_ipv4: CFLAGS+= -DTEST_ICMPS_ACCEPT +$(OUTPUT)/icmps-accept_ipv6: CFLAGS+= -DTEST_ICMPS_ACCEPT +$(OUTPUT)/bench-lookups_ipv4: LDLIBS+= -lm +$(OUTPUT)/bench-lookups_ipv6: LDLIBS+= -lm diff --git a/tools/testing/selftests/net/tcp_ao/bench-lookups.c b/tools/testing/selftests/net/tcp_ao/bench-lookups.c new file mode 100644 index 000000000000..6736484996a3 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/bench-lookups.c @@ -0,0 +1,360 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Author: Dmitry Safonov <dima@arista.com> */ +#include <arpa/inet.h> +#include <inttypes.h> +#include <math.h> +#include <stdlib.h> +#include <stdio.h> +#include <time.h> + +#include "../../../../include/linux/bits.h" +#include "../../../../include/linux/kernel.h" +#include "aolib.h" + +#define BENCH_NR_ITERS 100 /* number of times to run gathering statistics */ + +static void gen_test_ips(union tcp_addr *ips, size_t ips_nr, bool use_rand) +{ + union tcp_addr net = {}; + size_t i, j; + + if (inet_pton(TEST_FAMILY, TEST_NETWORK, &net) != 1) + test_error("Can't convert ip address %s", TEST_NETWORK); + + if (!use_rand) { + for (i = 0; i < ips_nr; i++) + ips[i] = gen_tcp_addr(net, 2 * i + 1); + return; + } + for (i = 0; i < ips_nr; i++) { + size_t r = (size_t)random() | 0x1; + + ips[i] = gen_tcp_addr(net, r); + + for (j = i - 1; j > 0 && i > 0; j--) { + if (!memcmp(&ips[i], &ips[j], sizeof(union tcp_addr))) { + i--; /* collision */ + break; + } + } + } +} + +static void test_add_routes(union tcp_addr *ips, size_t ips_nr) +{ + size_t i; + + for (i = 0; i < ips_nr; i++) { + union tcp_addr *p = (union tcp_addr *)&ips[i]; + int err; + + err = ip_route_add(veth_name, TEST_FAMILY, this_ip_addr, *p); + if (err && err != -EEXIST) + test_error("Failed to add route"); + } +} + +static void server_apply_keys(int lsk, union tcp_addr *ips, size_t ips_nr) +{ + size_t i; + + for (i = 0; i < ips_nr; i++) { + union tcp_addr *p = (union tcp_addr *)&ips[i]; + + if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, *p, -1, 100, 100)) + test_error("setsockopt(TCP_AO)"); + } +} + +static const size_t nr_keys[] = { 512, 1024, 2048, 4096, 8192 }; +static union tcp_addr *test_ips; + +struct bench_stats { + uint64_t min; + uint64_t max; + uint64_t nr; + double mean; + double s2; +}; + +static struct bench_tests { + struct bench_stats delete_last_key; + struct bench_stats add_key; + struct bench_stats delete_rand_key; + struct bench_stats connect_last_key; + struct bench_stats connect_rand_key; + struct bench_stats delete_async; +} bench_results[ARRAY_SIZE(nr_keys)]; + +#define NSEC_PER_SEC 1000000000ULL + +static void measure_call(struct bench_stats *st, + void (*f)(int, void *), int sk, void *arg) +{ + struct timespec start = {}, end = {}; + double delta; + uint64_t nsec; + + if (clock_gettime(CLOCK_MONOTONIC, &start)) + test_error("clock_gettime()"); + + f(sk, arg); + + if (clock_gettime(CLOCK_MONOTONIC, &end)) + test_error("clock_gettime()"); + + nsec = (end.tv_sec - start.tv_sec) * NSEC_PER_SEC; + if (end.tv_nsec >= start.tv_nsec) + nsec += end.tv_nsec - start.tv_nsec; + else + nsec -= start.tv_nsec - end.tv_nsec; + + if (st->nr == 0) { + st->min = st->max = nsec; + } else { + if (st->min > nsec) + st->min = nsec; + if (st->max < nsec) + st->max = nsec; + } + + /* Welford-Knuth algorithm */ + st->nr++; + delta = (double)nsec - st->mean; + st->mean += delta / st->nr; + st->s2 += delta * ((double)nsec - st->mean); +} + +static void delete_mkt(int sk, void *arg) +{ + struct tcp_ao_del *ao = arg; + + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_DEL_KEY, ao, sizeof(*ao))) + test_error("setsockopt(TCP_AO_DEL_KEY)"); +} + +static void add_back_mkt(int sk, void *arg) +{ + union tcp_addr *p = arg; + + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, *p, -1, 100, 100)) + test_error("setsockopt(TCP_AO)"); +} + +static void bench_delete(int lsk, struct bench_stats *add, + struct bench_stats *del, + union tcp_addr *ips, size_t ips_nr, + bool rand_order, bool async) +{ + struct tcp_ao_del ao_del = {}; + union tcp_addr *p; + size_t i; + + ao_del.sndid = 100; + ao_del.rcvid = 100; + ao_del.del_async = !!async; + ao_del.prefix = DEFAULT_TEST_PREFIX; + + /* Remove the first added */ + p = (union tcp_addr *)&ips[0]; + tcp_addr_to_sockaddr_in(&ao_del.addr, p, 0); + + for (i = 0; i < BENCH_NR_ITERS; i++) { + measure_call(del, delete_mkt, lsk, (void *)&ao_del); + + /* Restore it back */ + measure_call(add, add_back_mkt, lsk, (void *)p); + + /* + * Slowest for FILO-linked-list: + * on (i) iteration removing ips[i] element. When it gets + * added to the list back - it becomes first to fetch, so + * on (i + 1) iteration go to ips[i + 1] element. + */ + if (rand_order) + p = (union tcp_addr *)&ips[rand() % ips_nr]; + else + p = (union tcp_addr *)&ips[i % ips_nr]; + tcp_addr_to_sockaddr_in(&ao_del.addr, p, 0); + } +} + +static void bench_connect_srv(int lsk, union tcp_addr *ips, size_t ips_nr) +{ + size_t i; + + for (i = 0; i < BENCH_NR_ITERS; i++) { + int sk; + + synchronize_threads(); + + if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0)) + test_error("test_wait_fd()"); + + sk = accept(lsk, NULL, NULL); + if (sk < 0) + test_error("accept()"); + + close(sk); + } +} + +static void test_print_stats(const char *desc, size_t nr, struct bench_stats *bs) +{ + test_ok("%-20s\t%zu keys: min=%" PRIu64 "ms max=%" PRIu64 "ms mean=%gms stddev=%g", + desc, nr, bs->min / 1000000, bs->max / 1000000, + bs->mean / 1000000, sqrt((bs->mean / 1000000) / bs->nr)); +} + +static void *server_fn(void *arg) +{ + size_t i; + + for (i = 0; i < ARRAY_SIZE(nr_keys); i++) { + struct bench_tests *bt = &bench_results[i]; + int lsk; + + test_ips = malloc(nr_keys[i] * sizeof(union tcp_addr)); + if (!test_ips) + test_error("malloc()"); + + lsk = test_listen_socket(this_ip_addr, test_server_port + i, 1); + + gen_test_ips(test_ips, nr_keys[i], false); + test_add_routes(test_ips, nr_keys[i]); + test_set_optmem(KERNEL_TCP_AO_KEY_SZ_ROUND_UP * nr_keys[i]); + server_apply_keys(lsk, test_ips, nr_keys[i]); + + synchronize_threads(); + bench_connect_srv(lsk, test_ips, nr_keys[i]); + bench_connect_srv(lsk, test_ips, nr_keys[i]); + + /* The worst case for FILO-list */ + bench_delete(lsk, &bt->add_key, &bt->delete_last_key, + test_ips, nr_keys[i], false, false); + test_print_stats("Add a new key", + nr_keys[i], &bt->add_key); + test_print_stats("Delete: worst case", + nr_keys[i], &bt->delete_last_key); + + bench_delete(lsk, &bt->add_key, &bt->delete_rand_key, + test_ips, nr_keys[i], true, false); + test_print_stats("Delete: random-search", + nr_keys[i], &bt->delete_rand_key); + + bench_delete(lsk, &bt->add_key, &bt->delete_async, + test_ips, nr_keys[i], false, true); + test_print_stats("Delete: async", nr_keys[i], &bt->delete_async); + + free(test_ips); + close(lsk); + } + + return NULL; +} + +static void connect_client(int sk, void *arg) +{ + size_t *p = arg; + + if (test_connect_socket(sk, this_ip_dest, test_server_port + *p) <= 0) + test_error("failed to connect()"); +} + +static void client_addr_setup(int sk, union tcp_addr taddr) +{ +#ifdef IPV6_TEST + struct sockaddr_in6 addr = { + .sin6_family = AF_INET6, + .sin6_port = 0, + .sin6_addr = taddr.a6, + }; +#else + struct sockaddr_in addr = { + .sin_family = AF_INET, + .sin_port = 0, + .sin_addr = taddr.a4, + }; +#endif + int ret; + + ret = ip_addr_add(veth_name, TEST_FAMILY, taddr, TEST_PREFIX); + if (ret && ret != -EEXIST) + test_error("Failed to add ip address"); + ret = ip_route_add(veth_name, TEST_FAMILY, taddr, this_ip_dest); + if (ret && ret != -EEXIST) + test_error("Failed to add route"); + + if (bind(sk, &addr, sizeof(addr))) + test_error("bind()"); +} + +static void bench_connect_client(size_t port_off, struct bench_tests *bt, + union tcp_addr *ips, size_t ips_nr, bool rand_order) +{ + struct bench_stats *con; + union tcp_addr *p; + size_t i; + + if (rand_order) + con = &bt->connect_rand_key; + else + con = &bt->connect_last_key; + + p = (union tcp_addr *)&ips[0]; + + for (i = 0; i < BENCH_NR_ITERS; i++) { + int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + + if (sk < 0) + test_error("socket()"); + + client_addr_setup(sk, *p); + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, + -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + synchronize_threads(); + + measure_call(con, connect_client, sk, (void *)&port_off); + + close(sk); + + /* + * Slowest for FILO-linked-list: + * on (i) iteration removing ips[i] element. When it gets + * added to the list back - it becomes first to fetch, so + * on (i + 1) iteration go to ips[i + 1] element. + */ + if (rand_order) + p = (union tcp_addr *)&ips[rand() % ips_nr]; + else + p = (union tcp_addr *)&ips[i % ips_nr]; + } +} + +static void *client_fn(void *arg) +{ + size_t i; + + for (i = 0; i < ARRAY_SIZE(nr_keys); i++) { + struct bench_tests *bt = &bench_results[i]; + + synchronize_threads(); + bench_connect_client(i, bt, test_ips, nr_keys[i], false); + test_print_stats("Connect: worst case", + nr_keys[i], &bt->connect_last_key); + + bench_connect_client(i, bt, test_ips, nr_keys[i], false); + test_print_stats("Connect: random-search", + nr_keys[i], &bt->connect_last_key); + } + synchronize_threads(); + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(31, server_fn, client_fn); + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/config b/tools/testing/selftests/net/tcp_ao/config new file mode 100644 index 000000000000..3605e38711cb --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/config @@ -0,0 +1,11 @@ +CONFIG_CRYPTO_HMAC=y +CONFIG_CRYPTO_RMD160=y +CONFIG_CRYPTO_SHA1=y +CONFIG_IPV6_MULTIPLE_TABLES=y +CONFIG_IPV6=y +CONFIG_NET_L3_MASTER_DEV=y +CONFIG_NET_VRF=y +CONFIG_TCP_AO=y +CONFIG_TCP_MD5SIG=y +CONFIG_TRACEPOINTS=y +CONFIG_VETH=m diff --git a/tools/testing/selftests/net/tcp_ao/connect-deny.c b/tools/testing/selftests/net/tcp_ao/connect-deny.c new file mode 100644 index 000000000000..93b61e9a36f1 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/connect-deny.c @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Author: Dmitry Safonov <dima@arista.com> */ +#include <inttypes.h> +#include "aolib.h" + +#define fault(type) (inj == FAULT_ ## type) +static volatile int sk_pair; + +static inline int test_add_key_maclen(int sk, const char *key, uint8_t maclen, + union tcp_addr in_addr, uint8_t prefix, + uint8_t sndid, uint8_t rcvid) +{ + struct tcp_ao_add tmp = {}; + int err; + + if (prefix > DEFAULT_TEST_PREFIX) + prefix = DEFAULT_TEST_PREFIX; + + err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr, false, false, + prefix, 0, sndid, rcvid, maclen, + 0, strlen(key), key); + if (err) + return err; + + err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)); + if (err < 0) + return -errno; + + return test_verify_socket_key(sk, &tmp); +} + +static void try_accept(const char *tst_name, unsigned int port, const char *pwd, + union tcp_addr addr, uint8_t prefix, + uint8_t sndid, uint8_t rcvid, uint8_t maclen, + const char *cnt_name, test_cnt cnt_expected, + fault_t inj) +{ + struct tcp_counters cnt1, cnt2; + uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */ + test_cnt poll_cnt = (cnt_expected == TEST_CNT_GOOD) ? 0 : cnt_expected; + int lsk, err, sk = 0; + + lsk = test_listen_socket(this_ip_addr, port, 1); + + if (pwd && test_add_key_maclen(lsk, pwd, maclen, addr, prefix, sndid, rcvid)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + if (cnt_name) + before_cnt = netstat_get_one(cnt_name, NULL); + if (pwd && test_get_tcp_counters(lsk, &cnt1)) + test_error("test_get_tcp_counters()"); + + synchronize_threads(); /* preparations done */ + + err = test_skpair_wait_poll(lsk, 0, poll_cnt, &sk_pair); + if (err == -ETIMEDOUT) { + sk_pair = err; + if (!fault(TIMEOUT)) + test_fail("%s: timed out for accept()", tst_name); + } else if (err == -EKEYREJECTED) { + if (!fault(KEYREJECT)) + test_fail("%s: key was rejected", tst_name); + } else if (err < 0) { + test_error("test_skpair_wait_poll()"); + } else { + if (fault(TIMEOUT)) + test_fail("%s: ready to accept", tst_name); + + sk = accept(lsk, NULL, NULL); + if (sk < 0) { + test_error("accept()"); + } else { + if (fault(TIMEOUT)) + test_fail("%s: accepted", tst_name); + } + } + + synchronize_threads(); /* before counter checks */ + if (pwd && test_get_tcp_counters(lsk, &cnt2)) + test_error("test_get_tcp_counters()"); + + close(lsk); + + if (pwd) + test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected); + + if (!cnt_name) + goto out; + + after_cnt = netstat_get_one(cnt_name, NULL); + + if (after_cnt <= before_cnt) { + test_fail("%s: %s counter did not increase: %" PRIu64 " <= %" PRIu64, + tst_name, cnt_name, after_cnt, before_cnt); + } else { + test_ok("%s: counter %s increased %" PRIu64 " => %" PRIu64, + tst_name, cnt_name, before_cnt, after_cnt); + } + +out: + synchronize_threads(); /* close() */ + if (sk > 0) + close(sk); +} + +static void *server_fn(void *arg) +{ + union tcp_addr wrong_addr, network_addr; + unsigned int port = test_server_port; + + if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1) + test_error("Can't convert ip address %s", TEST_WRONG_IP); + + try_accept("Non-AO server + AO client", port++, NULL, + this_ip_dest, -1, 100, 100, 0, + "TCPAOKeyNotFound", TEST_CNT_NS_KEY_NOT_FOUND, FAULT_TIMEOUT); + + try_accept("AO server + Non-AO client", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100, 0, + "TCPAORequired", TEST_CNT_AO_REQUIRED, FAULT_TIMEOUT); + + try_accept("Wrong password", port++, "something that is not DEFAULT_TEST_PASSWORD", + this_ip_dest, -1, 100, 100, 0, + "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT); + + try_accept("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 101, 0, + "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT); + + try_accept("Wrong snd id", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 101, 100, 0, + "TCPAOGood", TEST_CNT_GOOD, FAULT_TIMEOUT); + + try_accept("Different maclen", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100, 8, + "TCPAOBad", TEST_CNT_BAD, FAULT_TIMEOUT); + + try_accept("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD, + wrong_addr, -1, 100, 100, 0, + "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, FAULT_TIMEOUT); + + /* Key rejected by the other side, failing short through skpair */ + try_accept("Client: Wrong addr", port++, NULL, + this_ip_dest, -1, 100, 100, 0, NULL, 0, FAULT_KEYREJECT); + + try_accept("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 200, 100, 0, + "TCPAOGood", TEST_CNT_GOOD, 0); + + if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1) + test_error("Can't convert ip address %s", TEST_NETWORK); + + try_accept("Server: prefix match", port++, DEFAULT_TEST_PASSWORD, + network_addr, 16, 100, 100, 0, + "TCPAOGood", TEST_CNT_GOOD, 0); + + try_accept("Client: prefix match", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100, 0, + "TCPAOGood", TEST_CNT_GOOD, 0); + + /* client exits */ + synchronize_threads(); + return NULL; +} + +static void try_connect(const char *tst_name, unsigned int port, + const char *pwd, union tcp_addr addr, uint8_t prefix, + uint8_t sndid, uint8_t rcvid, + test_cnt cnt_expected, fault_t inj) +{ + struct tcp_counters cnt1, cnt2; + int sk, ret; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + if (pwd && test_add_key(sk, pwd, addr, prefix, sndid, rcvid)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + if (pwd && test_get_tcp_counters(sk, &cnt1)) + test_error("test_get_tcp_counters()"); + + synchronize_threads(); /* preparations done */ + + ret = test_skpair_connect_poll(sk, this_ip_dest, port, cnt_expected, &sk_pair); + synchronize_threads(); /* before counter checks */ + if (ret < 0) { + sk_pair = ret; + if (fault(KEYREJECT) && ret == -EKEYREJECTED) { + test_ok("%s: connect() was prevented", tst_name); + } else if (ret == -ETIMEDOUT && fault(TIMEOUT)) { + test_ok("%s", tst_name); + } else if (ret == -ECONNREFUSED && + (fault(TIMEOUT) || fault(KEYREJECT))) { + test_ok("%s: refused to connect", tst_name); + } else { + test_error("%s: connect() returned %d", tst_name, ret); + } + goto out; + } + + if (fault(TIMEOUT) || fault(KEYREJECT)) + test_fail("%s: connected", tst_name); + else + test_ok("%s: connected", tst_name); + if (pwd && ret > 0) { + if (test_get_tcp_counters(sk, &cnt2)) + test_error("test_get_tcp_counters()"); + test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected); + } else if (pwd) { + test_tcp_counters_free(&cnt1); + } +out: + synchronize_threads(); /* close() */ + + if (ret > 0) + close(sk); +} + +static void *client_fn(void *arg) +{ + union tcp_addr wrong_addr, network_addr, addr_any = {}; + unsigned int port = test_server_port; + + if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1) + test_error("Can't convert ip address %s", TEST_WRONG_IP); + + trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest, + -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); + try_connect("Non-AO server + AO client", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); + + trace_hash_event_expect(TCP_HASH_AO_REQUIRED, this_ip_addr, this_ip_dest, + -1, port, 0, 0, 1, 0, 0, 0); + try_connect("AO server + Non-AO client", port++, NULL, + this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); + + trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_addr, this_ip_dest, + -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); + try_connect("Wrong password", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); + + trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest, + -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); + try_connect("Wrong rcv id", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); + + /* + * XXX: The test doesn't increase any counters, see tcp_make_synack(). + * Potentially, it can be speed up by setting sk_pair = -ETIMEDOUT + * but the price would be increased complexity of the tracer thread. + */ + trace_ao_event_sk_expect(TCP_AO_SYNACK_NO_KEY, this_ip_dest, addr_any, + port, 0, 100, 100); + try_connect("Wrong snd id", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); + + trace_ao_event_expect(TCP_AO_WRONG_MACLEN, this_ip_addr, this_ip_dest, + -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); + try_connect("Different maclen", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); + + trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest, + -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); + try_connect("Server: Wrong addr", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100, 0, FAULT_TIMEOUT); + + try_connect("Client: Wrong addr", port++, DEFAULT_TEST_PASSWORD, + wrong_addr, -1, 100, 100, 0, FAULT_KEYREJECT); + + try_connect("rcv id != snd id", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 200, TEST_CNT_GOOD, 0); + + if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_addr) != 1) + test_error("Can't convert ip address %s", TEST_NETWORK); + + try_connect("Server: prefix match", port++, DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100, TEST_CNT_GOOD, 0); + + try_connect("Client: prefix match", port++, DEFAULT_TEST_PASSWORD, + network_addr, 16, 100, 100, TEST_CNT_GOOD, 0); + + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(22, server_fn, client_fn); + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/connect.c b/tools/testing/selftests/net/tcp_ao/connect.c new file mode 100644 index 000000000000..340f00e979ea --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/connect.c @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Author: Dmitry Safonov <dima@arista.com> */ +#include <inttypes.h> +#include "aolib.h" + +static void *server_fn(void *arg) +{ + int sk, lsk; + ssize_t bytes; + + lsk = test_listen_socket(this_ip_addr, test_server_port, 1); + + if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + synchronize_threads(); + + if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0)) + test_error("test_wait_fd()"); + + sk = accept(lsk, NULL, NULL); + if (sk < 0) + test_error("accept()"); + + synchronize_threads(); + + bytes = test_server_run(sk, 0, 0); + + test_fail("server served: %zd", bytes); + return NULL; +} + +static void *client_fn(void *arg) +{ + int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + uint64_t before_aogood, after_aogood; + const size_t nr_packets = 20; + struct netstat *ns_before, *ns_after; + struct tcp_counters ao1, ao2; + + if (sk < 0) + test_error("socket()"); + + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + synchronize_threads(); + if (test_connect_socket(sk, this_ip_dest, test_server_port) <= 0) + test_error("failed to connect()"); + synchronize_threads(); + + ns_before = netstat_read(); + before_aogood = netstat_get(ns_before, "TCPAOGood", NULL); + if (test_get_tcp_counters(sk, &ao1)) + test_error("test_get_tcp_counters()"); + + if (test_client_verify(sk, 100, nr_packets)) { + test_fail("verify failed"); + return NULL; + } + + ns_after = netstat_read(); + after_aogood = netstat_get(ns_after, "TCPAOGood", NULL); + if (test_get_tcp_counters(sk, &ao2)) + test_error("test_get_tcp_counters()"); + netstat_print_diff(ns_before, ns_after); + netstat_free(ns_before); + netstat_free(ns_after); + + if (nr_packets > (after_aogood - before_aogood)) { + test_fail("TCPAOGood counter mismatch: %zu > (%" PRIu64 " - %" PRIu64 ")", + nr_packets, after_aogood, before_aogood); + return NULL; + } + if (test_assert_counters("connect", &ao1, &ao2, TEST_CNT_GOOD)) + return NULL; + + test_ok("connect TCPAOGood %" PRIu64 "/%" PRIu64 "/%" PRIu64 " => %" PRIu64 "/%" PRIu64 "/%" PRIu64 ", sent %zu", + before_aogood, ao1.ao.ao_info_pkt_good, + ao1.ao.key_cnts[0].pkt_good, + after_aogood, ao2.ao.ao_info_pkt_good, + ao2.ao.key_cnts[0].pkt_good, + nr_packets); + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(2, server_fn, client_fn); + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/icmps-accept.c b/tools/testing/selftests/net/tcp_ao/icmps-accept.c new file mode 120000 index 000000000000..0a5bb85eb260 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/icmps-accept.c @@ -0,0 +1 @@ +icmps-discard.c
\ No newline at end of file diff --git a/tools/testing/selftests/net/tcp_ao/icmps-discard.c b/tools/testing/selftests/net/tcp_ao/icmps-discard.c new file mode 100644 index 000000000000..85c1a1e958c6 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/icmps-discard.c @@ -0,0 +1,448 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Selftest that verifies that incomping ICMPs are ignored, + * the TCP connection stays alive, no hard or soft errors get reported + * to the usespace and the counter for ignored ICMPs is updated. + * + * RFC5925, 7.8: + * >> A TCP-AO implementation MUST default to ignore incoming ICMPv4 + * messages of Type 3 (destination unreachable), Codes 2-4 (protocol + * unreachable, port unreachable, and fragmentation needed -- ’hard + * errors’), and ICMPv6 Type 1 (destination unreachable), Code 1 + * (administratively prohibited) and Code 4 (port unreachable) intended + * for connections in synchronized states (ESTABLISHED, FIN-WAIT-1, FIN- + * WAIT-2, CLOSE-WAIT, CLOSING, LAST-ACK, TIME-WAIT) that match MKTs. + * + * Author: Dmitry Safonov <dima@arista.com> + */ +#include <inttypes.h> +#include <linux/icmp.h> +#include <linux/icmpv6.h> +#include <linux/ipv6.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <sys/socket.h> +#include "aolib.h" +#include "../../../../include/linux/compiler.h" + +const size_t packets_nr = 20; +const size_t packet_size = 100; +const char *tcpao_icmps = "TCPAODroppedIcmps"; + +#ifdef IPV6_TEST +const char *dst_unreach = "Icmp6InDestUnreachs"; +const int sk_ip_level = SOL_IPV6; +const int sk_recverr = IPV6_RECVERR; +#else +const char *dst_unreach = "InDestUnreachs"; +const int sk_ip_level = SOL_IP; +const int sk_recverr = IP_RECVERR; +#endif + +/* Server is expected to fail with hard error if ::accept_icmp is set */ +#ifdef TEST_ICMPS_ACCEPT +# define test_icmps_fail test_ok +# define test_icmps_ok test_fail +#else +# define test_icmps_fail test_fail +# define test_icmps_ok test_ok +#endif + +static void serve_interfered(int sk) +{ + ssize_t test_quota = packet_size * packets_nr * 10; + uint64_t dest_unreach_a, dest_unreach_b; + uint64_t icmp_ignored_a, icmp_ignored_b; + struct tcp_counters cnt1, cnt2; + bool counter_not_found; + struct netstat *ns_after, *ns_before; + ssize_t bytes; + + ns_before = netstat_read(); + dest_unreach_a = netstat_get(ns_before, dst_unreach, NULL); + icmp_ignored_a = netstat_get(ns_before, tcpao_icmps, NULL); + if (test_get_tcp_counters(sk, &cnt1)) + test_error("test_get_tcp_counters()"); + bytes = test_server_run(sk, test_quota, 0); + ns_after = netstat_read(); + netstat_print_diff(ns_before, ns_after); + dest_unreach_b = netstat_get(ns_after, dst_unreach, NULL); + icmp_ignored_b = netstat_get(ns_after, tcpao_icmps, + &counter_not_found); + if (test_get_tcp_counters(sk, &cnt2)) + test_error("test_get_tcp_counters()"); + + netstat_free(ns_before); + netstat_free(ns_after); + + if (dest_unreach_a >= dest_unreach_b) { + test_fail("%s counter didn't change: %" PRIu64 " >= %" PRIu64, + dst_unreach, dest_unreach_a, dest_unreach_b); + return; + } + test_ok("%s delivered %" PRIu64, + dst_unreach, dest_unreach_b - dest_unreach_a); + if (bytes < 0) + test_icmps_fail("Server failed with %zd: %s", bytes, strerrordesc_np(-bytes)); + else + test_icmps_ok("Server survived %zd bytes of traffic", test_quota); + if (counter_not_found) { + test_fail("Not found %s counter", tcpao_icmps); + return; + } +#ifdef TEST_ICMPS_ACCEPT + test_assert_counters(NULL, &cnt1, &cnt2, TEST_CNT_GOOD); +#else + test_assert_counters(NULL, &cnt1, &cnt2, TEST_CNT_GOOD | TEST_CNT_AO_DROPPED_ICMP); +#endif + if (icmp_ignored_a >= icmp_ignored_b) { + test_icmps_fail("%s counter didn't change: %" PRIu64 " >= %" PRIu64, + tcpao_icmps, icmp_ignored_a, icmp_ignored_b); + return; + } + test_icmps_ok("ICMPs ignored %" PRIu64, icmp_ignored_b - icmp_ignored_a); +} + +static void *server_fn(void *arg) +{ + int val, sk, lsk; + bool accept_icmps = false; + + lsk = test_listen_socket(this_ip_addr, test_server_port, 1); + +#ifdef TEST_ICMPS_ACCEPT + accept_icmps = true; +#endif + + if (test_set_ao_flags(lsk, false, accept_icmps)) + test_error("setsockopt(TCP_AO_INFO)"); + + if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + synchronize_threads(); + + if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0)) + test_error("test_wait_fd()"); + + sk = accept(lsk, NULL, NULL); + if (sk < 0) + test_error("accept()"); + + /* Fail on hard ip errors, such as dest unreachable (RFC1122) */ + val = 1; + if (setsockopt(sk, sk_ip_level, sk_recverr, &val, sizeof(val))) + test_error("setsockopt()"); + + synchronize_threads(); + + serve_interfered(sk); + return NULL; +} + +static size_t packets_sent; +static size_t icmps_sent; + +static uint32_t checksum4_nofold(void *data, size_t len, uint32_t sum) +{ + uint16_t *words = data; + size_t i; + + for (i = 0; i < len / sizeof(uint16_t); i++) + sum += words[i]; + if (len & 1) + sum += ((char *)data)[len - 1]; + return sum; +} + +static uint16_t checksum4_fold(void *data, size_t len, uint32_t sum) +{ + sum = checksum4_nofold(data, len, sum); + while (sum > 0xFFFF) + sum = (sum & 0xFFFF) + (sum >> 16); + return ~sum; +} + +static void set_ip4hdr(struct iphdr *iph, size_t packet_len, int proto, + struct sockaddr_in *src, struct sockaddr_in *dst) +{ + iph->version = 4; + iph->ihl = 5; + iph->tos = 0; + iph->tot_len = htons(packet_len); + iph->ttl = 2; + iph->protocol = proto; + iph->saddr = src->sin_addr.s_addr; + iph->daddr = dst->sin_addr.s_addr; + iph->check = checksum4_fold((void *)iph, iph->ihl << 1, 0); +} + +static void icmp_interfere4(uint8_t type, uint8_t code, uint32_t rcv_nxt, + struct sockaddr_in *src, struct sockaddr_in *dst) +{ + int sk = socket(AF_INET, SOCK_RAW, IPPROTO_RAW); + struct { + struct iphdr iph; + struct icmphdr icmph; + struct iphdr iphe; + struct { + uint16_t sport; + uint16_t dport; + uint32_t seq; + } tcph; + } packet = {}; + size_t packet_len; + ssize_t bytes; + + if (sk < 0) + test_error("socket(AF_INET, SOCK_RAW, IPPROTO_RAW)"); + + packet_len = sizeof(packet); + set_ip4hdr(&packet.iph, packet_len, IPPROTO_ICMP, src, dst); + + packet.icmph.type = type; + packet.icmph.code = code; + if (code == ICMP_FRAG_NEEDED) { + randomize_buffer(&packet.icmph.un.frag.mtu, + sizeof(packet.icmph.un.frag.mtu)); + } + + packet_len = sizeof(packet.iphe) + sizeof(packet.tcph); + set_ip4hdr(&packet.iphe, packet_len, IPPROTO_TCP, dst, src); + + packet.tcph.sport = dst->sin_port; + packet.tcph.dport = src->sin_port; + packet.tcph.seq = htonl(rcv_nxt); + + packet_len = sizeof(packet) - sizeof(packet.iph); + packet.icmph.checksum = checksum4_fold((void *)&packet.icmph, + packet_len, 0); + + bytes = sendto(sk, &packet, sizeof(packet), 0, + (struct sockaddr *)dst, sizeof(*dst)); + if (bytes != sizeof(packet)) + test_error("send(): %zd", bytes); + icmps_sent++; + + close(sk); +} + +static void set_ip6hdr(struct ipv6hdr *iph, size_t packet_len, int proto, + struct sockaddr_in6 *src, struct sockaddr_in6 *dst) +{ + iph->version = 6; + iph->payload_len = htons(packet_len); + iph->nexthdr = proto; + iph->hop_limit = 2; + iph->saddr = src->sin6_addr; + iph->daddr = dst->sin6_addr; +} + +static inline uint16_t csum_fold(uint32_t csum) +{ + uint32_t sum = csum; + + sum = (sum & 0xffff) + (sum >> 16); + sum = (sum & 0xffff) + (sum >> 16); + return (uint16_t)~sum; +} + +static inline uint32_t csum_add(uint32_t csum, uint32_t addend) +{ + uint32_t res = csum; + + res += addend; + return res + (res < addend); +} + +noinline uint32_t checksum6_nofold(void *data, size_t len, uint32_t sum) +{ + uint16_t *words = data; + size_t i; + + for (i = 0; i < len / sizeof(uint16_t); i++) + sum = csum_add(sum, words[i]); + if (len & 1) + sum = csum_add(sum, ((char *)data)[len - 1]); + return sum; +} + +noinline uint16_t icmp6_checksum(struct sockaddr_in6 *src, + struct sockaddr_in6 *dst, + void *ptr, size_t len, uint8_t proto) +{ + struct { + struct in6_addr saddr; + struct in6_addr daddr; + uint32_t payload_len; + uint8_t zero[3]; + uint8_t nexthdr; + } pseudo_header = {}; + uint32_t sum; + + pseudo_header.saddr = src->sin6_addr; + pseudo_header.daddr = dst->sin6_addr; + pseudo_header.payload_len = htonl(len); + pseudo_header.nexthdr = proto; + + sum = checksum6_nofold(&pseudo_header, sizeof(pseudo_header), 0); + sum = checksum6_nofold(ptr, len, sum); + + return csum_fold(sum); +} + +static void icmp6_interfere(int type, int code, uint32_t rcv_nxt, + struct sockaddr_in6 *src, struct sockaddr_in6 *dst) +{ + int sk = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW); + struct sockaddr_in6 dst_raw = *dst; + struct { + struct ipv6hdr iph; + struct icmp6hdr icmph; + struct ipv6hdr iphe; + struct { + uint16_t sport; + uint16_t dport; + uint32_t seq; + } tcph; + } packet = {}; + size_t packet_len; + ssize_t bytes; + + + if (sk < 0) + test_error("socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)"); + + packet_len = sizeof(packet) - sizeof(packet.iph); + set_ip6hdr(&packet.iph, packet_len, IPPROTO_ICMPV6, src, dst); + + packet.icmph.icmp6_type = type; + packet.icmph.icmp6_code = code; + + packet_len = sizeof(packet.iphe) + sizeof(packet.tcph); + set_ip6hdr(&packet.iphe, packet_len, IPPROTO_TCP, dst, src); + + packet.tcph.sport = dst->sin6_port; + packet.tcph.dport = src->sin6_port; + packet.tcph.seq = htonl(rcv_nxt); + + packet_len = sizeof(packet) - sizeof(packet.iph); + + packet.icmph.icmp6_cksum = icmp6_checksum(src, dst, + (void *)&packet.icmph, packet_len, IPPROTO_ICMPV6); + + dst_raw.sin6_port = htons(IPPROTO_RAW); + bytes = sendto(sk, &packet, sizeof(packet), 0, + (struct sockaddr *)&dst_raw, sizeof(dst_raw)); + if (bytes != sizeof(packet)) + test_error("send(): %zd", bytes); + icmps_sent++; + + close(sk); +} + +static uint32_t get_rcv_nxt(int sk) +{ + int val = TCP_REPAIR_ON; + uint32_t ret; + socklen_t sz = sizeof(ret); + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val))) + test_error("setsockopt(TCP_REPAIR)"); + val = TCP_RECV_QUEUE; + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &val, sizeof(val))) + test_error("setsockopt(TCP_REPAIR_QUEUE)"); + if (getsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &ret, &sz)) + test_error("getsockopt(TCP_QUEUE_SEQ)"); + val = TCP_REPAIR_OFF_NO_WP; + if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val))) + test_error("setsockopt(TCP_REPAIR)"); + return ret; +} + +static void icmp_interfere(const size_t nr, uint32_t rcv_nxt, void *src, void *dst) +{ + struct sockaddr_in *saddr4 = src; + struct sockaddr_in *daddr4 = dst; + struct sockaddr_in6 *saddr6 = src; + struct sockaddr_in6 *daddr6 = dst; + size_t i; + + if (saddr4->sin_family != daddr4->sin_family) + test_error("Different address families"); + + for (i = 0; i < nr; i++) { + if (saddr4->sin_family == AF_INET) { + icmp_interfere4(ICMP_DEST_UNREACH, ICMP_PROT_UNREACH, + rcv_nxt, saddr4, daddr4); + icmp_interfere4(ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, + rcv_nxt, saddr4, daddr4); + icmp_interfere4(ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED, + rcv_nxt, saddr4, daddr4); + icmps_sent += 3; + } else if (saddr4->sin_family == AF_INET6) { + icmp6_interfere(ICMPV6_DEST_UNREACH, + ICMPV6_ADM_PROHIBITED, + rcv_nxt, saddr6, daddr6); + icmp6_interfere(ICMPV6_DEST_UNREACH, + ICMPV6_PORT_UNREACH, + rcv_nxt, saddr6, daddr6); + icmps_sent += 2; + } else { + test_error("Not ip address family"); + } + } +} + +static void send_interfered(int sk) +{ + struct sockaddr_in6 src, dst; + socklen_t addr_sz; + + addr_sz = sizeof(src); + if (getsockname(sk, &src, &addr_sz)) + test_error("getsockname()"); + addr_sz = sizeof(dst); + if (getpeername(sk, &dst, &addr_sz)) + test_error("getpeername()"); + + while (1) { + uint32_t rcv_nxt; + + if (test_client_verify(sk, packet_size, packets_nr)) { + test_fail("client: connection is broken"); + return; + } + packets_sent += packets_nr; + rcv_nxt = get_rcv_nxt(sk); + icmp_interfere(packets_nr, rcv_nxt, (void *)&src, (void *)&dst); + } +} + +static void *client_fn(void *arg) +{ + int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + + if (sk < 0) + test_error("socket()"); + + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + synchronize_threads(); + if (test_connect_socket(sk, this_ip_dest, test_server_port) <= 0) + test_error("failed to connect()"); + synchronize_threads(); + + send_interfered(sk); + + /* Not expecting client to quit */ + test_fail("client disconnected"); + + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(4, server_fn, client_fn); + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/key-management.c b/tools/testing/selftests/net/tcp_ao/key-management.c new file mode 100644 index 000000000000..69d9a7a05d5c --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/key-management.c @@ -0,0 +1,1198 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Author: Dmitry Safonov <dima@arista.com> */ +#include <inttypes.h> +#include "../../../../include/linux/kernel.h" +#include "aolib.h" + +const size_t nr_packets = 20; +const size_t msg_len = 100; +const size_t quota = nr_packets * msg_len; +union tcp_addr wrong_addr; +#define SECOND_PASSWORD "at all times sincere friends of freedom have been rare" +#define fault(type) (inj == FAULT_ ## type) + +static const int test_vrf_ifindex = 200; +static const uint8_t test_vrf_tabid = 42; +static void setup_vrfs(void) +{ + int err; + + if (!kernel_config_has(KCONFIG_NET_VRF)) + return; + + err = add_vrf("ksft-vrf", test_vrf_tabid, test_vrf_ifindex, -1); + if (err) + test_error("Failed to add a VRF: %d", err); + + err = link_set_up("ksft-vrf"); + if (err) + test_error("Failed to bring up a VRF"); + + err = ip_route_add_vrf(veth_name, TEST_FAMILY, + this_ip_addr, this_ip_dest, test_vrf_tabid); + if (err) + test_error("Failed to add a route to VRF"); +} + + +static int prepare_sk(union tcp_addr *addr, uint8_t sndid, uint8_t rcvid) +{ + int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + + if (sk < 0) + test_error("socket()"); + + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, + DEFAULT_TEST_PREFIX, 100, 100)) + test_error("test_add_key()"); + + if (addr && test_add_key(sk, SECOND_PASSWORD, *addr, + DEFAULT_TEST_PREFIX, sndid, rcvid)) + test_error("test_add_key()"); + + return sk; +} + +static int prepare_lsk(union tcp_addr *addr, uint8_t sndid, uint8_t rcvid) +{ + int sk = prepare_sk(addr, sndid, rcvid); + + if (listen(sk, 10)) + test_error("listen()"); + + return sk; +} + +static int test_del_key(int sk, uint8_t sndid, uint8_t rcvid, bool async, + int current_key, int rnext_key) +{ + struct tcp_ao_info_opt ao_info = {}; + struct tcp_ao_getsockopt key = {}; + struct tcp_ao_del del = {}; + sockaddr_af sockaddr; + int err; + + tcp_addr_to_sockaddr_in(&del.addr, &this_ip_dest, 0); + del.prefix = DEFAULT_TEST_PREFIX; + del.sndid = sndid; + del.rcvid = rcvid; + + if (current_key >= 0) { + del.set_current = 1; + del.current_key = (uint8_t)current_key; + } + if (rnext_key >= 0) { + del.set_rnext = 1; + del.rnext = (uint8_t)rnext_key; + } + + err = setsockopt(sk, IPPROTO_TCP, TCP_AO_DEL_KEY, &del, sizeof(del)); + if (err < 0) + return -errno; + + if (async) + return 0; + + tcp_addr_to_sockaddr_in(&sockaddr, &this_ip_dest, 0); + err = test_get_one_ao(sk, &key, &sockaddr, sizeof(sockaddr), + DEFAULT_TEST_PREFIX, sndid, rcvid); + if (!err) + return -EEXIST; + if (err != -E2BIG) + test_error("getsockopt()"); + if (current_key < 0 && rnext_key < 0) + return 0; + if (test_get_ao_info(sk, &ao_info)) + test_error("getsockopt(TCP_AO_INFO) failed"); + if (current_key >= 0 && ao_info.current_key != (uint8_t)current_key) + return -ENOTRECOVERABLE; + if (rnext_key >= 0 && ao_info.rnext != (uint8_t)rnext_key) + return -ENOTRECOVERABLE; + return 0; +} + +static void try_delete_key(char *tst_name, int sk, uint8_t sndid, uint8_t rcvid, + bool async, int current_key, int rnext_key, + fault_t inj) +{ + int err; + + err = test_del_key(sk, sndid, rcvid, async, current_key, rnext_key); + if ((err == -EBUSY && fault(BUSY)) || (err == -EINVAL && fault(CURRNEXT))) { + test_ok("%s: key deletion was prevented", tst_name); + return; + } + if (err && fault(FIXME)) { + test_xfail("%s: failed to delete the key %u:%u %d", + tst_name, sndid, rcvid, err); + return; + } + if (!err) { + if (fault(BUSY) || fault(CURRNEXT)) { + test_fail("%s: the key was deleted %u:%u %d", tst_name, + sndid, rcvid, err); + } else { + test_ok("%s: the key was deleted", tst_name); + } + return; + } + test_fail("%s: can't delete the key %u:%u %d", tst_name, sndid, rcvid, err); +} + +static int test_set_key(int sk, int current_keyid, int rnext_keyid) +{ + struct tcp_ao_info_opt ao_info = {}; + int err; + + if (current_keyid >= 0) { + ao_info.set_current = 1; + ao_info.current_key = (uint8_t)current_keyid; + } + if (rnext_keyid >= 0) { + ao_info.set_rnext = 1; + ao_info.rnext = (uint8_t)rnext_keyid; + } + + err = test_set_ao_info(sk, &ao_info); + if (err) + return err; + if (test_get_ao_info(sk, &ao_info)) + test_error("getsockopt(TCP_AO_INFO) failed"); + if (current_keyid >= 0 && ao_info.current_key != (uint8_t)current_keyid) + return -ENOTRECOVERABLE; + if (rnext_keyid >= 0 && ao_info.rnext != (uint8_t)rnext_keyid) + return -ENOTRECOVERABLE; + return 0; +} + +static int test_add_current_rnext_key(int sk, const char *key, uint8_t keyflags, + union tcp_addr in_addr, uint8_t prefix, + bool set_current, bool set_rnext, + uint8_t sndid, uint8_t rcvid) +{ + struct tcp_ao_add tmp = {}; + int err; + + err = test_prepare_key(&tmp, DEFAULT_TEST_ALGO, in_addr, + set_current, set_rnext, + prefix, 0, sndid, rcvid, 0, keyflags, + strlen(key), key); + if (err) + return err; + + + err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)); + if (err < 0) + return -errno; + + return test_verify_socket_key(sk, &tmp); +} + +static int __try_add_current_rnext_key(int sk, const char *key, uint8_t keyflags, + union tcp_addr in_addr, uint8_t prefix, + bool set_current, bool set_rnext, + uint8_t sndid, uint8_t rcvid) +{ + struct tcp_ao_info_opt ao_info = {}; + int err; + + err = test_add_current_rnext_key(sk, key, keyflags, in_addr, prefix, + set_current, set_rnext, sndid, rcvid); + if (err) + return err; + + if (test_get_ao_info(sk, &ao_info)) + test_error("getsockopt(TCP_AO_INFO) failed"); + if (set_current && ao_info.current_key != sndid) + return -ENOTRECOVERABLE; + if (set_rnext && ao_info.rnext != rcvid) + return -ENOTRECOVERABLE; + return 0; +} + +static void try_add_current_rnext_key(char *tst_name, int sk, const char *key, + uint8_t keyflags, + union tcp_addr in_addr, uint8_t prefix, + bool set_current, bool set_rnext, + uint8_t sndid, uint8_t rcvid, fault_t inj) +{ + int err; + + err = __try_add_current_rnext_key(sk, key, keyflags, in_addr, prefix, + set_current, set_rnext, sndid, rcvid); + if (!err && !fault(CURRNEXT)) { + test_ok("%s", tst_name); + return; + } + if (err == -EINVAL && fault(CURRNEXT)) { + test_ok("%s", tst_name); + return; + } + test_fail("%s", tst_name); +} + +static void check_closed_socket(void) +{ + int sk; + + sk = prepare_sk(&this_ip_dest, 200, 200); + try_delete_key("closed socket, delete a key", sk, 200, 200, 0, -1, -1, 0); + try_delete_key("closed socket, delete all keys", sk, 100, 100, 0, -1, -1, 0); + close(sk); + + sk = prepare_sk(&this_ip_dest, 200, 200); + if (test_set_key(sk, 100, 200)) + test_error("failed to set current/rnext keys"); + try_delete_key("closed socket, delete current key", sk, 100, 100, 0, -1, -1, FAULT_BUSY); + try_delete_key("closed socket, delete rnext key", sk, 200, 200, 0, -1, -1, FAULT_BUSY); + close(sk); + + sk = prepare_sk(&this_ip_dest, 200, 200); + if (test_add_key(sk, "Glory to heros!", this_ip_dest, + DEFAULT_TEST_PREFIX, 10, 11)) + test_error("test_add_key()"); + if (test_add_key(sk, "Glory to Ukraine!", this_ip_dest, + DEFAULT_TEST_PREFIX, 12, 13)) + test_error("test_add_key()"); + try_delete_key("closed socket, delete a key + set current/rnext", sk, 100, 100, 0, 10, 13, 0); + try_delete_key("closed socket, force-delete current key", sk, 10, 11, 0, 200, -1, 0); + try_delete_key("closed socket, force-delete rnext key", sk, 12, 13, 0, -1, 200, 0); + try_delete_key("closed socket, delete current+rnext key", sk, 200, 200, 0, -1, -1, FAULT_BUSY); + close(sk); + + sk = prepare_sk(&this_ip_dest, 200, 200); + if (test_set_key(sk, 100, 200)) + test_error("failed to set current/rnext keys"); + try_add_current_rnext_key("closed socket, add + change current key", + sk, "Laaaa! Lalala-la-la-lalala...", 0, + this_ip_dest, DEFAULT_TEST_PREFIX, + true, false, 10, 20, 0); + try_add_current_rnext_key("closed socket, add + change rnext key", + sk, "Laaaa! Lalala-la-la-lalala...", 0, + this_ip_dest, DEFAULT_TEST_PREFIX, + false, true, 20, 10, 0); + close(sk); +} + +static void assert_no_current_rnext(const char *tst_msg, int sk) +{ + struct tcp_ao_info_opt ao_info = {}; + + if (test_get_ao_info(sk, &ao_info)) + test_error("getsockopt(TCP_AO_INFO) failed"); + + errno = 0; + if (ao_info.set_current || ao_info.set_rnext) { + test_xfail("%s: the socket has current/rnext keys: %d:%d", + tst_msg, + (ao_info.set_current) ? ao_info.current_key : -1, + (ao_info.set_rnext) ? ao_info.rnext : -1); + } else { + test_ok("%s: the socket has no current/rnext keys", tst_msg); + } +} + +static void assert_no_tcp_repair(void) +{ + struct tcp_ao_repair ao_img = {}; + socklen_t len = sizeof(ao_img); + int sk, err; + + sk = prepare_sk(&this_ip_dest, 200, 200); + test_enable_repair(sk); + if (listen(sk, 10)) + test_error("listen()"); + errno = 0; + err = getsockopt(sk, SOL_TCP, TCP_AO_REPAIR, &ao_img, &len); + if (err && errno == EPERM) + test_ok("listen socket, getsockopt(TCP_AO_REPAIR) is restricted"); + else + test_fail("listen socket, getsockopt(TCP_AO_REPAIR) works"); + errno = 0; + err = setsockopt(sk, SOL_TCP, TCP_AO_REPAIR, &ao_img, sizeof(ao_img)); + if (err && errno == EPERM) + test_ok("listen socket, setsockopt(TCP_AO_REPAIR) is restricted"); + else + test_fail("listen socket, setsockopt(TCP_AO_REPAIR) works"); + close(sk); +} + +static void check_listen_socket(void) +{ + int sk, err; + + sk = prepare_lsk(&this_ip_dest, 200, 200); + try_delete_key("listen socket, delete a key", sk, 200, 200, 0, -1, -1, 0); + try_delete_key("listen socket, delete all keys", sk, 100, 100, 0, -1, -1, 0); + close(sk); + + sk = prepare_lsk(&this_ip_dest, 200, 200); + err = test_set_key(sk, 100, -1); + if (err == -EINVAL) + test_ok("listen socket, setting current key not allowed"); + else + test_fail("listen socket, set current key"); + err = test_set_key(sk, -1, 200); + if (err == -EINVAL) + test_ok("listen socket, setting rnext key not allowed"); + else + test_fail("listen socket, set rnext key"); + close(sk); + + sk = prepare_sk(&this_ip_dest, 200, 200); + if (test_set_key(sk, 100, 200)) + test_error("failed to set current/rnext keys"); + if (listen(sk, 10)) + test_error("listen()"); + assert_no_current_rnext("listen() after current/rnext keys set", sk); + try_delete_key("listen socket, delete current key from before listen()", sk, 100, 100, 0, -1, -1, FAULT_FIXME); + try_delete_key("listen socket, delete rnext key from before listen()", sk, 200, 200, 0, -1, -1, FAULT_FIXME); + close(sk); + + assert_no_tcp_repair(); + + sk = prepare_lsk(&this_ip_dest, 200, 200); + if (test_add_key(sk, "Glory to heros!", this_ip_dest, + DEFAULT_TEST_PREFIX, 10, 11)) + test_error("test_add_key()"); + if (test_add_key(sk, "Glory to Ukraine!", this_ip_dest, + DEFAULT_TEST_PREFIX, 12, 13)) + test_error("test_add_key()"); + try_delete_key("listen socket, delete a key + set current/rnext", sk, + 100, 100, 0, 10, 13, FAULT_CURRNEXT); + try_delete_key("listen socket, force-delete current key", sk, + 10, 11, 0, 200, -1, FAULT_CURRNEXT); + try_delete_key("listen socket, force-delete rnext key", sk, + 12, 13, 0, -1, 200, FAULT_CURRNEXT); + try_delete_key("listen socket, delete a key", sk, + 200, 200, 0, -1, -1, 0); + close(sk); + + sk = prepare_lsk(&this_ip_dest, 200, 200); + try_add_current_rnext_key("listen socket, add + change current key", + sk, "Laaaa! Lalala-la-la-lalala...", 0, + this_ip_dest, DEFAULT_TEST_PREFIX, + true, false, 10, 20, FAULT_CURRNEXT); + try_add_current_rnext_key("listen socket, add + change rnext key", + sk, "Laaaa! Lalala-la-la-lalala...", 0, + this_ip_dest, DEFAULT_TEST_PREFIX, + false, true, 20, 10, FAULT_CURRNEXT); + close(sk); +} + +static const char *fips_fpath = "/proc/sys/crypto/fips_enabled"; +static bool is_fips_enabled(void) +{ + static int fips_checked = -1; + FILE *fenabled; + int enabled; + + if (fips_checked >= 0) + return !!fips_checked; + if (access(fips_fpath, R_OK)) { + if (errno != ENOENT) + test_error("Can't open %s", fips_fpath); + fips_checked = 0; + return false; + } + fenabled = fopen(fips_fpath, "r"); + if (!fenabled) + test_error("Can't open %s", fips_fpath); + if (fscanf(fenabled, "%d", &enabled) != 1) + test_error("Can't read from %s", fips_fpath); + fclose(fenabled); + fips_checked = !!enabled; + return !!fips_checked; +} + +struct test_key { + char password[TCP_AO_MAXKEYLEN]; + const char *alg; + unsigned int len; + uint8_t client_keyid; + uint8_t server_keyid; + uint8_t maclen; + uint8_t matches_client : 1, + matches_server : 1, + matches_vrf : 1, + is_current : 1, + is_rnext : 1, + used_on_server_tx : 1, + used_on_client_tx : 1, + skip_counters_checks : 1; +}; + +struct key_collection { + unsigned int nr_keys; + struct test_key *keys; +}; + +static struct key_collection collection; + +#define TEST_MAX_MACLEN 16 +const char *test_algos[] = { + "cmac(aes128)", + "hmac(sha1)", "hmac(sha512)", "hmac(sha384)", "hmac(sha256)", + "hmac(sha224)", "hmac(sha3-512)", + /* only if !CONFIG_FIPS */ +#define TEST_NON_FIPS_ALGOS 2 + "hmac(rmd160)", "hmac(md5)" +}; +const unsigned int test_maclens[] = { 1, 4, 12, 16 }; +#define MACLEN_SHIFT 2 +#define ALGOS_SHIFT 4 + +static unsigned int make_mask(unsigned int shift, unsigned int prev_shift) +{ + unsigned int ret = BIT(shift) - 1; + + return ret << prev_shift; +} + +static void init_key_in_collection(unsigned int index, bool randomized) +{ + struct test_key *key = &collection.keys[index]; + unsigned int algos_nr, algos_index; + + /* Same for randomized and non-randomized test flows */ + key->client_keyid = index; + key->server_keyid = 127 + index; + key->matches_client = 1; + key->matches_server = 1; + key->matches_vrf = 1; + /* not really even random, but good enough for a test */ + key->len = rand() % (TCP_AO_MAXKEYLEN - TEST_TCP_AO_MINKEYLEN); + key->len += TEST_TCP_AO_MINKEYLEN; + randomize_buffer(key->password, key->len); + + if (randomized) { + key->maclen = (rand() % TEST_MAX_MACLEN) + 1; + algos_index = rand(); + } else { + unsigned int shift = MACLEN_SHIFT; + + key->maclen = test_maclens[index & make_mask(shift, 0)]; + algos_index = index & make_mask(ALGOS_SHIFT, shift); + } + algos_nr = ARRAY_SIZE(test_algos); + if (is_fips_enabled()) + algos_nr -= TEST_NON_FIPS_ALGOS; + key->alg = test_algos[algos_index % algos_nr]; +} + +static int init_default_key_collection(unsigned int nr_keys, bool randomized) +{ + size_t key_sz = sizeof(collection.keys[0]); + + if (!nr_keys) { + free(collection.keys); + collection.keys = NULL; + return 0; + } + + /* + * All keys have uniq sndid/rcvid and sndid != rcvid in order to + * check for any bugs/issues for different keyids, visible to both + * peers. Keyid == 254 is unused. + */ + if (nr_keys > 127) + test_error("Test requires too many keys, correct the source"); + + collection.keys = reallocarray(collection.keys, nr_keys, key_sz); + if (!collection.keys) + return -ENOMEM; + + memset(collection.keys, 0, nr_keys * key_sz); + collection.nr_keys = nr_keys; + while (nr_keys--) + init_key_in_collection(nr_keys, randomized); + + return 0; +} + +static void test_key_error(const char *msg, struct test_key *key) +{ + test_error("%s: key: { %s, %u:%u, %u, %u:%u:%u:%u:%u (%u)}", + msg, key->alg, key->client_keyid, key->server_keyid, + key->maclen, key->matches_client, key->matches_server, + key->matches_vrf, key->is_current, key->is_rnext, key->len); +} + +static int test_add_key_cr(int sk, const char *pwd, unsigned int pwd_len, + union tcp_addr addr, uint8_t vrf, + uint8_t sndid, uint8_t rcvid, + uint8_t maclen, const char *alg, + bool set_current, bool set_rnext) +{ + struct tcp_ao_add tmp = {}; + uint8_t keyflags = 0; + int err; + + if (!alg) + alg = DEFAULT_TEST_ALGO; + + if (vrf) + keyflags |= TCP_AO_KEYF_IFINDEX; + err = test_prepare_key(&tmp, alg, addr, set_current, set_rnext, + DEFAULT_TEST_PREFIX, vrf, sndid, rcvid, maclen, + keyflags, pwd_len, pwd); + if (err) + return err; + + err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)); + if (err < 0) + return -errno; + + return test_verify_socket_key(sk, &tmp); +} + +static void verify_current_rnext(const char *tst, int sk, + int current_keyid, int rnext_keyid) +{ + struct tcp_ao_info_opt ao_info = {}; + + if (test_get_ao_info(sk, &ao_info)) + test_error("getsockopt(TCP_AO_INFO) failed"); + + errno = 0; + if (current_keyid >= 0) { + if (!ao_info.set_current) + test_fail("%s: the socket doesn't have current key", tst); + else if (ao_info.current_key != current_keyid) + test_fail("%s: current key is not the expected one %d != %u", + tst, current_keyid, ao_info.current_key); + else + test_ok("%s: current key %u as expected", + tst, ao_info.current_key); + } + if (rnext_keyid >= 0) { + if (!ao_info.set_rnext) + test_fail("%s: the socket doesn't have rnext key", tst); + else if (ao_info.rnext != rnext_keyid) + test_fail("%s: rnext key is not the expected one %d != %u", + tst, rnext_keyid, ao_info.rnext); + else + test_ok("%s: rnext key %u as expected", tst, ao_info.rnext); + } +} + + +static int key_collection_socket(bool server, unsigned int port) +{ + unsigned int i; + int sk; + + if (server) + sk = test_listen_socket(this_ip_addr, port, 1); + else + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + for (i = 0; i < collection.nr_keys; i++) { + struct test_key *key = &collection.keys[i]; + union tcp_addr *addr = &wrong_addr; + uint8_t sndid, rcvid, vrf; + bool set_current = false, set_rnext = false; + + if (key->matches_vrf) + vrf = 0; + else + vrf = test_vrf_ifindex; + if (server) { + if (key->matches_client) + addr = &this_ip_dest; + sndid = key->server_keyid; + rcvid = key->client_keyid; + } else { + if (key->matches_server) + addr = &this_ip_dest; + sndid = key->client_keyid; + rcvid = key->server_keyid; + key->used_on_client_tx = set_current = key->is_current; + key->used_on_server_tx = set_rnext = key->is_rnext; + } + + if (test_add_key_cr(sk, key->password, key->len, + *addr, vrf, sndid, rcvid, key->maclen, + key->alg, set_current, set_rnext)) + test_key_error("setsockopt(TCP_AO_ADD_KEY)", key); +#ifdef DEBUG + test_print("%s [%u/%u] key: { %s, %u:%u, %u, %u:%u:%u:%u (%u)}", + server ? "server" : "client", i, collection.nr_keys, + key->alg, rcvid, sndid, key->maclen, + key->matches_client, key->matches_server, + key->is_current, key->is_rnext, key->len); +#endif + } + return sk; +} + +static void verify_counters(const char *tst_name, bool is_listen_sk, bool server, + struct tcp_counters *a, struct tcp_counters *b) +{ + unsigned int i; + + test_assert_counters_sk(tst_name, a, b, TEST_CNT_GOOD); + + for (i = 0; i < collection.nr_keys; i++) { + struct test_key *key = &collection.keys[i]; + uint8_t sndid, rcvid; + bool rx_cnt_expected; + + if (key->skip_counters_checks) + continue; + if (server) { + sndid = key->server_keyid; + rcvid = key->client_keyid; + rx_cnt_expected = key->used_on_client_tx; + } else { + sndid = key->client_keyid; + rcvid = key->server_keyid; + rx_cnt_expected = key->used_on_server_tx; + } + + test_assert_counters_key(tst_name, &a->ao, &b->ao, + rx_cnt_expected ? TEST_CNT_KEY_GOOD : 0, + sndid, rcvid); + } + test_tcp_counters_free(a); + test_tcp_counters_free(b); + test_ok("%s: passed counters checks", tst_name); +} + +static struct tcp_ao_getsockopt *lookup_key(struct tcp_ao_getsockopt *buf, + size_t len, int sndid, int rcvid) +{ + size_t i; + + for (i = 0; i < len; i++) { + if (sndid >= 0 && buf[i].sndid != sndid) + continue; + if (rcvid >= 0 && buf[i].rcvid != rcvid) + continue; + return &buf[i]; + } + return NULL; +} + +static void verify_keys(const char *tst_name, int sk, + bool is_listen_sk, bool server) +{ + socklen_t len = sizeof(struct tcp_ao_getsockopt); + struct tcp_ao_getsockopt *keys; + bool passed_test = true; + unsigned int i; + + keys = calloc(collection.nr_keys, len); + if (!keys) + test_error("calloc()"); + + keys->nkeys = collection.nr_keys; + keys->get_all = 1; + + if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, keys, &len)) { + free(keys); + test_error("getsockopt(TCP_AO_GET_KEYS)"); + } + + for (i = 0; i < collection.nr_keys; i++) { + struct test_key *key = &collection.keys[i]; + struct tcp_ao_getsockopt *dump_key; + bool is_kdf_aes_128_cmac = false; + bool is_cmac_aes = false; + uint8_t sndid, rcvid; + bool matches = false; + + if (server) { + if (key->matches_client) + matches = true; + sndid = key->server_keyid; + rcvid = key->client_keyid; + } else { + if (key->matches_server) + matches = true; + sndid = key->client_keyid; + rcvid = key->server_keyid; + } + if (!key->matches_vrf) + matches = false; + /* no keys get removed on the original listener socket */ + if (is_listen_sk) + matches = true; + + dump_key = lookup_key(keys, keys->nkeys, sndid, rcvid); + if (matches != !!dump_key) { + test_fail("%s: key %u:%u %s%s on the socket", + tst_name, sndid, rcvid, + key->matches_vrf ? "" : "[vrf] ", + matches ? "disappeared" : "yet present"); + passed_test = false; + goto out; + } + if (!dump_key) + continue; + + if (!strcmp("cmac(aes128)", key->alg)) { + is_kdf_aes_128_cmac = (key->len != 16); + is_cmac_aes = true; + } + + if (is_cmac_aes) { + if (strcmp(dump_key->alg_name, "cmac(aes)")) { + test_fail("%s: key %u:%u cmac(aes) has unexpected alg %s", + tst_name, sndid, rcvid, + dump_key->alg_name); + passed_test = false; + continue; + } + } else if (strcmp(dump_key->alg_name, key->alg)) { + test_fail("%s: key %u:%u has unexpected alg %s != %s", + tst_name, sndid, rcvid, + dump_key->alg_name, key->alg); + passed_test = false; + continue; + } + if (is_kdf_aes_128_cmac) { + if (dump_key->keylen != 16) { + test_fail("%s: key %u:%u cmac(aes128) has unexpected len %u", + tst_name, sndid, rcvid, + dump_key->keylen); + continue; + } + } else if (dump_key->keylen != key->len) { + test_fail("%s: key %u:%u changed password len %u != %u", + tst_name, sndid, rcvid, + dump_key->keylen, key->len); + passed_test = false; + continue; + } + if (!is_kdf_aes_128_cmac && + memcmp(dump_key->key, key->password, key->len)) { + test_fail("%s: key %u:%u has different password", + tst_name, sndid, rcvid); + passed_test = false; + continue; + } + if (dump_key->maclen != key->maclen) { + test_fail("%s: key %u:%u changed maclen %u != %u", + tst_name, sndid, rcvid, + dump_key->maclen, key->maclen); + passed_test = false; + continue; + } + } + + if (passed_test) + test_ok("%s: The socket keys are consistent with the expectations", + tst_name); +out: + free(keys); +} + +static int start_server(const char *tst_name, unsigned int port, size_t quota, + struct tcp_counters *begin, + unsigned int current_index, unsigned int rnext_index) +{ + struct tcp_counters lsk_c1, lsk_c2; + ssize_t bytes; + int sk, lsk; + + synchronize_threads(); /* 1: key collection initialized */ + lsk = key_collection_socket(true, port); + if (test_get_tcp_counters(lsk, &lsk_c1)) + test_error("test_get_tcp_counters()"); + synchronize_threads(); /* 2: MKTs added => connect() */ + if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0)) + test_error("test_wait_fd()"); + + sk = accept(lsk, NULL, NULL); + if (sk < 0) + test_error("accept()"); + if (test_get_tcp_counters(sk, begin)) + test_error("test_get_tcp_counters()"); + + synchronize_threads(); /* 3: accepted => send data */ + if (test_get_tcp_counters(lsk, &lsk_c2)) + test_error("test_get_tcp_counters()"); + verify_keys(tst_name, lsk, true, true); + close(lsk); + + bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC); + if (bytes != quota) + test_fail("%s: server served: %zd", tst_name, bytes); + else + test_ok("%s: server alive", tst_name); + + verify_counters(tst_name, true, true, &lsk_c1, &lsk_c2); + + return sk; +} + +static void end_server(const char *tst_name, int sk, + struct tcp_counters *begin) +{ + struct tcp_counters end; + + if (test_get_tcp_counters(sk, &end)) + test_error("test_get_tcp_counters()"); + verify_keys(tst_name, sk, false, true); + + synchronize_threads(); /* 4: verified => closed */ + close(sk); + + verify_counters(tst_name, false, true, begin, &end); + synchronize_threads(); /* 5: counters */ +} + +static void try_server_run(const char *tst_name, unsigned int port, size_t quota, + unsigned int current_index, unsigned int rnext_index) +{ + struct tcp_counters tmp; + int sk; + + sk = start_server(tst_name, port, quota, &tmp, + current_index, rnext_index); + end_server(tst_name, sk, &tmp); +} + +static void server_rotations(const char *tst_name, unsigned int port, + size_t quota, unsigned int rotations, + unsigned int current_index, unsigned int rnext_index) +{ + struct tcp_counters tmp; + unsigned int i; + int sk; + + sk = start_server(tst_name, port, quota, &tmp, + current_index, rnext_index); + + for (i = current_index + 1; rotations > 0; i++, rotations--) { + ssize_t bytes; + + if (i >= collection.nr_keys) + i = 0; + bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC); + if (bytes != quota) { + test_fail("%s: server served: %zd", tst_name, bytes); + return; + } + verify_current_rnext(tst_name, sk, + collection.keys[i].server_keyid, -1); + synchronize_threads(); /* verify current/rnext */ + } + end_server(tst_name, sk, &tmp); +} + +static int run_client(const char *tst_name, unsigned int port, + unsigned int nr_keys, int current_index, int rnext_index, + struct tcp_counters *before, + const size_t msg_sz, const size_t msg_nr) +{ + int sk; + + synchronize_threads(); /* 1: key collection initialized */ + sk = key_collection_socket(false, port); + + if (current_index >= 0 || rnext_index >= 0) { + int sndid = -1, rcvid = -1; + + if (current_index >= 0) + sndid = collection.keys[current_index].client_keyid; + if (rnext_index >= 0) + rcvid = collection.keys[rnext_index].server_keyid; + if (test_set_key(sk, sndid, rcvid)) + test_error("failed to set current/rnext keys"); + } + if (before && test_get_tcp_counters(sk, before)) + test_error("test_get_tcp_counters()"); + + synchronize_threads(); /* 2: MKTs added => connect() */ + if (test_connect_socket(sk, this_ip_dest, port++) <= 0) + test_error("failed to connect()"); + if (current_index < 0) + current_index = nr_keys - 1; + if (rnext_index < 0) + rnext_index = nr_keys - 1; + collection.keys[current_index].used_on_client_tx = 1; + collection.keys[rnext_index].used_on_server_tx = 1; + + synchronize_threads(); /* 3: accepted => send data */ + if (test_client_verify(sk, msg_sz, msg_nr)) { + test_fail("verify failed"); + close(sk); + if (before) + test_tcp_counters_free(before); + return -1; + } + + return sk; +} + +static int start_client(const char *tst_name, unsigned int port, + unsigned int nr_keys, int current_index, int rnext_index, + struct tcp_counters *before, + const size_t msg_sz, const size_t msg_nr) +{ + if (init_default_key_collection(nr_keys, true)) + test_error("Failed to init the key collection"); + + return run_client(tst_name, port, nr_keys, current_index, + rnext_index, before, msg_sz, msg_nr); +} + +static void end_client(const char *tst_name, int sk, unsigned int nr_keys, + int current_index, int rnext_index, + struct tcp_counters *start) +{ + struct tcp_counters end; + + /* Some application may become dependent on this kernel choice */ + if (current_index < 0) + current_index = nr_keys - 1; + if (rnext_index < 0) + rnext_index = nr_keys - 1; + verify_current_rnext(tst_name, sk, + collection.keys[current_index].client_keyid, + collection.keys[rnext_index].server_keyid); + if (start && test_get_tcp_counters(sk, &end)) + test_error("test_get_tcp_counters()"); + verify_keys(tst_name, sk, false, false); + synchronize_threads(); /* 4: verify => closed */ + close(sk); + if (start) + verify_counters(tst_name, false, false, start, &end); + synchronize_threads(); /* 5: counters */ +} + +static void try_unmatched_keys(int sk, int *rnext_index, unsigned int port) +{ + struct test_key *key; + unsigned int i = 0; + int err; + + do { + key = &collection.keys[i]; + if (!key->matches_server) + break; + } while (++i < collection.nr_keys); + if (key->matches_server) + test_error("all keys on client match the server"); + + err = test_add_key_cr(sk, key->password, key->len, wrong_addr, + 0, key->client_keyid, key->server_keyid, + key->maclen, key->alg, 0, 0); + if (!err) { + test_fail("Added a key with non-matching ip-address for established sk"); + return; + } + if (err == -EINVAL) + test_ok("Can't add a key with non-matching ip-address for established sk"); + else + test_error("Failed to add a key"); + + err = test_add_key_cr(sk, key->password, key->len, this_ip_dest, + test_vrf_ifindex, + key->client_keyid, key->server_keyid, + key->maclen, key->alg, 0, 0); + if (!err) { + test_fail("Added a key with non-matching VRF for established sk"); + return; + } + if (err == -EINVAL) + test_ok("Can't add a key with non-matching VRF for established sk"); + else + test_error("Failed to add a key"); + + for (i = 0; i < collection.nr_keys; i++) { + key = &collection.keys[i]; + if (!key->matches_client) + break; + } + if (key->matches_client) + test_error("all keys on server match the client"); + if (test_set_key(sk, -1, key->server_keyid)) + test_error("Can't change the current key"); + trace_ao_event_expect(TCP_AO_RNEXT_REQUEST, this_ip_addr, this_ip_dest, + -1, port, 0, -1, -1, -1, -1, -1, + -1, key->server_keyid, -1); + if (test_client_verify(sk, msg_len, nr_packets)) + test_fail("verify failed"); + *rnext_index = i; +} + +static int client_non_matching(const char *tst_name, unsigned int port, + unsigned int nr_keys, + int current_index, int rnext_index, + const size_t msg_sz, const size_t msg_nr) +{ + unsigned int i; + + if (init_default_key_collection(nr_keys, true)) + test_error("Failed to init the key collection"); + + for (i = 0; i < nr_keys; i++) { + /* key (0, 0) matches */ + collection.keys[i].matches_client = !!((i + 3) % 4); + collection.keys[i].matches_server = !!((i + 2) % 4); + if (kernel_config_has(KCONFIG_NET_VRF)) + collection.keys[i].matches_vrf = !!((i + 1) % 4); + } + + return run_client(tst_name, port, nr_keys, current_index, + rnext_index, NULL, msg_sz, msg_nr); +} + +static void check_current_back(const char *tst_name, unsigned int port, + unsigned int nr_keys, + unsigned int current_index, unsigned int rnext_index, + unsigned int rotate_to_index) +{ + struct tcp_counters tmp; + int sk; + + sk = start_client(tst_name, port, nr_keys, current_index, rnext_index, + &tmp, msg_len, nr_packets); + if (sk < 0) + return; + if (test_set_key(sk, collection.keys[rotate_to_index].client_keyid, -1)) + test_error("Can't change the current key"); + trace_ao_event_expect(TCP_AO_RNEXT_REQUEST, this_ip_dest, this_ip_addr, + port, -1, 0, -1, -1, -1, -1, -1, + collection.keys[rotate_to_index].client_keyid, + collection.keys[current_index].client_keyid, -1); + if (test_client_verify(sk, msg_len, nr_packets)) + test_fail("verify failed"); + /* There is a race here: between setting the current_key with + * setsockopt(TCP_AO_INFO) and starting to send some data - there + * might have been a segment received with the desired + * RNext_key set. In turn that would mean that the first outgoing + * segment will have the desired current_key (flipped back). + * Which is what the user/test wants. As it's racy, skip checking + * the counters, yet check what are the resulting current/rnext + * keys on both sides. + */ + collection.keys[rotate_to_index].skip_counters_checks = 1; + + end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp); +} + +static void roll_over_keys(const char *tst_name, unsigned int port, + unsigned int nr_keys, unsigned int rotations, + unsigned int current_index, unsigned int rnext_index) +{ + struct tcp_counters tmp; + unsigned int i; + int sk; + + sk = start_client(tst_name, port, nr_keys, current_index, rnext_index, + &tmp, msg_len, nr_packets); + if (sk < 0) + return; + for (i = rnext_index + 1; rotations > 0; i++, rotations--) { + if (i >= collection.nr_keys) + i = 0; + trace_ao_event_expect(TCP_AO_RNEXT_REQUEST, + this_ip_addr, this_ip_dest, + -1, port, 0, -1, -1, -1, -1, -1, + i == 0 ? -1 : collection.keys[i - 1].server_keyid, + collection.keys[i].server_keyid, -1); + if (test_set_key(sk, -1, collection.keys[i].server_keyid)) + test_error("Can't change the Rnext key"); + if (test_client_verify(sk, msg_len, nr_packets)) { + test_fail("verify failed"); + close(sk); + test_tcp_counters_free(&tmp); + return; + } + verify_current_rnext(tst_name, sk, -1, + collection.keys[i].server_keyid); + collection.keys[i].used_on_server_tx = 1; + synchronize_threads(); /* verify current/rnext */ + } + end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp); +} + +static void try_client_run(const char *tst_name, unsigned int port, + unsigned int nr_keys, int current_index, int rnext_index) +{ + struct tcp_counters tmp; + int sk; + + sk = start_client(tst_name, port, nr_keys, current_index, rnext_index, + &tmp, msg_len, nr_packets); + if (sk < 0) + return; + end_client(tst_name, sk, nr_keys, current_index, rnext_index, &tmp); +} + +static void try_client_match(const char *tst_name, unsigned int port, + unsigned int nr_keys, + int current_index, int rnext_index) +{ + int sk; + + sk = client_non_matching(tst_name, port, nr_keys, current_index, + rnext_index, msg_len, nr_packets); + if (sk < 0) + return; + try_unmatched_keys(sk, &rnext_index, port); + end_client(tst_name, sk, nr_keys, current_index, rnext_index, NULL); +} + +static void *server_fn(void *arg) +{ + unsigned int port = test_server_port; + + setup_vrfs(); + try_server_run("server: Check current/rnext keys unset before connect()", + port++, quota, 19, 19); + try_server_run("server: Check current/rnext keys set before connect()", + port++, quota, 10, 10); + try_server_run("server: Check current != rnext keys set before connect()", + port++, quota, 5, 10); + try_server_run("server: Check current flapping back on peer's RnextKey request", + port++, quota * 2, 5, 10); + server_rotations("server: Rotate over all different keys", port++, + quota, 20, 0, 0); + try_server_run("server: Check accept() => established key matching", + port++, quota * 2, 0, 0); + + synchronize_threads(); /* don't race to exit: client exits */ + return NULL; +} + +static void check_established_socket(void) +{ + unsigned int port = test_server_port; + + setup_vrfs(); + try_client_run("client: Check current/rnext keys unset before connect()", + port++, 20, -1, -1); + try_client_run("client: Check current/rnext keys set before connect()", + port++, 20, 10, 10); + try_client_run("client: Check current != rnext keys set before connect()", + port++, 20, 10, 5); + check_current_back("client: Check current flapping back on peer's RnextKey request", + port++, 20, 10, 5, 2); + roll_over_keys("client: Rotate over all different keys", port++, + 20, 20, 0, 0); + try_client_match("client: Check connect() => established key matching", + port++, 20, 0, 0); +} + +static void *client_fn(void *arg) +{ + if (inet_pton(TEST_FAMILY, TEST_WRONG_IP, &wrong_addr) != 1) + test_error("Can't convert ip address %s", TEST_WRONG_IP); + check_closed_socket(); + check_listen_socket(); + check_established_socket(); + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(121, server_fn, client_fn); + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/aolib.h b/tools/testing/selftests/net/tcp_ao/lib/aolib.h new file mode 100644 index 000000000000..ebb2899c12fe --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/aolib.h @@ -0,0 +1,832 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * TCP-AO selftest library. Provides helpers to unshare network + * namespaces, create veth, assign ip addresses, set routes, + * manipulate socket options, read network counter and etc. + * Author: Dmitry Safonov <dima@arista.com> + */ +#ifndef _AOLIB_H_ +#define _AOLIB_H_ + +#include <arpa/inet.h> +#include <errno.h> +#include <linux/snmp.h> +#include <linux/tcp.h> +#include <netinet/in.h> +#include <stdarg.h> +#include <stdbool.h> +#include <stdlib.h> +#include <stdio.h> +#include <string.h> +#include <sys/syscall.h> +#include <unistd.h> + +#include "../../../../../include/linux/stringify.h" +#include "../../../../../include/linux/bits.h" + +#ifndef SOL_TCP +/* can't include <netinet/tcp.h> as including <linux/tcp.h> */ +# define SOL_TCP 6 /* TCP level */ +#endif + +/* Working around ksft, see the comment in lib/setup.c */ +extern void __test_msg(const char *buf); +extern void __test_ok(const char *buf); +extern void __test_fail(const char *buf); +extern void __test_xfail(const char *buf); +extern void __test_error(const char *buf); +extern void __test_skip(const char *buf); + +static inline char *test_snprintf(const char *fmt, va_list vargs) +{ + char *ret = NULL; + size_t size = 0; + va_list tmp; + int n = 0; + + va_copy(tmp, vargs); + n = vsnprintf(ret, size, fmt, tmp); + va_end(tmp); + if (n < 0) + return NULL; + + size = n + 1; + ret = malloc(size); + if (!ret) + return NULL; + + n = vsnprintf(ret, size, fmt, vargs); + if (n < 0 || n > size - 1) { + free(ret); + return NULL; + } + return ret; +} + +static __printf(1, 2) inline char *test_sprintf(const char *fmt, ...) +{ + va_list vargs; + char *ret; + + va_start(vargs, fmt); + ret = test_snprintf(fmt, vargs); + va_end(vargs); + + return ret; +} + +static __printf(2, 3) inline void __test_print(void (*fn)(const char *), + const char *fmt, ...) +{ + va_list vargs; + char *msg; + + va_start(vargs, fmt); + msg = test_snprintf(fmt, vargs); + va_end(vargs); + + if (!msg) + return; + + fn(msg); + free(msg); +} + +#define test_print(fmt, ...) \ + __test_print(__test_msg, "%ld[%s:%u] " fmt "\n", \ + syscall(SYS_gettid), \ + __FILE__, __LINE__, ##__VA_ARGS__) + +#define test_ok(fmt, ...) \ + __test_print(__test_ok, fmt "\n", ##__VA_ARGS__) +#define test_skip(fmt, ...) \ + __test_print(__test_skip, fmt "\n", ##__VA_ARGS__) +#define test_xfail(fmt, ...) \ + __test_print(__test_xfail, fmt "\n", ##__VA_ARGS__) + +#define test_fail(fmt, ...) \ +do { \ + if (errno) \ + __test_print(__test_fail, fmt ": %m\n", ##__VA_ARGS__); \ + else \ + __test_print(__test_fail, fmt "\n", ##__VA_ARGS__); \ + test_failed(); \ +} while (0) + +#define KSFT_FAIL 1 +#define test_error(fmt, ...) \ +do { \ + if (errno) \ + __test_print(__test_error, "%ld[%s:%u] " fmt ": %m\n", \ + syscall(SYS_gettid), __FILE__, __LINE__, \ + ##__VA_ARGS__); \ + else \ + __test_print(__test_error, "%ld[%s:%u] " fmt "\n", \ + syscall(SYS_gettid), __FILE__, __LINE__, \ + ##__VA_ARGS__); \ + exit(KSFT_FAIL); \ +} while (0) + +enum test_fault { + FAULT_TIMEOUT = 1, + FAULT_KEYREJECT, + FAULT_PREINSTALL_AO, + FAULT_PREINSTALL_MD5, + FAULT_POSTINSTALL, + FAULT_BUSY, + FAULT_CURRNEXT, + FAULT_FIXME, +}; +typedef enum test_fault fault_t; + +enum test_needs_kconfig { + KCONFIG_NET_NS = 0, /* required */ + KCONFIG_VETH, /* required */ + KCONFIG_TCP_AO, /* required */ + KCONFIG_TCP_MD5, /* optional, for TCP-MD5 features */ + KCONFIG_NET_VRF, /* optional, for L3/VRF testing */ + KCONFIG_FTRACE, /* optional, for tracepoints checks */ + __KCONFIG_LAST__ +}; +extern bool kernel_config_has(enum test_needs_kconfig k); +extern const char *tests_skip_reason[__KCONFIG_LAST__]; +static inline bool should_skip_test(const char *tst_name, + enum test_needs_kconfig k) +{ + if (kernel_config_has(k)) + return false; + test_skip("%s: %s", tst_name, tests_skip_reason[k]); + return true; +} + +union tcp_addr { + struct in_addr a4; + struct in6_addr a6; +}; + +typedef void *(*thread_fn)(void *); +extern void test_failed(void); +extern void __test_init(unsigned int ntests, int family, unsigned int prefix, + union tcp_addr addr1, union tcp_addr addr2, + thread_fn peer1, thread_fn peer2); + +static inline void test_init2(unsigned int ntests, + thread_fn peer1, thread_fn peer2, + int family, unsigned int prefix, + const char *addr1, const char *addr2) +{ + union tcp_addr taddr1, taddr2; + + if (inet_pton(family, addr1, &taddr1) != 1) + test_error("Can't convert ip address %s", addr1); + if (inet_pton(family, addr2, &taddr2) != 1) + test_error("Can't convert ip address %s", addr2); + + __test_init(ntests, family, prefix, taddr1, taddr2, peer1, peer2); +} +extern void test_add_destructor(void (*d)(void)); +extern void test_init_ftrace(int nsfd1, int nsfd2); +extern int test_setup_tracing(void); + +/* To adjust optmem socket limit, approximately estimate a number, + * that is bigger than sizeof(struct tcp_ao_key). + */ +#define KERNEL_TCP_AO_KEY_SZ_ROUND_UP 300 + +extern void test_set_optmem(size_t value); +extern size_t test_get_optmem(void); + +extern const struct sockaddr_in6 addr_any6; +extern const struct sockaddr_in addr_any4; + +#ifdef IPV6_TEST +# define __TEST_CLIENT_IP(n) ("2001:db8:" __stringify(n) "::1") +# define TEST_CLIENT_IP __TEST_CLIENT_IP(1) +# define TEST_WRONG_IP "2001:db8:253::1" +# define TEST_SERVER_IP "2001:db8:254::1" +# define TEST_NETWORK "2001::" +# define TEST_PREFIX 128 +# define TEST_FAMILY AF_INET6 +# define SOCKADDR_ANY addr_any6 +# define sockaddr_af struct sockaddr_in6 +#else +# define __TEST_CLIENT_IP(n) ("10.0." __stringify(n) ".1") +# define TEST_CLIENT_IP __TEST_CLIENT_IP(1) +# define TEST_WRONG_IP "10.0.253.1" +# define TEST_SERVER_IP "10.0.254.1" +# define TEST_NETWORK "10.0.0.0" +# define TEST_PREFIX 32 +# define TEST_FAMILY AF_INET +# define SOCKADDR_ANY addr_any4 +# define sockaddr_af struct sockaddr_in +#endif + +static inline union tcp_addr gen_tcp_addr(union tcp_addr net, size_t n) +{ + union tcp_addr ret = net; + +#ifdef IPV6_TEST + ret.a6.s6_addr32[3] = htonl(n & (BIT(32) - 1)); + ret.a6.s6_addr32[2] = htonl((n >> 32) & (BIT(32) - 1)); +#else + ret.a4.s_addr = htonl(ntohl(net.a4.s_addr) + n); +#endif + + return ret; +} + +static inline void tcp_addr_to_sockaddr_in(void *dest, + const union tcp_addr *src, + unsigned int port) +{ + sockaddr_af *out = dest; + + memset(out, 0, sizeof(*out)); +#ifdef IPV6_TEST + out->sin6_family = AF_INET6; + out->sin6_port = port; + out->sin6_addr = src->a6; +#else + out->sin_family = AF_INET; + out->sin_port = port; + out->sin_addr = src->a4; +#endif +} + +static inline void test_init(unsigned int ntests, + thread_fn peer1, thread_fn peer2) +{ + test_init2(ntests, peer1, peer2, TEST_FAMILY, TEST_PREFIX, + TEST_SERVER_IP, TEST_CLIENT_IP); +} +extern void synchronize_threads(void); +extern void switch_ns(int fd); +extern int switch_save_ns(int fd); +extern void switch_close_ns(int fd); + +extern __thread union tcp_addr this_ip_addr; +extern __thread union tcp_addr this_ip_dest; +extern int test_family; + +extern void randomize_buffer(void *buf, size_t buflen); +extern __printf(3, 4) int test_echo(const char *fname, bool append, + const char *fmt, ...); + +extern int open_netns(void); +extern int unshare_open_netns(void); +extern const char veth_name[]; +extern int add_veth(const char *name, int nsfda, int nsfdb); +extern int add_vrf(const char *name, uint32_t tabid, int ifindex, int nsfd); +extern int ip_addr_add(const char *intf, int family, + union tcp_addr addr, uint8_t prefix); +extern int ip_route_add(const char *intf, int family, + union tcp_addr src, union tcp_addr dst); +extern int ip_route_add_vrf(const char *intf, int family, + union tcp_addr src, union tcp_addr dst, + uint8_t vrf); +extern int link_set_up(const char *intf); + +extern const unsigned int test_server_port; +extern int test_wait_fd(int sk, time_t sec, bool write); +extern int __test_connect_socket(int sk, const char *device, + void *addr, size_t addr_sz, bool async); +extern int __test_listen_socket(int backlog, void *addr, size_t addr_sz); + +static inline int test_listen_socket(const union tcp_addr taddr, + unsigned int port, int backlog) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port)); + return __test_listen_socket(backlog, (void *)&addr, sizeof(addr)); +} + +/* + * In order for selftests to work under CONFIG_CRYPTO_FIPS=y, + * the password should be loger than 14 bytes, see hmac_setkey() + */ +#define TEST_TCP_AO_MINKEYLEN 14 +#define DEFAULT_TEST_PASSWORD "In this hour, I do not believe that any darkness will endure." + +#ifndef DEFAULT_TEST_ALGO +#define DEFAULT_TEST_ALGO "cmac(aes128)" +#endif + +#ifdef IPV6_TEST +#define DEFAULT_TEST_PREFIX 128 +#else +#define DEFAULT_TEST_PREFIX 32 +#endif + +/* + * Timeout on syscalls where failure is not expected. + * You may want to rise it if the test machine is very busy. + */ +#ifndef TEST_TIMEOUT_SEC +#define TEST_TIMEOUT_SEC 5 +#endif + +/* + * Timeout on connect() where a failure is expected. + * If set to 0 - kernel will try to retransmit SYN number of times, set in + * /proc/sys/net/ipv4/tcp_syn_retries + * By default set to 1 to make tests pass faster on non-busy machine. + * [in process of removal, don't use in new tests] + */ +#ifndef TEST_RETRANSMIT_SEC +#define TEST_RETRANSMIT_SEC 1 +#endif + +static inline int _test_connect_socket(int sk, const union tcp_addr taddr, + unsigned int port, bool async) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port)); + return __test_connect_socket(sk, veth_name, + (void *)&addr, sizeof(addr), async); +} + +static inline int test_connect_socket(int sk, const union tcp_addr taddr, + unsigned int port) +{ + return _test_connect_socket(sk, taddr, port, false); +} + +extern int __test_set_md5(int sk, void *addr, size_t addr_sz, + uint8_t prefix, int vrf, const char *password); +static inline int test_set_md5(int sk, const union tcp_addr in_addr, + uint8_t prefix, int vrf, const char *password) +{ + sockaddr_af addr; + + if (prefix > DEFAULT_TEST_PREFIX) + prefix = DEFAULT_TEST_PREFIX; + + tcp_addr_to_sockaddr_in(&addr, &in_addr, 0); + return __test_set_md5(sk, (void *)&addr, sizeof(addr), + prefix, vrf, password); +} + +extern int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg, + void *addr, size_t addr_sz, bool set_current, bool set_rnext, + uint8_t prefix, uint8_t vrf, + uint8_t sndid, uint8_t rcvid, uint8_t maclen, + uint8_t keyflags, uint8_t keylen, const char *key); + +static inline int test_prepare_key(struct tcp_ao_add *ao, + const char *alg, union tcp_addr taddr, + bool set_current, bool set_rnext, + uint8_t prefix, uint8_t vrf, + uint8_t sndid, uint8_t rcvid, uint8_t maclen, + uint8_t keyflags, uint8_t keylen, const char *key) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &taddr, 0); + return test_prepare_key_sockaddr(ao, alg, (void *)&addr, sizeof(addr), + set_current, set_rnext, prefix, vrf, sndid, rcvid, + maclen, keyflags, keylen, key); +} + +static inline int test_prepare_def_key(struct tcp_ao_add *ao, + const char *key, uint8_t keyflags, + union tcp_addr in_addr, uint8_t prefix, uint8_t vrf, + uint8_t sndid, uint8_t rcvid) +{ + if (prefix > DEFAULT_TEST_PREFIX) + prefix = DEFAULT_TEST_PREFIX; + + return test_prepare_key(ao, DEFAULT_TEST_ALGO, in_addr, false, false, + prefix, vrf, sndid, rcvid, 0, keyflags, + strlen(key), key); +} + +extern int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out, + void *addr, size_t addr_sz, + uint8_t prefix, uint8_t sndid, uint8_t rcvid); +extern int test_get_ao_info(int sk, struct tcp_ao_info_opt *out); +extern int test_set_ao_info(int sk, struct tcp_ao_info_opt *in); +extern int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a, + const struct tcp_ao_getsockopt *b); +extern int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a, + const struct tcp_ao_info_opt *b); + +static inline int test_verify_socket_key(int sk, struct tcp_ao_add *key) +{ + struct tcp_ao_getsockopt key2 = {}; + int err; + + err = test_get_one_ao(sk, &key2, &key->addr, sizeof(key->addr), + key->prefix, key->sndid, key->rcvid); + if (err) + return err; + + return test_cmp_getsockopt_setsockopt(key, &key2); +} + +static inline int test_add_key_vrf(int sk, + const char *key, uint8_t keyflags, + union tcp_addr in_addr, uint8_t prefix, + uint8_t vrf, uint8_t sndid, uint8_t rcvid) +{ + struct tcp_ao_add tmp = {}; + int err; + + err = test_prepare_def_key(&tmp, key, keyflags, in_addr, prefix, + vrf, sndid, rcvid); + if (err) + return err; + + err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)); + if (err < 0) + return -errno; + + return test_verify_socket_key(sk, &tmp); +} + +static inline int test_add_key(int sk, const char *key, + union tcp_addr in_addr, uint8_t prefix, + uint8_t sndid, uint8_t rcvid) +{ + return test_add_key_vrf(sk, key, 0, in_addr, prefix, 0, sndid, rcvid); +} + +static inline int test_verify_socket_ao(int sk, struct tcp_ao_info_opt *ao) +{ + struct tcp_ao_info_opt ao2 = {}; + int err; + + err = test_get_ao_info(sk, &ao2); + if (err) + return err; + + return test_cmp_getsockopt_setsockopt_ao(ao, &ao2); +} + +static inline int test_set_ao_flags(int sk, bool ao_required, bool accept_icmps) +{ + struct tcp_ao_info_opt ao = {}; + int err; + + err = test_get_ao_info(sk, &ao); + /* Maybe ao_info wasn't allocated yet */ + if (err && err != -ENOENT) + return err; + + ao.ao_required = !!ao_required; + ao.accept_icmps = !!accept_icmps; + err = test_set_ao_info(sk, &ao); + if (err) + return err; + + return test_verify_socket_ao(sk, &ao); +} + +extern ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec); +extern int test_client_verify(int sk, const size_t msg_len, const size_t nr); + +struct tcp_ao_key_counters { + uint8_t sndid; + uint8_t rcvid; + uint64_t pkt_good; + uint64_t pkt_bad; +}; + +struct tcp_ao_counters { + /* per-netns */ + uint64_t netns_ao_good; + uint64_t netns_ao_bad; + uint64_t netns_ao_key_not_found; + uint64_t netns_ao_required; + uint64_t netns_ao_dropped_icmp; + /* per-socket */ + uint64_t ao_info_pkt_good; + uint64_t ao_info_pkt_bad; + uint64_t ao_info_pkt_key_not_found; + uint64_t ao_info_pkt_ao_required; + uint64_t ao_info_pkt_dropped_icmp; + /* per-key */ + size_t nr_keys; + struct tcp_ao_key_counters *key_cnts; +}; + +struct tcp_counters { + struct tcp_ao_counters ao; + uint64_t netns_md5_notfound; + uint64_t netns_md5_unexpected; + uint64_t netns_md5_failure; +}; + +extern int test_get_tcp_counters(int sk, struct tcp_counters *out); + +#define TEST_CNT_KEY_GOOD BIT(0) +#define TEST_CNT_KEY_BAD BIT(1) +#define TEST_CNT_SOCK_GOOD BIT(2) +#define TEST_CNT_SOCK_BAD BIT(3) +#define TEST_CNT_SOCK_KEY_NOT_FOUND BIT(4) +#define TEST_CNT_SOCK_AO_REQUIRED BIT(5) +#define TEST_CNT_SOCK_DROPPED_ICMP BIT(6) +#define TEST_CNT_NS_GOOD BIT(7) +#define TEST_CNT_NS_BAD BIT(8) +#define TEST_CNT_NS_KEY_NOT_FOUND BIT(9) +#define TEST_CNT_NS_AO_REQUIRED BIT(10) +#define TEST_CNT_NS_DROPPED_ICMP BIT(11) +#define TEST_CNT_NS_MD5_NOT_FOUND BIT(12) +#define TEST_CNT_NS_MD5_UNEXPECTED BIT(13) +#define TEST_CNT_NS_MD5_FAILURE BIT(14) +typedef uint16_t test_cnt; + +#define _for_each_counter(f) \ +do { \ + /* per-netns */ \ + f(ao.netns_ao_good, TEST_CNT_NS_GOOD); \ + f(ao.netns_ao_bad, TEST_CNT_NS_BAD); \ + f(ao.netns_ao_key_not_found, TEST_CNT_NS_KEY_NOT_FOUND); \ + f(ao.netns_ao_required, TEST_CNT_NS_AO_REQUIRED); \ + f(ao.netns_ao_dropped_icmp, TEST_CNT_NS_DROPPED_ICMP); \ + /* per-socket */ \ + f(ao.ao_info_pkt_good, TEST_CNT_SOCK_GOOD); \ + f(ao.ao_info_pkt_bad, TEST_CNT_SOCK_BAD); \ + f(ao.ao_info_pkt_key_not_found, TEST_CNT_SOCK_KEY_NOT_FOUND); \ + f(ao.ao_info_pkt_ao_required, TEST_CNT_SOCK_AO_REQUIRED); \ + f(ao.ao_info_pkt_dropped_icmp, TEST_CNT_SOCK_DROPPED_ICMP); \ + /* non-AO */ \ + f(netns_md5_notfound, TEST_CNT_NS_MD5_NOT_FOUND); \ + f(netns_md5_unexpected, TEST_CNT_NS_MD5_UNEXPECTED); \ + f(netns_md5_failure, TEST_CNT_NS_MD5_FAILURE); \ +} while (0) + +#define TEST_CNT_AO_GOOD (TEST_CNT_SOCK_GOOD | TEST_CNT_NS_GOOD) +#define TEST_CNT_AO_BAD (TEST_CNT_SOCK_BAD | TEST_CNT_NS_BAD) +#define TEST_CNT_AO_KEY_NOT_FOUND (TEST_CNT_SOCK_KEY_NOT_FOUND | \ + TEST_CNT_NS_KEY_NOT_FOUND) +#define TEST_CNT_AO_REQUIRED (TEST_CNT_SOCK_AO_REQUIRED | \ + TEST_CNT_NS_AO_REQUIRED) +#define TEST_CNT_AO_DROPPED_ICMP (TEST_CNT_SOCK_DROPPED_ICMP | \ + TEST_CNT_NS_DROPPED_ICMP) +#define TEST_CNT_GOOD (TEST_CNT_KEY_GOOD | TEST_CNT_AO_GOOD) +#define TEST_CNT_BAD (TEST_CNT_KEY_BAD | TEST_CNT_AO_BAD) + +extern test_cnt test_cmp_counters(struct tcp_counters *before, + struct tcp_counters *after); +extern int test_assert_counters_sk(const char *tst_name, + struct tcp_counters *before, struct tcp_counters *after, + test_cnt expected); +extern int test_assert_counters_key(const char *tst_name, + struct tcp_ao_counters *before, struct tcp_ao_counters *after, + test_cnt expected, int sndid, int rcvid); +extern void test_tcp_counters_free(struct tcp_counters *cnts); + +/* + * Polling for netns and socket counters during select()/connect() and also + * client/server messaging. Instead of constant timeout on underlying select(), + * check the counters and return early. This allows to pass the tests where + * timeout is expected without waiting for that fixing timeout (tests speed-up). + * Previously shorter timeouts were used for tests expecting to time out, + * but that leaded to sporadic false positives on counter checks failures, + * as one second timeouts aren't enough for TCP retransmit. + * + * Two sides of the socketpair (client/server) should synchronize failures + * using a shared variable *err, so that they can detect the other side's + * failure. + */ +extern int test_skpair_wait_poll(int sk, bool write, test_cnt cond, + volatile int *err); +extern int _test_skpair_connect_poll(int sk, const char *device, + void *addr, size_t addr_sz, + test_cnt cond, volatile int *err); +static inline int test_skpair_connect_poll(int sk, const union tcp_addr taddr, + unsigned int port, + test_cnt cond, volatile int *err) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port)); + return _test_skpair_connect_poll(sk, veth_name, + (void *)&addr, sizeof(addr), cond, err); +} + +extern int test_skpair_client(int sk, const size_t msg_len, const size_t nr, + test_cnt cond, volatile int *err); +extern int test_skpair_server(int sk, ssize_t quota, + test_cnt cond, volatile int *err); + +/* + * Frees buffers allocated in test_get_tcp_counters(). + * The function doesn't expect new keys or keys removed between calls + * to test_get_tcp_counters(). Check key counters manually if they + * may change. + */ +static inline int test_assert_counters(const char *tst_name, + struct tcp_counters *before, + struct tcp_counters *after, + test_cnt expected) +{ + int ret; + + ret = test_assert_counters_sk(tst_name, before, after, expected); + if (ret) + goto out; + ret = test_assert_counters_key(tst_name, &before->ao, &after->ao, + expected, -1, -1); +out: + test_tcp_counters_free(before); + test_tcp_counters_free(after); + return ret; +} + +struct netstat; +extern struct netstat *netstat_read(void); +extern void netstat_free(struct netstat *ns); +extern void netstat_print_diff(struct netstat *nsa, struct netstat *nsb); +extern uint64_t netstat_get(struct netstat *ns, + const char *name, bool *not_found); + +static inline uint64_t netstat_get_one(const char *name, bool *not_found) +{ + struct netstat *ns = netstat_read(); + uint64_t ret; + + ret = netstat_get(ns, name, not_found); + + netstat_free(ns); + return ret; +} + +struct tcp_sock_queue { + uint32_t seq; + void *buf; +}; + +struct tcp_sock_state { + struct tcp_info info; + struct tcp_repair_window trw; + struct tcp_sock_queue out; + int outq_len; /* output queue size (not sent + not acked) */ + int outq_nsd_len; /* output queue size (not sent only) */ + struct tcp_sock_queue in; + int inq_len; + int mss; + int timestamp; +}; + +extern void __test_sock_checkpoint(int sk, struct tcp_sock_state *state, + void *addr, size_t addr_size); +static inline void test_sock_checkpoint(int sk, struct tcp_sock_state *state, + sockaddr_af *saddr) +{ + __test_sock_checkpoint(sk, state, saddr, sizeof(*saddr)); +} +extern void test_ao_checkpoint(int sk, struct tcp_ao_repair *state); +extern void __test_sock_restore(int sk, const char *device, + struct tcp_sock_state *state, + void *saddr, void *daddr, size_t addr_size); +static inline void test_sock_restore(int sk, struct tcp_sock_state *state, + sockaddr_af *saddr, + const union tcp_addr daddr, + unsigned int dport) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &daddr, htons(dport)); + __test_sock_restore(sk, veth_name, state, saddr, &addr, sizeof(addr)); +} +extern void test_ao_restore(int sk, struct tcp_ao_repair *state); +extern void test_sock_state_free(struct tcp_sock_state *state); +extern void test_enable_repair(int sk); +extern void test_disable_repair(int sk); +extern void test_kill_sk(int sk); +static inline int test_add_repaired_key(int sk, + const char *key, uint8_t keyflags, + union tcp_addr in_addr, uint8_t prefix, + uint8_t sndid, uint8_t rcvid) +{ + struct tcp_ao_add tmp = {}; + int err; + + err = test_prepare_def_key(&tmp, key, keyflags, in_addr, prefix, + 0, sndid, rcvid); + if (err) + return err; + + tmp.set_current = 1; + tmp.set_rnext = 1; + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)) < 0) + return -errno; + + return test_verify_socket_key(sk, &tmp); +} + +#define DEFAULT_FTRACE_BUFFER_KB 10000 +#define DEFAULT_TRACER_LINES_ARR 200 +struct test_ftracer; +extern uint64_t ns_cookie1, ns_cookie2; + +enum ftracer_op { + FTRACER_LINE_DISCARD = 0, + FTRACER_LINE_PRESERVE, + FTRACER_EXIT, +}; + +extern struct test_ftracer *create_ftracer(const char *name, + enum ftracer_op (*process_line)(const char *line), + void (*destructor)(struct test_ftracer *tracer), + bool (*expecting_more)(void), + size_t lines_buf_sz, size_t buffer_size_kb); +extern int setup_trace_event(struct test_ftracer *tracer, + const char *event, const char *filter); +extern void destroy_ftracer(struct test_ftracer *tracer); +extern const size_t tracer_get_savedlines_nr(struct test_ftracer *tracer); +extern const char **tracer_get_savedlines(struct test_ftracer *tracer); + +enum trace_events { + /* TCP_HASH_EVENT */ + TCP_HASH_BAD_HEADER = 0, + TCP_HASH_MD5_REQUIRED, + TCP_HASH_MD5_UNEXPECTED, + TCP_HASH_MD5_MISMATCH, + TCP_HASH_AO_REQUIRED, + /* TCP_AO_EVENT */ + TCP_AO_HANDSHAKE_FAILURE, + TCP_AO_WRONG_MACLEN, + TCP_AO_MISMATCH, + TCP_AO_KEY_NOT_FOUND, + TCP_AO_RNEXT_REQUEST, + /* TCP_AO_EVENT_SK */ + TCP_AO_SYNACK_NO_KEY, + /* TCP_AO_EVENT_SNE */ + TCP_AO_SND_SNE_UPDATE, + TCP_AO_RCV_SNE_UPDATE, + __MAX_TRACE_EVENTS +}; + +extern int __trace_event_expect(enum trace_events type, int family, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, int L3index, + int fin, int syn, int rst, int psh, int ack, + int keyid, int rnext, int maclen, int sne); + +static inline void trace_hash_event_expect(enum trace_events type, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, int L3index, + int fin, int syn, int rst, int psh, int ack) +{ + int err; + + err = __trace_event_expect(type, TEST_FAMILY, src, dst, + src_port, dst_port, L3index, + fin, syn, rst, psh, ack, + -1, -1, -1, -1); + if (err) + test_error("Couldn't add a trace event: %d", err); +} + +static inline void trace_ao_event_expect(enum trace_events type, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, int L3index, + int fin, int syn, int rst, int psh, int ack, + int keyid, int rnext, int maclen) +{ + int err; + + err = __trace_event_expect(type, TEST_FAMILY, src, dst, + src_port, dst_port, L3index, + fin, syn, rst, psh, ack, + keyid, rnext, maclen, -1); + if (err) + test_error("Couldn't add a trace event: %d", err); +} + +static inline void trace_ao_event_sk_expect(enum trace_events type, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, + int keyid, int rnext) +{ + int err; + + err = __trace_event_expect(type, TEST_FAMILY, src, dst, + src_port, dst_port, -1, + -1, -1, -1, -1, -1, + keyid, rnext, -1, -1); + if (err) + test_error("Couldn't add a trace event: %d", err); +} + +static inline void trace_ao_event_sne_expect(enum trace_events type, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, int sne) +{ + int err; + + err = __trace_event_expect(type, TEST_FAMILY, src, dst, + src_port, dst_port, -1, + -1, -1, -1, -1, -1, + -1, -1, -1, sne); + if (err) + test_error("Couldn't add a trace event: %d", err); +} + +extern int setup_aolib_ftracer(void); + +#endif /* _AOLIB_H_ */ diff --git a/tools/testing/selftests/net/tcp_ao/lib/ftrace-tcp.c b/tools/testing/selftests/net/tcp_ao/lib/ftrace-tcp.c new file mode 100644 index 000000000000..27403f875054 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/ftrace-tcp.c @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <inttypes.h> +#include <pthread.h> +#include "aolib.h" + +static const char *trace_event_names[__MAX_TRACE_EVENTS] = { + /* TCP_HASH_EVENT */ + "tcp_hash_bad_header", + "tcp_hash_md5_required", + "tcp_hash_md5_unexpected", + "tcp_hash_md5_mismatch", + "tcp_hash_ao_required", + /* TCP_AO_EVENT */ + "tcp_ao_handshake_failure", + "tcp_ao_wrong_maclen", + "tcp_ao_mismatch", + "tcp_ao_key_not_found", + "tcp_ao_rnext_request", + /* TCP_AO_EVENT_SK */ + "tcp_ao_synack_no_key", + /* TCP_AO_EVENT_SNE */ + "tcp_ao_snd_sne_update", + "tcp_ao_rcv_sne_update" +}; + +struct expected_trace_point { + /* required */ + enum trace_events type; + int family; + union tcp_addr src; + union tcp_addr dst; + + /* optional */ + int src_port; + int dst_port; + int L3index; + + int fin; + int syn; + int rst; + int psh; + int ack; + + int keyid; + int rnext; + int maclen; + int sne; + + size_t matched; +}; + +static struct expected_trace_point *exp_tps; +static size_t exp_tps_nr; +static size_t exp_tps_size; +static pthread_mutex_t exp_tps_mutex = PTHREAD_MUTEX_INITIALIZER; + +int __trace_event_expect(enum trace_events type, int family, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, int L3index, + int fin, int syn, int rst, int psh, int ack, + int keyid, int rnext, int maclen, int sne) +{ + struct expected_trace_point new_tp = { + .type = type, + .family = family, + .src = src, + .dst = dst, + .src_port = src_port, + .dst_port = dst_port, + .L3index = L3index, + .fin = fin, + .syn = syn, + .rst = rst, + .psh = psh, + .ack = ack, + .keyid = keyid, + .rnext = rnext, + .maclen = maclen, + .sne = sne, + .matched = 0, + }; + int ret = 0; + + if (!kernel_config_has(KCONFIG_FTRACE)) + return 0; + + pthread_mutex_lock(&exp_tps_mutex); + if (exp_tps_nr == exp_tps_size) { + struct expected_trace_point *tmp; + + if (exp_tps_size == 0) + exp_tps_size = 10; + else + exp_tps_size = exp_tps_size * 1.6; + + tmp = reallocarray(exp_tps, exp_tps_size, sizeof(exp_tps[0])); + if (!tmp) { + ret = -ENOMEM; + goto out; + } + exp_tps = tmp; + } + exp_tps[exp_tps_nr] = new_tp; + exp_tps_nr++; +out: + pthread_mutex_unlock(&exp_tps_mutex); + return ret; +} + +static void free_expected_events(void) +{ + /* We're from the process destructor - not taking the mutex */ + exp_tps_size = 0; + exp_tps = NULL; + free(exp_tps); +} + +struct trace_point { + int family; + union tcp_addr src; + union tcp_addr dst; + unsigned int src_port; + unsigned int dst_port; + int L3index; + unsigned int fin:1, + syn:1, + rst:1, + psh:1, + ack:1; + + unsigned int keyid; + unsigned int rnext; + unsigned int maclen; + + unsigned int sne; +}; + +static bool lookup_expected_event(int event_type, struct trace_point *e) +{ + size_t i; + + pthread_mutex_lock(&exp_tps_mutex); + for (i = 0; i < exp_tps_nr; i++) { + struct expected_trace_point *p = &exp_tps[i]; + size_t sk_size; + + if (p->type != event_type) + continue; + if (p->family != e->family) + continue; + if (p->family == AF_INET) + sk_size = sizeof(p->src.a4); + else + sk_size = sizeof(p->src.a6); + if (memcmp(&p->src, &e->src, sk_size)) + continue; + if (memcmp(&p->dst, &e->dst, sk_size)) + continue; + if (p->src_port >= 0 && p->src_port != e->src_port) + continue; + if (p->dst_port >= 0 && p->dst_port != e->dst_port) + continue; + if (p->L3index >= 0 && p->L3index != e->L3index) + continue; + + if (p->fin >= 0 && p->fin != e->fin) + continue; + if (p->syn >= 0 && p->syn != e->syn) + continue; + if (p->rst >= 0 && p->rst != e->rst) + continue; + if (p->psh >= 0 && p->psh != e->psh) + continue; + if (p->ack >= 0 && p->ack != e->ack) + continue; + + if (p->keyid >= 0 && p->keyid != e->keyid) + continue; + if (p->rnext >= 0 && p->rnext != e->rnext) + continue; + if (p->maclen >= 0 && p->maclen != e->maclen) + continue; + if (p->sne >= 0 && p->sne != e->sne) + continue; + p->matched++; + pthread_mutex_unlock(&exp_tps_mutex); + return true; + } + pthread_mutex_unlock(&exp_tps_mutex); + return false; +} + +static int check_event_type(const char *line) +{ + size_t i; + + /* + * This should have been a set or hashmap, but it's a selftest, + * so... KISS. + */ + for (i = 0; i < __MAX_TRACE_EVENTS; i++) { + if (!strncmp(trace_event_names[i], line, strlen(trace_event_names[i]))) + return i; + } + return -1; +} + +static bool event_has_flags(enum trace_events event) +{ + switch (event) { + case TCP_HASH_BAD_HEADER: + case TCP_HASH_MD5_REQUIRED: + case TCP_HASH_MD5_UNEXPECTED: + case TCP_HASH_MD5_MISMATCH: + case TCP_HASH_AO_REQUIRED: + case TCP_AO_HANDSHAKE_FAILURE: + case TCP_AO_WRONG_MACLEN: + case TCP_AO_MISMATCH: + case TCP_AO_KEY_NOT_FOUND: + case TCP_AO_RNEXT_REQUEST: + return true; + default: + return false; + } +} + +static int tracer_ip_split(int family, char *src, char **addr, char **port) +{ + char *p; + + if (family == AF_INET) { + /* fomat is <addr>:port, i.e.: 10.0.254.1:7015 */ + *addr = src; + p = strchr(src, ':'); + if (!p) { + test_print("Couldn't parse trace event addr:port %s", src); + return -EINVAL; + } + *p++ = '\0'; + *port = p; + return 0; + } + if (family != AF_INET6) + return -EAFNOSUPPORT; + + /* format is [<addr>]:port, i.e.: [2001:db8:254::1]:7013 */ + *addr = strchr(src, '['); + p = strchr(src, ']'); + + if (!p || !*addr) { + test_print("Couldn't parse trace event [addr]:port %s", src); + return -EINVAL; + } + + *addr = *addr + 1; /* '[' */ + *p++ = '\0'; /* ']' */ + if (*p != ':') { + test_print("Couldn't parse trace event :port %s", p); + return -EINVAL; + } + *p++ = '\0'; /* ':' */ + *port = p; + return 0; +} + +static int tracer_scan_address(int family, char *src, + union tcp_addr *dst, unsigned int *port) +{ + char *addr, *port_str; + int ret; + + ret = tracer_ip_split(family, src, &addr, &port_str); + if (ret) + return ret; + + if (inet_pton(family, addr, dst) != 1) { + test_print("Couldn't parse trace event addr %s", addr); + return -EINVAL; + } + errno = 0; + *port = (unsigned int)strtoul(port_str, NULL, 10); + if (errno != 0) { + test_print("Couldn't parse trace event port %s", port_str); + return -errno; + } + return 0; +} + +static int tracer_scan_event(const char *line, enum trace_events event, + struct trace_point *out) +{ + char *src = NULL, *dst = NULL, *family = NULL; + char fin, syn, rst, psh, ack; + int nr_matched, ret = 0; + uint64_t netns_cookie; + + switch (event) { + case TCP_HASH_BAD_HEADER: + case TCP_HASH_MD5_REQUIRED: + case TCP_HASH_MD5_UNEXPECTED: + case TCP_HASH_MD5_MISMATCH: + case TCP_HASH_AO_REQUIRED: { + nr_matched = sscanf(line, "%*s net=%" PRIu64 " state%*s family=%ms src=%ms dest=%ms L3index=%d [%c%c%c%c%c]", + &netns_cookie, &family, + &src, &dst, &out->L3index, + &fin, &syn, &rst, &psh, &ack); + if (nr_matched != 10) + test_print("Couldn't parse trace event, matched = %d/10", + nr_matched); + break; + } + case TCP_AO_HANDSHAKE_FAILURE: + case TCP_AO_WRONG_MACLEN: + case TCP_AO_MISMATCH: + case TCP_AO_KEY_NOT_FOUND: + case TCP_AO_RNEXT_REQUEST: { + nr_matched = sscanf(line, "%*s net=%" PRIu64 " state%*s family=%ms src=%ms dest=%ms L3index=%d [%c%c%c%c%c] keyid=%u rnext=%u maclen=%u", + &netns_cookie, &family, + &src, &dst, &out->L3index, + &fin, &syn, &rst, &psh, &ack, + &out->keyid, &out->rnext, &out->maclen); + if (nr_matched != 13) + test_print("Couldn't parse trace event, matched = %d/13", + nr_matched); + break; + } + case TCP_AO_SYNACK_NO_KEY: { + nr_matched = sscanf(line, "%*s net=%" PRIu64 " state%*s family=%ms src=%ms dest=%ms keyid=%u rnext=%u", + &netns_cookie, &family, + &src, &dst, &out->keyid, &out->rnext); + if (nr_matched != 6) + test_print("Couldn't parse trace event, matched = %d/6", + nr_matched); + break; + } + case TCP_AO_SND_SNE_UPDATE: + case TCP_AO_RCV_SNE_UPDATE: { + nr_matched = sscanf(line, "%*s net=%" PRIu64 " state%*s family=%ms src=%ms dest=%ms sne=%u", + &netns_cookie, &family, + &src, &dst, &out->sne); + if (nr_matched != 5) + test_print("Couldn't parse trace event, matched = %d/5", + nr_matched); + break; + } + default: + return -1; + } + + if (family) { + if (!strcmp(family, "AF_INET")) { + out->family = AF_INET; + } else if (!strcmp(family, "AF_INET6")) { + out->family = AF_INET6; + } else { + test_print("Couldn't parse trace event family %s", family); + ret = -EINVAL; + goto out_free; + } + } + + if (event_has_flags(event)) { + out->fin = (fin == 'F'); + out->syn = (syn == 'S'); + out->rst = (rst == 'R'); + out->psh = (psh == 'P'); + out->ack = (ack == '.'); + + if ((fin != 'F' && fin != ' ') || + (syn != 'S' && syn != ' ') || + (rst != 'R' && rst != ' ') || + (psh != 'P' && psh != ' ') || + (ack != '.' && ack != ' ')) { + test_print("Couldn't parse trace event flags %c%c%c%c%c", + fin, syn, rst, psh, ack); + ret = -EINVAL; + goto out_free; + } + } + + if (src && tracer_scan_address(out->family, src, &out->src, &out->src_port)) { + ret = -EINVAL; + goto out_free; + } + + if (dst && tracer_scan_address(out->family, dst, &out->dst, &out->dst_port)) { + ret = -EINVAL; + goto out_free; + } + + if (netns_cookie != ns_cookie1 && netns_cookie != ns_cookie2) { + test_print("Net namespace filter for trace event didn't work: %" PRIu64 " != %" PRIu64 " OR %" PRIu64, + netns_cookie, ns_cookie1, ns_cookie2); + ret = -EINVAL; + } + +out_free: + free(src); + free(dst); + free(family); + return ret; +} + +static enum ftracer_op aolib_tracer_process_event(const char *line) +{ + int event_type = check_event_type(line); + struct trace_point tmp = {}; + + if (event_type < 0) + return FTRACER_LINE_PRESERVE; + + if (tracer_scan_event(line, event_type, &tmp)) + return FTRACER_LINE_PRESERVE; + + return lookup_expected_event(event_type, &tmp) ? + FTRACER_LINE_DISCARD : FTRACER_LINE_PRESERVE; +} + +static void dump_trace_event(struct expected_trace_point *e) +{ + char src[INET6_ADDRSTRLEN], dst[INET6_ADDRSTRLEN]; + + if (!inet_ntop(e->family, &e->src, src, INET6_ADDRSTRLEN)) + test_error("inet_ntop()"); + if (!inet_ntop(e->family, &e->dst, dst, INET6_ADDRSTRLEN)) + test_error("inet_ntop()"); + test_print("trace event filter %s [%s:%d => %s:%d, L3index %d, flags: %s%s%s%s%s, keyid: %d, rnext: %d, maclen: %d, sne: %d] = %zu", + trace_event_names[e->type], + src, e->src_port, dst, e->dst_port, e->L3index, + e->fin ? "F" : "", e->syn ? "S" : "", e->rst ? "R" : "", + e->psh ? "P" : "", e->ack ? "." : "", + e->keyid, e->rnext, e->maclen, e->sne, e->matched); +} + +static void print_match_stats(bool unexpected_events) +{ + size_t matches_per_type[__MAX_TRACE_EVENTS] = {}; + bool expected_but_none = false; + size_t i, total_matched = 0; + char *stat_line = NULL; + + for (i = 0; i < exp_tps_nr; i++) { + struct expected_trace_point *e = &exp_tps[i]; + + total_matched += e->matched; + matches_per_type[e->type] += e->matched; + if (!e->matched) + expected_but_none = true; + } + for (i = 0; i < __MAX_TRACE_EVENTS; i++) { + if (!matches_per_type[i]) + continue; + stat_line = test_sprintf("%s%s[%zu] ", stat_line ?: "", + trace_event_names[i], + matches_per_type[i]); + if (!stat_line) + test_error("test_sprintf()"); + } + + if (unexpected_events || expected_but_none) { + for (i = 0; i < exp_tps_nr; i++) + dump_trace_event(&exp_tps[i]); + } + + if (unexpected_events) + return; + + if (expected_but_none) + test_fail("Some trace events were expected, but didn't occur"); + else if (total_matched) + test_ok("Trace events matched expectations: %zu %s", + total_matched, stat_line); + else + test_ok("No unexpected trace events during the test run"); +} + +#define dump_events(fmt, ...) \ + __test_print(__test_msg, fmt, ##__VA_ARGS__) +static void check_free_events(struct test_ftracer *tracer) +{ + const char **lines; + size_t nr; + + if (!kernel_config_has(KCONFIG_FTRACE)) { + test_skip("kernel config doesn't have ftrace - no checks"); + return; + } + + nr = tracer_get_savedlines_nr(tracer); + lines = tracer_get_savedlines(tracer); + print_match_stats(!!nr); + if (!nr) + return; + + errno = 0; + test_xfail("Trace events [%zu] were not expected:", nr); + while (nr) + dump_events("\t%s", lines[--nr]); +} + +static int setup_tcp_trace_events(struct test_ftracer *tracer) +{ + char *filter; + size_t i; + int ret; + + filter = test_sprintf("net_cookie == %zu || net_cookie == %zu", + ns_cookie1, ns_cookie2); + if (!filter) + return -ENOMEM; + + for (i = 0; i < __MAX_TRACE_EVENTS; i++) { + char *event_name = test_sprintf("tcp/%s", trace_event_names[i]); + + if (!event_name) { + ret = -ENOMEM; + break; + } + ret = setup_trace_event(tracer, event_name, filter); + free(event_name); + if (ret) + break; + } + + free(filter); + return ret; +} + +static void aolib_tracer_destroy(struct test_ftracer *tracer) +{ + check_free_events(tracer); + free_expected_events(); +} + +static bool aolib_tracer_expecting_more(void) +{ + size_t i; + + for (i = 0; i < exp_tps_nr; i++) + if (!exp_tps[i].matched) + return true; + return false; +} + +int setup_aolib_ftracer(void) +{ + struct test_ftracer *f; + + f = create_ftracer("aolib", aolib_tracer_process_event, + aolib_tracer_destroy, aolib_tracer_expecting_more, + DEFAULT_FTRACE_BUFFER_KB, DEFAULT_TRACER_LINES_ARR); + if (!f) + return -1; + + return setup_tcp_trace_events(f); +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/ftrace.c b/tools/testing/selftests/net/tcp_ao/lib/ftrace.c new file mode 100644 index 000000000000..e4d0b173bc94 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/ftrace.c @@ -0,0 +1,543 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <inttypes.h> +#include <pthread.h> +#include <stdbool.h> +#include <stdio.h> +#include <stdlib.h> +#include <sys/mount.h> +#include <sys/time.h> +#include <unistd.h> +#include "../../../../../include/linux/kernel.h" +#include "aolib.h" + +static char ftrace_path[] = "ksft-ftrace-XXXXXX"; +static bool ftrace_mounted; +uint64_t ns_cookie1, ns_cookie2; + +struct test_ftracer { + pthread_t tracer_thread; + int error; + char *instance_path; + FILE *trace_pipe; + + enum ftracer_op (*process_line)(const char *line); + void (*destructor)(struct test_ftracer *tracer); + bool (*expecting_more)(void); + + char **saved_lines; + size_t saved_lines_size; + size_t next_line_ind; + + pthread_cond_t met_all_expected; + pthread_mutex_t met_all_expected_lock; + + struct test_ftracer *next; +}; + +static struct test_ftracer *ftracers; +static pthread_mutex_t ftracers_lock = PTHREAD_MUTEX_INITIALIZER; + +static int mount_ftrace(void) +{ + if (!mkdtemp(ftrace_path)) + test_error("Can't create temp dir"); + + if (mount("tracefs", ftrace_path, "tracefs", 0, "rw")) + return -errno; + + ftrace_mounted = true; + + return 0; +} + +static void unmount_ftrace(void) +{ + if (ftrace_mounted && umount(ftrace_path)) + test_print("Failed on cleanup: can't unmount tracefs: %m"); + + if (rmdir(ftrace_path)) + test_error("Failed on cleanup: can't remove ftrace dir %s", + ftrace_path); +} + +struct opts_list_t { + char *opt_name; + struct opts_list_t *next; +}; + +static int disable_trace_options(const char *ftrace_path) +{ + struct opts_list_t *opts_list = NULL; + char *fopts, *line = NULL; + size_t buf_len = 0; + ssize_t line_len; + int ret = 0; + FILE *opts; + + fopts = test_sprintf("%s/%s", ftrace_path, "trace_options"); + if (!fopts) + return -ENOMEM; + + opts = fopen(fopts, "r+"); + if (!opts) { + ret = -errno; + goto out_free; + } + + while ((line_len = getline(&line, &buf_len, opts)) != -1) { + struct opts_list_t *tmp; + + if (!strncmp(line, "no", 2)) + continue; + + tmp = malloc(sizeof(*tmp)); + if (!tmp) { + ret = -ENOMEM; + goto out_free_opts_list; + } + tmp->next = opts_list; + tmp->opt_name = test_sprintf("no%s", line); + if (!tmp->opt_name) { + ret = -ENOMEM; + free(tmp); + goto out_free_opts_list; + } + opts_list = tmp; + } + + while (opts_list) { + struct opts_list_t *tmp = opts_list; + + fseek(opts, 0, SEEK_SET); + fwrite(tmp->opt_name, 1, strlen(tmp->opt_name), opts); + + opts_list = opts_list->next; + free(tmp->opt_name); + free(tmp); + } + +out_free_opts_list: + while (opts_list) { + struct opts_list_t *tmp = opts_list; + + opts_list = opts_list->next; + free(tmp->opt_name); + free(tmp); + } + free(line); + fclose(opts); +out_free: + free(fopts); + return ret; +} + +static int setup_buffer_size(const char *ftrace_path, size_t sz) +{ + char *fbuf_size = test_sprintf("%s/buffer_size_kb", ftrace_path); + int ret; + + if (!fbuf_size) + return -1; + + ret = test_echo(fbuf_size, 0, "%zu", sz); + free(fbuf_size); + return ret; +} + +static int setup_ftrace_instance(struct test_ftracer *tracer, const char *name) +{ + char *tmp; + + tmp = test_sprintf("%s/instances/ksft-%s-XXXXXX", ftrace_path, name); + if (!tmp) + return -ENOMEM; + + tracer->instance_path = mkdtemp(tmp); + if (!tracer->instance_path) { + free(tmp); + return -errno; + } + + return 0; +} + +static void remove_ftrace_instance(struct test_ftracer *tracer) +{ + if (rmdir(tracer->instance_path)) + test_print("Failed on cleanup: can't remove ftrace instance %s", + tracer->instance_path); + free(tracer->instance_path); +} + +static void tracer_cleanup(void *arg) +{ + struct test_ftracer *tracer = arg; + + fclose(tracer->trace_pipe); +} + +static void tracer_set_error(struct test_ftracer *tracer, int error) +{ + if (!tracer->error) + tracer->error = error; +} + +const size_t tracer_get_savedlines_nr(struct test_ftracer *tracer) +{ + return tracer->next_line_ind; +} + +const char **tracer_get_savedlines(struct test_ftracer *tracer) +{ + return (const char **)tracer->saved_lines; +} + +static void *tracer_thread_func(void *arg) +{ + struct test_ftracer *tracer = arg; + + pthread_cleanup_push(tracer_cleanup, arg); + + while (tracer->next_line_ind < tracer->saved_lines_size) { + char **lp = &tracer->saved_lines[tracer->next_line_ind]; + enum ftracer_op op; + size_t buf_len = 0; + ssize_t line_len; + + line_len = getline(lp, &buf_len, tracer->trace_pipe); + if (line_len == -1) + break; + + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, NULL); + op = tracer->process_line(*lp); + pthread_setcancelstate(PTHREAD_CANCEL_ENABLE, NULL); + + if (tracer->expecting_more) { + pthread_mutex_lock(&tracer->met_all_expected_lock); + if (!tracer->expecting_more()) + pthread_cond_signal(&tracer->met_all_expected); + pthread_mutex_unlock(&tracer->met_all_expected_lock); + } + + if (op == FTRACER_LINE_DISCARD) + continue; + if (op == FTRACER_EXIT) + break; + if (op != FTRACER_LINE_PRESERVE) + test_error("unexpected tracer command %d", op); + + tracer->next_line_ind++; + buf_len = 0; + } + test_print("too many lines in ftracer buffer %zu, exiting tracer", + tracer->next_line_ind); + + pthread_cleanup_pop(1); + return NULL; +} + +static int setup_trace_thread(struct test_ftracer *tracer) +{ + int ret = 0; + char *path; + + path = test_sprintf("%s/trace_pipe", tracer->instance_path); + if (!path) + return -ENOMEM; + + tracer->trace_pipe = fopen(path, "r"); + if (!tracer->trace_pipe) { + ret = -errno; + goto out_free; + } + + if (pthread_create(&tracer->tracer_thread, NULL, + tracer_thread_func, (void *)tracer)) { + ret = -errno; + fclose(tracer->trace_pipe); + } + +out_free: + free(path); + return ret; +} + +static void stop_trace_thread(struct test_ftracer *tracer) +{ + void *res; + + if (pthread_cancel(tracer->tracer_thread)) { + test_print("Can't stop tracer pthread: %m"); + tracer_set_error(tracer, -errno); + } + if (pthread_join(tracer->tracer_thread, &res)) { + test_print("Can't join tracer pthread: %m"); + tracer_set_error(tracer, -errno); + } + if (res != PTHREAD_CANCELED) { + test_print("Tracer thread wasn't canceled"); + tracer_set_error(tracer, -errno); + } + if (tracer->error) + test_fail("tracer errored by %s", strerror(tracer->error)); +} + +static void final_wait_for_events(struct test_ftracer *tracer, + unsigned timeout_sec) +{ + struct timespec timeout; + struct timeval now; + int ret = 0; + + if (!tracer->expecting_more) + return; + + pthread_mutex_lock(&tracer->met_all_expected_lock); + gettimeofday(&now, NULL); + timeout.tv_sec = now.tv_sec + timeout_sec; + timeout.tv_nsec = now.tv_usec * 1000; + + while (tracer->expecting_more() && ret != ETIMEDOUT) + ret = pthread_cond_timedwait(&tracer->met_all_expected, + &tracer->met_all_expected_lock, &timeout); + pthread_mutex_unlock(&tracer->met_all_expected_lock); +} + +int setup_trace_event(struct test_ftracer *tracer, + const char *event, const char *filter) +{ + char *enable_path, *filter_path, *instance = tracer->instance_path; + int ret; + + enable_path = test_sprintf("%s/events/%s/enable", instance, event); + if (!enable_path) + return -ENOMEM; + + filter_path = test_sprintf("%s/events/%s/filter", instance, event); + if (!filter_path) { + ret = -ENOMEM; + goto out_free; + } + + ret = test_echo(filter_path, 0, "%s", filter); + if (!ret) + ret = test_echo(enable_path, 0, "1"); + +out_free: + free(filter_path); + free(enable_path); + return ret; +} + +struct test_ftracer *create_ftracer(const char *name, + enum ftracer_op (*process_line)(const char *line), + void (*destructor)(struct test_ftracer *tracer), + bool (*expecting_more)(void), + size_t lines_buf_sz, size_t buffer_size_kb) +{ + struct test_ftracer *tracer; + int err; + + /* XXX: separate __create_ftracer() helper and do here + * if (!kernel_config_has(KCONFIG_FTRACE)) + * return NULL; + */ + + tracer = malloc(sizeof(*tracer)); + if (!tracer) { + test_print("malloc()"); + return NULL; + } + + memset(tracer, 0, sizeof(*tracer)); + + err = setup_ftrace_instance(tracer, name); + if (err) { + test_print("setup_ftrace_instance(): %d", err); + goto err_free; + } + + err = disable_trace_options(tracer->instance_path); + if (err) { + test_print("disable_trace_options(): %d", err); + goto err_remove; + } + + err = setup_buffer_size(tracer->instance_path, buffer_size_kb); + if (err) { + test_print("disable_trace_options(): %d", err); + goto err_remove; + } + + tracer->saved_lines = calloc(lines_buf_sz, sizeof(tracer->saved_lines[0])); + if (!tracer->saved_lines) { + test_print("calloc()"); + goto err_remove; + } + tracer->saved_lines_size = lines_buf_sz; + + tracer->process_line = process_line; + tracer->destructor = destructor; + tracer->expecting_more = expecting_more; + + err = pthread_cond_init(&tracer->met_all_expected, NULL); + if (err) { + test_print("pthread_cond_init(): %d", err); + goto err_free_lines; + } + + err = pthread_mutex_init(&tracer->met_all_expected_lock, NULL); + if (err) { + test_print("pthread_mutex_init(): %d", err); + goto err_cond_destroy; + } + + err = setup_trace_thread(tracer); + if (err) { + test_print("setup_trace_thread(): %d", err); + goto err_mutex_destroy; + } + + pthread_mutex_lock(&ftracers_lock); + tracer->next = ftracers; + ftracers = tracer; + pthread_mutex_unlock(&ftracers_lock); + + return tracer; + +err_mutex_destroy: + pthread_mutex_destroy(&tracer->met_all_expected_lock); +err_cond_destroy: + pthread_cond_destroy(&tracer->met_all_expected); +err_free_lines: + free(tracer->saved_lines); +err_remove: + remove_ftrace_instance(tracer); +err_free: + free(tracer); + return NULL; +} + +static void __destroy_ftracer(struct test_ftracer *tracer) +{ + size_t i; + + final_wait_for_events(tracer, TEST_TIMEOUT_SEC); + stop_trace_thread(tracer); + remove_ftrace_instance(tracer); + if (tracer->destructor) + tracer->destructor(tracer); + for (i = 0; i < tracer->saved_lines_size; i++) + free(tracer->saved_lines[i]); + pthread_cond_destroy(&tracer->met_all_expected); + pthread_mutex_destroy(&tracer->met_all_expected_lock); + free(tracer); +} + +void destroy_ftracer(struct test_ftracer *tracer) +{ + pthread_mutex_lock(&ftracers_lock); + if (tracer == ftracers) { + ftracers = tracer->next; + } else { + struct test_ftracer *f = ftracers; + + while (f->next != tracer) { + if (!f->next) + test_error("tracers list corruption or double free %p", tracer); + f = f->next; + } + f->next = tracer->next; + } + tracer->next = NULL; + pthread_mutex_unlock(&ftracers_lock); + __destroy_ftracer(tracer); +} + +static void destroy_all_ftracers(void) +{ + struct test_ftracer *f; + + pthread_mutex_lock(&ftracers_lock); + f = ftracers; + ftracers = NULL; + pthread_mutex_unlock(&ftracers_lock); + + while (f) { + struct test_ftracer *n = f->next; + + f->next = NULL; + __destroy_ftracer(f); + f = n; + } +} + +static void test_unset_tracing(void) +{ + destroy_all_ftracers(); + unmount_ftrace(); +} + +int test_setup_tracing(void) +{ + /* + * Just a basic protection - this should be called only once from + * lib/kconfig. Not thread safe, which is fine as it's early, before + * threads are created. + */ + static int already_set; + int err; + + if (already_set) + return -1; + + /* Needs net-namespace cookies for filters */ + if (ns_cookie1 == ns_cookie2) { + test_print("net-namespace cookies: %" PRIu64 " == %" PRIu64 ", can't set up tracing", + ns_cookie1, ns_cookie2); + return -1; + } + + already_set = 1; + + test_add_destructor(test_unset_tracing); + + err = mount_ftrace(); + if (err) { + test_print("failed to mount_ftrace(): %d", err); + return err; + } + + return setup_aolib_ftracer(); +} + +static int get_ns_cookie(int nsfd, uint64_t *out) +{ + int old_ns = switch_save_ns(nsfd); + socklen_t size = sizeof(*out); + int sk; + + sk = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) { + test_print("socket(): %m"); + return -errno; + } + + if (getsockopt(sk, SOL_SOCKET, SO_NETNS_COOKIE, out, &size)) { + test_print("getsockopt(SO_NETNS_COOKIE): %m"); + close(sk); + return -errno; + } + + close(sk); + switch_close_ns(old_ns); + return 0; +} + +void test_init_ftrace(int nsfd1, int nsfd2) +{ + get_ns_cookie(nsfd1, &ns_cookie1); + get_ns_cookie(nsfd2, &ns_cookie2); + /* Populate kernel config state */ + kernel_config_has(KCONFIG_FTRACE); +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/kconfig.c b/tools/testing/selftests/net/tcp_ao/lib/kconfig.c new file mode 100644 index 000000000000..9f1c175846f8 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/kconfig.c @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Check what features does the kernel support (where the selftest is running). + * Somewhat inspired by CRIU kerndat/kdat kernel features detector. + */ +#include <pthread.h> +#include "aolib.h" + +struct kconfig_t { + int _error; /* negative errno if not supported */ + int (*check_kconfig)(int *error); +}; + +static int has_net_ns(int *err) +{ + if (access("/proc/self/ns/net", F_OK) < 0) { + *err = errno; + if (errno == ENOENT) + return 0; + test_print("Unable to access /proc/self/ns/net: %m"); + return -errno; + } + return *err = errno = 0; +} + +static int has_veth(int *err) +{ + int orig_netns, ns_a, ns_b; + + orig_netns = open_netns(); + ns_a = unshare_open_netns(); + ns_b = unshare_open_netns(); + + *err = add_veth("check_veth", ns_a, ns_b); + + switch_ns(orig_netns); + close(orig_netns); + close(ns_a); + close(ns_b); + return 0; +} + +static int has_tcp_ao(int *err) +{ + struct sockaddr_in addr = { + .sin_family = test_family, + }; + struct tcp_ao_add tmp = {}; + const char *password = DEFAULT_TEST_PASSWORD; + int sk, ret = 0; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) { + test_print("socket(): %m"); + return -errno; + } + + tmp.sndid = 100; + tmp.rcvid = 100; + tmp.keylen = strlen(password); + memcpy(tmp.key, password, strlen(password)); + strcpy(tmp.alg_name, "hmac(sha1)"); + memcpy(&tmp.addr, &addr, sizeof(addr)); + *err = 0; + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)) < 0) { + *err = -errno; + if (errno != ENOPROTOOPT) + ret = -errno; + } + close(sk); + return ret; +} + +static int has_tcp_md5(int *err) +{ + union tcp_addr addr_any = {}; + int sk, ret = 0; + + sk = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) { + test_print("socket(): %m"); + return -errno; + } + + /* + * Under CONFIG_CRYPTO_FIPS=y it fails with ENOMEM, rather with + * anything more descriptive. Oh well. + */ + *err = 0; + if (test_set_md5(sk, addr_any, 0, -1, DEFAULT_TEST_PASSWORD)) { + *err = -errno; + if (errno != ENOPROTOOPT && errno == ENOMEM) { + test_print("setsockopt(TCP_MD5SIG_EXT): %m"); + ret = -errno; + } + } + close(sk); + return ret; +} + +static int has_vrfs(int *err) +{ + int orig_netns, ns_test, ret = 0; + + orig_netns = open_netns(); + ns_test = unshare_open_netns(); + + *err = add_vrf("ksft-check", 55, 101, ns_test); + if (*err && *err != -EOPNOTSUPP) { + test_print("Failed to add a VRF: %d", *err); + ret = *err; + } + + switch_ns(orig_netns); + close(orig_netns); + close(ns_test); + return ret; +} + +static int has_ftrace(int *err) +{ + *err = test_setup_tracing(); + return 0; +} + +#define KCONFIG_UNKNOWN 1 +static pthread_mutex_t kconfig_lock = PTHREAD_MUTEX_INITIALIZER; +static struct kconfig_t kconfig[__KCONFIG_LAST__] = { + { KCONFIG_UNKNOWN, has_net_ns }, + { KCONFIG_UNKNOWN, has_veth }, + { KCONFIG_UNKNOWN, has_tcp_ao }, + { KCONFIG_UNKNOWN, has_tcp_md5 }, + { KCONFIG_UNKNOWN, has_vrfs }, + { KCONFIG_UNKNOWN, has_ftrace }, +}; + +const char *tests_skip_reason[__KCONFIG_LAST__] = { + "Tests require network namespaces support (CONFIG_NET_NS)", + "Tests require veth support (CONFIG_VETH)", + "Tests require TCP-AO support (CONFIG_TCP_AO)", + "setsockopt(TCP_MD5SIG_EXT) is not supported (CONFIG_TCP_MD5)", + "VRFs are not supported (CONFIG_NET_VRF)", + "Ftrace points are not supported (CONFIG_TRACEPOINTS)", +}; + +bool kernel_config_has(enum test_needs_kconfig k) +{ + bool ret; + + pthread_mutex_lock(&kconfig_lock); + if (kconfig[k]._error == KCONFIG_UNKNOWN) { + if (kconfig[k].check_kconfig(&kconfig[k]._error)) + test_error("Failed to initialize kconfig %u", k); + } + ret = kconfig[k]._error == 0; + pthread_mutex_unlock(&kconfig_lock); + return ret; +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/netlink.c b/tools/testing/selftests/net/tcp_ao/lib/netlink.c new file mode 100644 index 000000000000..7f108493a29a --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/netlink.c @@ -0,0 +1,413 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Original from tools/testing/selftests/net/ipsec.c */ +#include <linux/netlink.h> +#include <linux/random.h> +#include <linux/rtnetlink.h> +#include <linux/veth.h> +#include <net/if.h> +#include <stdint.h> +#include <string.h> +#include <sys/socket.h> + +#include "aolib.h" + +#define MAX_PAYLOAD 2048 + +static int netlink_sock(int *sock, uint32_t *seq_nr, int proto) +{ + if (*sock > 0) { + seq_nr++; + return 0; + } + + *sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto); + if (*sock < 0) { + test_print("socket(AF_NETLINK)"); + return -1; + } + + randomize_buffer(seq_nr, sizeof(*seq_nr)); + + return 0; +} + +static int netlink_check_answer(int sock, bool quite) +{ + struct nlmsgerror { + struct nlmsghdr hdr; + int error; + struct nlmsghdr orig_msg; + } answer; + + if (recv(sock, &answer, sizeof(answer), 0) < 0) { + test_print("recv()"); + return -1; + } else if (answer.hdr.nlmsg_type != NLMSG_ERROR) { + test_print("expected NLMSG_ERROR, got %d", + (int)answer.hdr.nlmsg_type); + return -1; + } else if (answer.error) { + if (!quite) { + test_print("NLMSG_ERROR: %d: %s", + answer.error, strerror(-answer.error)); + } + return answer.error; + } + + return 0; +} + +static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh) +{ + return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len)); +} + +static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz, + unsigned short rta_type, const void *payload, size_t size) +{ + /* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */ + struct rtattr *attr = rtattr_hdr(nh); + size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size); + + if (req_sz < nl_size) { + test_print("req buf is too small: %zu < %zu", req_sz, nl_size); + return -1; + } + nh->nlmsg_len = nl_size; + + attr->rta_len = RTA_LENGTH(size); + attr->rta_type = rta_type; + memcpy(RTA_DATA(attr), payload, size); + + return 0; +} + +static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz, + unsigned short rta_type, const void *payload, size_t size) +{ + struct rtattr *ret = rtattr_hdr(nh); + + if (rtattr_pack(nh, req_sz, rta_type, payload, size)) + return 0; + + return ret; +} + +static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz, + unsigned short rta_type) +{ + return _rtattr_begin(nh, req_sz, rta_type, 0, 0); +} + +static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr) +{ + char *nlmsg_end = (char *)nh + nh->nlmsg_len; + + attr->rta_len = nlmsg_end - (char *)attr; +} + +static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz, + const char *peer, int ns) +{ + struct ifinfomsg pi; + struct rtattr *peer_attr; + + memset(&pi, 0, sizeof(pi)); + pi.ifi_family = AF_UNSPEC; + pi.ifi_change = 0xFFFFFFFF; + + peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi)); + if (!peer_attr) + return -1; + + if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer))) + return -1; + + if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns))) + return -1; + + rtattr_end(nh, peer_attr); + + return 0; +} + +static int __add_veth(int sock, uint32_t seq, const char *name, + int ns_a, int ns_b) +{ + uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE; + struct { + struct nlmsghdr nh; + struct ifinfomsg info; + char attrbuf[MAX_PAYLOAD]; + } req; + static const char veth_type[] = "veth"; + struct rtattr *link_info, *info_data; + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info)); + req.nh.nlmsg_type = RTM_NEWLINK; + req.nh.nlmsg_flags = flags; + req.nh.nlmsg_seq = seq; + req.info.ifi_family = AF_UNSPEC; + req.info.ifi_change = 0xFFFFFFFF; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, name, strlen(name))) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a))) + return -1; + + link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO); + if (!link_info) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type))) + return -1; + + info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA); + if (!info_data) + return -1; + + if (veth_pack_peerb(&req.nh, sizeof(req), name, ns_b)) + return -1; + + rtattr_end(&req.nh, info_data); + rtattr_end(&req.nh, link_info); + + if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) { + test_print("send()"); + return -1; + } + return netlink_check_answer(sock, false); +} + +int add_veth(const char *name, int nsfda, int nsfdb) +{ + int route_sock = -1, ret; + uint32_t route_seq; + + if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) + test_error("Failed to open netlink route socket\n"); + + ret = __add_veth(route_sock, route_seq++, name, nsfda, nsfdb); + close(route_sock); + return ret; +} + +static int __ip_addr_add(int sock, uint32_t seq, const char *intf, + int family, union tcp_addr addr, uint8_t prefix) +{ + uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE; + struct { + struct nlmsghdr nh; + struct ifaddrmsg info; + char attrbuf[MAX_PAYLOAD]; + } req; + size_t addr_len = (family == AF_INET) ? sizeof(struct in_addr) : + sizeof(struct in6_addr); + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info)); + req.nh.nlmsg_type = RTM_NEWADDR; + req.nh.nlmsg_flags = flags; + req.nh.nlmsg_seq = seq; + req.info.ifa_family = family; + req.info.ifa_prefixlen = prefix; + req.info.ifa_index = if_nametoindex(intf); + req.info.ifa_flags = IFA_F_NODAD; + + if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, addr_len)) + return -1; + + if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) { + test_print("send()"); + return -1; + } + return netlink_check_answer(sock, true); +} + +int ip_addr_add(const char *intf, int family, + union tcp_addr addr, uint8_t prefix) +{ + int route_sock = -1, ret; + uint32_t route_seq; + + if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) + test_error("Failed to open netlink route socket\n"); + + ret = __ip_addr_add(route_sock, route_seq++, intf, + family, addr, prefix); + + close(route_sock); + return ret; +} + +static int __ip_route_add(int sock, uint32_t seq, const char *intf, int family, + union tcp_addr src, union tcp_addr dst, uint8_t vrf) +{ + struct { + struct nlmsghdr nh; + struct rtmsg rt; + char attrbuf[MAX_PAYLOAD]; + } req; + unsigned int index = if_nametoindex(intf); + size_t addr_len = (family == AF_INET) ? sizeof(struct in_addr) : + sizeof(struct in6_addr); + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.rt)); + req.nh.nlmsg_type = RTM_NEWROUTE; + req.nh.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE; + req.nh.nlmsg_seq = seq; + req.rt.rtm_family = family; + req.rt.rtm_dst_len = (family == AF_INET) ? 32 : 128; + req.rt.rtm_table = vrf; + req.rt.rtm_protocol = RTPROT_BOOT; + req.rt.rtm_scope = RT_SCOPE_UNIVERSE; + req.rt.rtm_type = RTN_UNICAST; + + if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, addr_len)) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, addr_len)) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index))) + return -1; + + if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) { + test_print("send()"); + return -1; + } + + return netlink_check_answer(sock, true); +} + +int ip_route_add_vrf(const char *intf, int family, + union tcp_addr src, union tcp_addr dst, uint8_t vrf) +{ + int route_sock = -1, ret; + uint32_t route_seq; + + if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) + test_error("Failed to open netlink route socket\n"); + + ret = __ip_route_add(route_sock, route_seq++, intf, + family, src, dst, vrf); + + close(route_sock); + return ret; +} + +int ip_route_add(const char *intf, int family, + union tcp_addr src, union tcp_addr dst) +{ + return ip_route_add_vrf(intf, family, src, dst, RT_TABLE_MAIN); +} + +static int __link_set_up(int sock, uint32_t seq, const char *intf) +{ + struct { + struct nlmsghdr nh; + struct ifinfomsg info; + char attrbuf[MAX_PAYLOAD]; + } req; + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info)); + req.nh.nlmsg_type = RTM_NEWLINK; + req.nh.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.nh.nlmsg_seq = seq; + req.info.ifi_family = AF_UNSPEC; + req.info.ifi_change = 0xFFFFFFFF; + req.info.ifi_index = if_nametoindex(intf); + req.info.ifi_flags = IFF_UP; + req.info.ifi_change = IFF_UP; + + if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) { + test_print("send()"); + return -1; + } + return netlink_check_answer(sock, false); +} + +int link_set_up(const char *intf) +{ + int route_sock = -1, ret; + uint32_t route_seq; + + if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) + test_error("Failed to open netlink route socket\n"); + + ret = __link_set_up(route_sock, route_seq++, intf); + + close(route_sock); + return ret; +} + +static int __add_vrf(int sock, uint32_t seq, const char *name, + uint32_t tabid, int ifindex, int nsfd) +{ + uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE; + struct { + struct nlmsghdr nh; + struct ifinfomsg info; + char attrbuf[MAX_PAYLOAD]; + } req; + static const char vrf_type[] = "vrf"; + struct rtattr *link_info, *info_data; + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info)); + req.nh.nlmsg_type = RTM_NEWLINK; + req.nh.nlmsg_flags = flags; + req.nh.nlmsg_seq = seq; + req.info.ifi_family = AF_UNSPEC; + req.info.ifi_change = 0xFFFFFFFF; + req.info.ifi_index = ifindex; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, name, strlen(name))) + return -1; + + if (nsfd >= 0) + if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, + &nsfd, sizeof(nsfd))) + return -1; + + link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO); + if (!link_info) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, vrf_type, sizeof(vrf_type))) + return -1; + + info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA); + if (!info_data) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_VRF_TABLE, + &tabid, sizeof(tabid))) + return -1; + + rtattr_end(&req.nh, info_data); + rtattr_end(&req.nh, link_info); + + if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) { + test_print("send()"); + return -1; + } + return netlink_check_answer(sock, true); +} + +int add_vrf(const char *name, uint32_t tabid, int ifindex, int nsfd) +{ + int route_sock = -1, ret; + uint32_t route_seq; + + if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) + test_error("Failed to open netlink route socket\n"); + + ret = __add_vrf(route_sock, route_seq++, name, tabid, ifindex, nsfd); + close(route_sock); + return ret; +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/proc.c b/tools/testing/selftests/net/tcp_ao/lib/proc.c new file mode 100644 index 000000000000..8b984fa04286 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/proc.c @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <inttypes.h> +#include <pthread.h> +#include <stdio.h> +#include "../../../../../include/linux/compiler.h" +#include "../../../../../include/linux/kernel.h" +#include "aolib.h" + +struct netstat_counter { + uint64_t val; + char *name; +}; + +struct netstat { + char *header_name; + struct netstat *next; + size_t counters_nr; + struct netstat_counter *counters; +}; + +static struct netstat *lookup_type(struct netstat *ns, + const char *type, size_t len) +{ + while (ns != NULL) { + size_t cmp = max(len, strlen(ns->header_name)); + + if (!strncmp(ns->header_name, type, cmp)) + return ns; + ns = ns->next; + } + return NULL; +} + +static struct netstat *lookup_get(struct netstat *ns, + const char *type, const size_t len) +{ + struct netstat *ret; + + ret = lookup_type(ns, type, len); + if (ret != NULL) + return ret; + + ret = malloc(sizeof(struct netstat)); + if (!ret) + test_error("malloc()"); + + ret->header_name = strndup(type, len); + if (ret->header_name == NULL) + test_error("strndup()"); + ret->next = ns; + ret->counters_nr = 0; + ret->counters = NULL; + + return ret; +} + +static struct netstat *lookup_get_column(struct netstat *ns, const char *line) +{ + char *column; + + column = strchr(line, ':'); + if (!column) + test_error("can't parse netstat file"); + + return lookup_get(ns, line, column - line); +} + +static void netstat_read_type(FILE *fnetstat, struct netstat **dest, char *line) +{ + struct netstat *type = lookup_get_column(*dest, line); + const char *pos = line; + size_t i, nr_elems = 0; + char tmp; + + while ((pos = strchr(pos, ' '))) { + nr_elems++; + pos++; + } + + *dest = type; + type->counters = reallocarray(type->counters, + type->counters_nr + nr_elems, + sizeof(struct netstat_counter)); + if (!type->counters) + test_error("reallocarray()"); + + pos = strchr(line, ' ') + 1; + + if (fscanf(fnetstat, "%[^ :]", type->header_name) == EOF) + test_error("fscanf(%s)", type->header_name); + if (fread(&tmp, 1, 1, fnetstat) != 1 || tmp != ':') + test_error("Unexpected netstat format (%c)", tmp); + + for (i = type->counters_nr; i < type->counters_nr + nr_elems; i++) { + struct netstat_counter *nc = &type->counters[i]; + const char *new_pos = strchr(pos, ' '); + const char *fmt = " %" PRIu64; + + if (new_pos == NULL) + new_pos = strchr(pos, '\n'); + + nc->name = strndup(pos, new_pos - pos); + if (nc->name == NULL) + test_error("strndup()"); + + if (unlikely(!strcmp(nc->name, "MaxConn"))) + fmt = " %" PRId64; /* MaxConn is signed, RFC 2012 */ + if (fscanf(fnetstat, fmt, &nc->val) != 1) + test_error("fscanf(%s)", nc->name); + pos = new_pos + 1; + } + type->counters_nr += nr_elems; + + if (fread(&tmp, 1, 1, fnetstat) != 1 || tmp != '\n') + test_error("Unexpected netstat format"); +} + +static const char *snmp6_name = "Snmp6"; +static void snmp6_read(FILE *fnetstat, struct netstat **dest) +{ + struct netstat *type = lookup_get(*dest, snmp6_name, strlen(snmp6_name)); + char *counter_name; + size_t i; + + for (i = type->counters_nr;; i++) { + struct netstat_counter *nc; + uint64_t counter; + + if (fscanf(fnetstat, "%ms", &counter_name) == EOF) + break; + if (fscanf(fnetstat, "%" PRIu64, &counter) == EOF) + test_error("Unexpected snmp6 format"); + type->counters = reallocarray(type->counters, i + 1, + sizeof(struct netstat_counter)); + if (!type->counters) + test_error("reallocarray()"); + nc = &type->counters[i]; + nc->name = counter_name; + nc->val = counter; + } + type->counters_nr = i; + *dest = type; +} + +struct netstat *netstat_read(void) +{ + struct netstat *ret = 0; + size_t line_sz = 0; + char *line = NULL; + FILE *fnetstat; + + /* + * Opening thread-self instead of /proc/net/... as the latter + * points to /proc/self/net/ which instantiates thread-leader's + * net-ns, see: + * commit 155134fef2b6 ("Revert "proc: Point /proc/{mounts,net} at..") + */ + errno = 0; + fnetstat = fopen("/proc/thread-self/net/netstat", "r"); + if (fnetstat == NULL) + test_error("failed to open /proc/net/netstat"); + + while (getline(&line, &line_sz, fnetstat) != -1) + netstat_read_type(fnetstat, &ret, line); + fclose(fnetstat); + + errno = 0; + fnetstat = fopen("/proc/thread-self/net/snmp", "r"); + if (fnetstat == NULL) + test_error("failed to open /proc/net/snmp"); + + while (getline(&line, &line_sz, fnetstat) != -1) + netstat_read_type(fnetstat, &ret, line); + fclose(fnetstat); + + errno = 0; + fnetstat = fopen("/proc/thread-self/net/snmp6", "r"); + if (fnetstat == NULL) + test_error("failed to open /proc/net/snmp6"); + + snmp6_read(fnetstat, &ret); + fclose(fnetstat); + + free(line); + return ret; +} + +void netstat_free(struct netstat *ns) +{ + while (ns != NULL) { + struct netstat *prev = ns; + size_t i; + + free(ns->header_name); + for (i = 0; i < ns->counters_nr; i++) + free(ns->counters[i].name); + free(ns->counters); + ns = ns->next; + free(prev); + } +} + +static inline void +__netstat_print_diff(uint64_t a, struct netstat *nsb, size_t i) +{ + if (unlikely(!strcmp(nsb->header_name, "MaxConn"))) { + test_print("%8s %25s: %" PRId64 " => %" PRId64, + nsb->header_name, nsb->counters[i].name, + a, nsb->counters[i].val); + return; + } + + test_print("%8s %25s: %" PRIu64 " => %" PRIu64, nsb->header_name, + nsb->counters[i].name, a, nsb->counters[i].val); +} + +void netstat_print_diff(struct netstat *nsa, struct netstat *nsb) +{ + size_t i, j; + + while (nsb != NULL) { + if (unlikely(strcmp(nsb->header_name, nsa->header_name))) { + for (i = 0; i < nsb->counters_nr; i++) + __netstat_print_diff(0, nsb, i); + nsb = nsb->next; + continue; + } + + if (nsb->counters_nr < nsa->counters_nr) + test_error("Unexpected: some counters disappeared!"); + + for (j = 0, i = 0; i < nsb->counters_nr; i++) { + if (strcmp(nsb->counters[i].name, nsa->counters[j].name)) { + __netstat_print_diff(0, nsb, i); + continue; + } + + if (nsa->counters[j].val == nsb->counters[i].val) { + j++; + continue; + } + + __netstat_print_diff(nsa->counters[j].val, nsb, i); + j++; + } + if (j != nsa->counters_nr) + test_error("Unexpected: some counters disappeared!"); + + nsb = nsb->next; + nsa = nsa->next; + } +} + +uint64_t netstat_get(struct netstat *ns, const char *name, bool *not_found) +{ + if (not_found) + *not_found = false; + + while (ns != NULL) { + size_t i; + + for (i = 0; i < ns->counters_nr; i++) { + if (!strcmp(name, ns->counters[i].name)) + return ns->counters[i].val; + } + + ns = ns->next; + } + + if (not_found) + *not_found = true; + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/repair.c b/tools/testing/selftests/net/tcp_ao/lib/repair.c new file mode 100644 index 000000000000..9893b3ba69f5 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/repair.c @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: GPL-2.0 +/* This is over-simplified TCP_REPAIR for TCP_ESTABLISHED sockets + * It tests that TCP-AO enabled connection can be restored. + * For the proper socket repair see: + * https://github.com/checkpoint-restore/criu/blob/criu-dev/soccr/soccr.h + */ +#include <fcntl.h> +#include <linux/sockios.h> +#include <sys/ioctl.h> +#include "aolib.h" + +#ifndef TCPOPT_MAXSEG +# define TCPOPT_MAXSEG 2 +#endif +#ifndef TCPOPT_WINDOW +# define TCPOPT_WINDOW 3 +#endif +#ifndef TCPOPT_SACK_PERMITTED +# define TCPOPT_SACK_PERMITTED 4 +#endif +#ifndef TCPOPT_TIMESTAMP +# define TCPOPT_TIMESTAMP 8 +#endif + +enum { + TCP_ESTABLISHED = 1, + TCP_SYN_SENT, + TCP_SYN_RECV, + TCP_FIN_WAIT1, + TCP_FIN_WAIT2, + TCP_TIME_WAIT, + TCP_CLOSE, + TCP_CLOSE_WAIT, + TCP_LAST_ACK, + TCP_LISTEN, + TCP_CLOSING, /* Now a valid state */ + TCP_NEW_SYN_RECV, + + TCP_MAX_STATES /* Leave at the end! */ +}; + +static void test_sock_checkpoint_queue(int sk, int queue, int qlen, + struct tcp_sock_queue *q) +{ + socklen_t len; + int ret; + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue))) + test_error("setsockopt(TCP_REPAIR_QUEUE)"); + + len = sizeof(q->seq); + ret = getsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &q->seq, &len); + if (ret || len != sizeof(q->seq)) + test_error("getsockopt(TCP_QUEUE_SEQ): %d", (int)len); + + if (!qlen) { + q->buf = NULL; + return; + } + + q->buf = malloc(qlen); + if (q->buf == NULL) + test_error("malloc()"); + ret = recv(sk, q->buf, qlen, MSG_PEEK | MSG_DONTWAIT); + if (ret != qlen) + test_error("recv(%d): %d", qlen, ret); +} + +void __test_sock_checkpoint(int sk, struct tcp_sock_state *state, + void *addr, size_t addr_size) +{ + socklen_t len = sizeof(state->info); + int ret; + + memset(state, 0, sizeof(*state)); + + ret = getsockopt(sk, SOL_TCP, TCP_INFO, &state->info, &len); + if (ret || len != sizeof(state->info)) + test_error("getsockopt(TCP_INFO): %d", (int)len); + + len = addr_size; + if (getsockname(sk, addr, &len) || len != addr_size) + test_error("getsockname(): %d", (int)len); + + len = sizeof(state->trw); + ret = getsockopt(sk, SOL_TCP, TCP_REPAIR_WINDOW, &state->trw, &len); + if (ret || len != sizeof(state->trw)) + test_error("getsockopt(TCP_REPAIR_WINDOW): %d", (int)len); + + if (ioctl(sk, SIOCOUTQ, &state->outq_len)) + test_error("ioctl(SIOCOUTQ)"); + + if (ioctl(sk, SIOCOUTQNSD, &state->outq_nsd_len)) + test_error("ioctl(SIOCOUTQNSD)"); + test_sock_checkpoint_queue(sk, TCP_SEND_QUEUE, state->outq_len, &state->out); + + if (ioctl(sk, SIOCINQ, &state->inq_len)) + test_error("ioctl(SIOCINQ)"); + test_sock_checkpoint_queue(sk, TCP_RECV_QUEUE, state->inq_len, &state->in); + + if (state->info.tcpi_state == TCP_CLOSE) + state->outq_len = state->outq_nsd_len = 0; + + len = sizeof(state->mss); + ret = getsockopt(sk, SOL_TCP, TCP_MAXSEG, &state->mss, &len); + if (ret || len != sizeof(state->mss)) + test_error("getsockopt(TCP_MAXSEG): %d", (int)len); + + len = sizeof(state->timestamp); + ret = getsockopt(sk, SOL_TCP, TCP_TIMESTAMP, &state->timestamp, &len); + if (ret || len != sizeof(state->timestamp)) + test_error("getsockopt(TCP_TIMESTAMP): %d", (int)len); +} + +void test_ao_checkpoint(int sk, struct tcp_ao_repair *state) +{ + socklen_t len = sizeof(*state); + int ret; + + memset(state, 0, sizeof(*state)); + + ret = getsockopt(sk, SOL_TCP, TCP_AO_REPAIR, state, &len); + if (ret || len != sizeof(*state)) + test_error("getsockopt(TCP_AO_REPAIR): %d", (int)len); +} + +static void test_sock_restore_seq(int sk, int queue, uint32_t seq) +{ + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue))) + test_error("setsockopt(TCP_REPAIR_QUEUE)"); + + if (setsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &seq, sizeof(seq))) + test_error("setsockopt(TCP_QUEUE_SEQ)"); +} + +static void test_sock_restore_queue(int sk, int queue, void *buf, int len) +{ + int chunk = len; + size_t off = 0; + + if (len == 0) + return; + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue))) + test_error("setsockopt(TCP_REPAIR_QUEUE)"); + + do { + int ret; + + ret = send(sk, buf + off, chunk, 0); + if (ret <= 0) { + if (chunk > 1024) { + chunk >>= 1; + continue; + } + test_error("send()"); + } + off += ret; + len -= ret; + } while (len > 0); +} + +void __test_sock_restore(int sk, const char *device, + struct tcp_sock_state *state, + void *saddr, void *daddr, size_t addr_size) +{ + struct tcp_repair_opt opts[4]; + unsigned int opt_nr = 0; + long flags; + + if (bind(sk, saddr, addr_size)) + test_error("bind()"); + + flags = fcntl(sk, F_GETFL); + if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0)) + test_error("fcntl()"); + + test_sock_restore_seq(sk, TCP_RECV_QUEUE, state->in.seq - state->inq_len); + test_sock_restore_seq(sk, TCP_SEND_QUEUE, state->out.seq - state->outq_len); + + if (device != NULL && setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, + device, strlen(device) + 1)) + test_error("setsockopt(SO_BINDTODEVICE, %s)", device); + + if (connect(sk, daddr, addr_size)) + test_error("connect()"); + + if (state->info.tcpi_options & TCPI_OPT_SACK) { + opts[opt_nr].opt_code = TCPOPT_SACK_PERMITTED; + opts[opt_nr].opt_val = 0; + opt_nr++; + } + if (state->info.tcpi_options & TCPI_OPT_WSCALE) { + opts[opt_nr].opt_code = TCPOPT_WINDOW; + opts[opt_nr].opt_val = state->info.tcpi_snd_wscale + + (state->info.tcpi_rcv_wscale << 16); + opt_nr++; + } + if (state->info.tcpi_options & TCPI_OPT_TIMESTAMPS) { + opts[opt_nr].opt_code = TCPOPT_TIMESTAMP; + opts[opt_nr].opt_val = 0; + opt_nr++; + } + opts[opt_nr].opt_code = TCPOPT_MAXSEG; + opts[opt_nr].opt_val = state->mss; + opt_nr++; + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_OPTIONS, opts, opt_nr * sizeof(opts[0]))) + test_error("setsockopt(TCP_REPAIR_OPTIONS)"); + + if (state->info.tcpi_options & TCPI_OPT_TIMESTAMPS) { + if (setsockopt(sk, SOL_TCP, TCP_TIMESTAMP, + &state->timestamp, opt_nr * sizeof(opts[0]))) + test_error("setsockopt(TCP_TIMESTAMP)"); + } + test_sock_restore_queue(sk, TCP_RECV_QUEUE, state->in.buf, state->inq_len); + test_sock_restore_queue(sk, TCP_SEND_QUEUE, state->out.buf, state->outq_len); + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_WINDOW, &state->trw, sizeof(state->trw))) + test_error("setsockopt(TCP_REPAIR_WINDOW)"); +} + +void test_ao_restore(int sk, struct tcp_ao_repair *state) +{ + if (setsockopt(sk, SOL_TCP, TCP_AO_REPAIR, state, sizeof(*state))) + test_error("setsockopt(TCP_AO_REPAIR)"); +} + +void test_sock_state_free(struct tcp_sock_state *state) +{ + free(state->out.buf); + free(state->in.buf); +} + +void test_enable_repair(int sk) +{ + int val = TCP_REPAIR_ON; + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val))) + test_error("setsockopt(TCP_REPAIR)"); +} + +void test_disable_repair(int sk) +{ + int val = TCP_REPAIR_OFF_NO_WP; + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val))) + test_error("setsockopt(TCP_REPAIR)"); +} + +void test_kill_sk(int sk) +{ + test_enable_repair(sk); + close(sk); +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/setup.c b/tools/testing/selftests/net/tcp_ao/lib/setup.c new file mode 100644 index 000000000000..a27cc03c9fbd --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/setup.c @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <fcntl.h> +#include <pthread.h> +#include <sched.h> +#include <signal.h> +#include "aolib.h" + +/* + * Can't be included in the header: it defines static variables which + * will be unique to every object. Let's include it only once here. + */ +#include "../../../kselftest.h" + +/* Prevent overriding of one thread's output by another */ +static pthread_mutex_t ksft_print_lock = PTHREAD_MUTEX_INITIALIZER; + +void __test_msg(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_print_msg("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} +void __test_ok(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_test_result_pass("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} +void __test_fail(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_test_result_fail("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} +void __test_xfail(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_test_result_xfail("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} +void __test_error(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_test_result_error("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} +void __test_skip(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_test_result_skip("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} + +static volatile int failed; +static volatile int skipped; + +void test_failed(void) +{ + failed = 1; +} + +static void test_exit(void) +{ + if (failed) { + ksft_exit_fail(); + } else if (skipped) { + /* ksft_exit_skip() is different from ksft_exit_*() */ + ksft_print_cnts(); + exit(KSFT_SKIP); + } else { + ksft_exit_pass(); + } +} + +struct dlist_t { + void (*destruct)(void); + struct dlist_t *next; +}; +static struct dlist_t *destructors_list; + +void test_add_destructor(void (*d)(void)) +{ + struct dlist_t *p; + + p = malloc(sizeof(struct dlist_t)); + if (p == NULL) + test_error("malloc() failed"); + + p->next = destructors_list; + p->destruct = d; + destructors_list = p; +} + +static void test_destructor(void) __attribute__((destructor)); +static void test_destructor(void) +{ + while (destructors_list) { + struct dlist_t *p = destructors_list->next; + + destructors_list->destruct(); + free(destructors_list); + destructors_list = p; + } + test_exit(); +} + +static void sig_int(int signo) +{ + test_error("Caught SIGINT - exiting"); +} + +int open_netns(void) +{ + const char *netns_path = "/proc/thread-self/ns/net"; + int fd; + + fd = open(netns_path, O_RDONLY); + if (fd < 0) + test_error("open(%s)", netns_path); + return fd; +} + +int unshare_open_netns(void) +{ + if (unshare(CLONE_NEWNET) != 0) + test_error("unshare()"); + + return open_netns(); +} + +void switch_ns(int fd) +{ + if (setns(fd, CLONE_NEWNET)) + test_error("setns()"); +} + +int switch_save_ns(int new_ns) +{ + int ret = open_netns(); + + switch_ns(new_ns); + return ret; +} + +void switch_close_ns(int fd) +{ + if (setns(fd, CLONE_NEWNET)) + test_error("setns()"); + close(fd); +} + +static int nsfd_outside = -1; +static int nsfd_parent = -1; +static int nsfd_child = -1; +const char veth_name[] = "ktst-veth"; + +static void init_namespaces(void) +{ + nsfd_outside = open_netns(); + nsfd_parent = unshare_open_netns(); + nsfd_child = unshare_open_netns(); +} + +static void link_init(const char *veth, int family, uint8_t prefix, + union tcp_addr addr, union tcp_addr dest) +{ + if (link_set_up(veth)) + test_error("Failed to set link up"); + if (ip_addr_add(veth, family, addr, prefix)) + test_error("Failed to add ip address"); + if (ip_route_add(veth, family, addr, dest)) + test_error("Failed to add route"); +} + +static unsigned int nr_threads = 1; + +static pthread_mutex_t sync_lock = PTHREAD_MUTEX_INITIALIZER; +static pthread_cond_t sync_cond = PTHREAD_COND_INITIALIZER; +static volatile unsigned int stage_threads[2]; +static volatile unsigned int stage_nr; + +/* synchronize all threads in the same stage */ +void synchronize_threads(void) +{ + unsigned int q = stage_nr; + + pthread_mutex_lock(&sync_lock); + stage_threads[q]++; + if (stage_threads[q] == nr_threads) { + stage_nr ^= 1; + stage_threads[stage_nr] = 0; + pthread_cond_signal(&sync_cond); + } + while (stage_threads[q] < nr_threads) + pthread_cond_wait(&sync_cond, &sync_lock); + pthread_mutex_unlock(&sync_lock); +} + +__thread union tcp_addr this_ip_addr; +__thread union tcp_addr this_ip_dest; +int test_family; + +struct new_pthread_arg { + thread_fn func; + union tcp_addr my_ip; + union tcp_addr dest_ip; +}; +static void *new_pthread_entry(void *arg) +{ + struct new_pthread_arg *p = arg; + + this_ip_addr = p->my_ip; + this_ip_dest = p->dest_ip; + p->func(NULL); /* shouldn't return */ + exit(KSFT_FAIL); +} + +static void __test_skip_all(const char *msg) +{ + ksft_set_plan(1); + ksft_print_header(); + skipped = 1; + test_skip("%s", msg); + exit(KSFT_SKIP); +} + +void __test_init(unsigned int ntests, int family, unsigned int prefix, + union tcp_addr addr1, union tcp_addr addr2, + thread_fn peer1, thread_fn peer2) +{ + struct sigaction sa = { + .sa_handler = sig_int, + .sa_flags = SA_RESTART, + }; + time_t seed = time(NULL); + + sigemptyset(&sa.sa_mask); + if (sigaction(SIGINT, &sa, NULL)) + test_error("Can't set SIGINT handler"); + + test_family = family; + if (!kernel_config_has(KCONFIG_NET_NS)) + __test_skip_all(tests_skip_reason[KCONFIG_NET_NS]); + if (!kernel_config_has(KCONFIG_VETH)) + __test_skip_all(tests_skip_reason[KCONFIG_VETH]); + if (!kernel_config_has(KCONFIG_TCP_AO)) + __test_skip_all(tests_skip_reason[KCONFIG_TCP_AO]); + + ksft_set_plan(ntests); + test_print("rand seed %u", (unsigned int)seed); + srand(seed); + + ksft_print_header(); + init_namespaces(); + test_init_ftrace(nsfd_parent, nsfd_child); + + if (add_veth(veth_name, nsfd_parent, nsfd_child)) + test_error("Failed to add veth"); + + switch_ns(nsfd_child); + link_init(veth_name, family, prefix, addr2, addr1); + if (peer2) { + struct new_pthread_arg targ; + pthread_t t; + + targ.my_ip = addr2; + targ.dest_ip = addr1; + targ.func = peer2; + nr_threads++; + if (pthread_create(&t, NULL, new_pthread_entry, &targ)) + test_error("Failed to create pthread"); + } + switch_ns(nsfd_parent); + link_init(veth_name, family, prefix, addr1, addr2); + + this_ip_addr = addr1; + this_ip_dest = addr2; + peer1(NULL); + if (failed) + exit(KSFT_FAIL); + else + exit(KSFT_PASS); +} + +/* /proc/sys/net/core/optmem_max artifically limits the amount of memory + * that can be allocated with sock_kmalloc() on each socket in the system. + * It is not virtualized in v6.7, so it has to written outside test + * namespaces. To be nice a test will revert optmem back to the old value. + * Keeping it simple without any file lock, which means the tests that + * need to set/increase optmem value shouldn't run in parallel. + * Also, not re-entrant. + * Since commit f5769faeec36 ("net: Namespace-ify sysctl_optmem_max") + * it is per-namespace, keeping logic for non-virtualized optmem_max + * for v6.7, which supports TCP-AO. + */ +static const char *optmem_file = "/proc/sys/net/core/optmem_max"; +static size_t saved_optmem; +static int optmem_ns = -1; + +static bool is_optmem_namespaced(void) +{ + if (optmem_ns == -1) { + int old_ns = switch_save_ns(nsfd_child); + + optmem_ns = !access(optmem_file, F_OK); + switch_close_ns(old_ns); + } + return !!optmem_ns; +} + +size_t test_get_optmem(void) +{ + int old_ns = 0; + FILE *foptmem; + size_t ret; + + if (!is_optmem_namespaced()) + old_ns = switch_save_ns(nsfd_outside); + foptmem = fopen(optmem_file, "r"); + if (!foptmem) + test_error("failed to open %s", optmem_file); + + if (fscanf(foptmem, "%zu", &ret) != 1) + test_error("can't read from %s", optmem_file); + fclose(foptmem); + if (!is_optmem_namespaced()) + switch_close_ns(old_ns); + return ret; +} + +static void __test_set_optmem(size_t new, size_t *old) +{ + int old_ns = 0; + FILE *foptmem; + + if (old != NULL) + *old = test_get_optmem(); + + if (!is_optmem_namespaced()) + old_ns = switch_save_ns(nsfd_outside); + foptmem = fopen(optmem_file, "w"); + if (!foptmem) + test_error("failed to open %s", optmem_file); + + if (fprintf(foptmem, "%zu", new) <= 0) + test_error("can't write %zu to %s", new, optmem_file); + fclose(foptmem); + if (!is_optmem_namespaced()) + switch_close_ns(old_ns); +} + +static void test_revert_optmem(void) +{ + if (saved_optmem == 0) + return; + + __test_set_optmem(saved_optmem, NULL); +} + +void test_set_optmem(size_t value) +{ + if (saved_optmem == 0) { + __test_set_optmem(value, &saved_optmem); + test_add_destructor(test_revert_optmem); + } else { + __test_set_optmem(value, NULL); + } +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/sock.c b/tools/testing/selftests/net/tcp_ao/lib/sock.c new file mode 100644 index 000000000000..ef8e9031d47a --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/sock.c @@ -0,0 +1,730 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <alloca.h> +#include <fcntl.h> +#include <inttypes.h> +#include <string.h> +#include "../../../../../include/linux/kernel.h" +#include "../../../../../include/linux/stringify.h" +#include "aolib.h" + +const unsigned int test_server_port = 7010; +int __test_listen_socket(int backlog, void *addr, size_t addr_sz) +{ + int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + long flags; + + if (sk < 0) + test_error("socket()"); + + err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name, + strlen(veth_name) + 1); + if (err < 0) + test_error("setsockopt(SO_BINDTODEVICE)"); + + if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0) + test_error("bind()"); + + flags = fcntl(sk, F_GETFL); + if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0)) + test_error("fcntl()"); + + if (listen(sk, backlog)) + test_error("listen()"); + + return sk; +} + +static int __test_wait_fd(int sk, struct timeval *tv, bool write) +{ + fd_set fds, efds; + int ret; + socklen_t slen = sizeof(ret); + + FD_ZERO(&fds); + FD_SET(sk, &fds); + FD_ZERO(&efds); + FD_SET(sk, &efds); + + errno = 0; + if (write) + ret = select(sk + 1, NULL, &fds, &efds, tv); + else + ret = select(sk + 1, &fds, NULL, &efds, tv); + if (ret < 0) + return -errno; + if (ret == 0) { + errno = ETIMEDOUT; + return -ETIMEDOUT; + } + + if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen)) + return -errno; + if (ret) + return -ret; + return 0; +} + +int test_wait_fd(int sk, time_t sec, bool write) +{ + struct timeval tv = { .tv_sec = sec, }; + + return __test_wait_fd(sk, sec ? &tv : NULL, write); +} + +static bool __skpair_poll_should_stop(int sk, struct tcp_counters *c, + test_cnt condition) +{ + struct tcp_counters c2; + test_cnt diff; + + if (test_get_tcp_counters(sk, &c2)) + test_error("test_get_tcp_counters()"); + + diff = test_cmp_counters(c, &c2); + test_tcp_counters_free(&c2); + return (diff & condition) == condition; +} + +/* How often wake up and check netns counters & paired (*err) */ +#define POLL_USEC 150 +static int __test_skpair_poll(int sk, bool write, uint64_t timeout, + struct tcp_counters *c, test_cnt cond, + volatile int *err) +{ + uint64_t t; + + for (t = 0; t <= timeout * 1000000; t += POLL_USEC) { + struct timeval tv = { .tv_usec = POLL_USEC, }; + int ret; + + ret = __test_wait_fd(sk, &tv, write); + if (ret != -ETIMEDOUT) + return ret; + if (c && cond && __skpair_poll_should_stop(sk, c, cond)) + break; + if (err && *err) + return *err; + } + if (err) + *err = -ETIMEDOUT; + return -ETIMEDOUT; +} + +int __test_connect_socket(int sk, const char *device, + void *addr, size_t addr_sz, bool async) +{ + long flags; + int err; + + if (device != NULL) { + err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device, + strlen(device) + 1); + if (err < 0) + test_error("setsockopt(SO_BINDTODEVICE, %s)", device); + } + + flags = fcntl(sk, F_GETFL); + if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0)) + test_error("fcntl()"); + + if (connect(sk, addr, addr_sz) < 0) { + if (errno != EINPROGRESS) { + err = -errno; + goto out; + } + if (async) + return sk; + err = test_wait_fd(sk, TEST_TIMEOUT_SEC, 1); + if (err) + goto out; + } + return sk; + +out: + close(sk); + return err; +} + +int test_skpair_wait_poll(int sk, bool write, + test_cnt cond, volatile int *err) +{ + struct tcp_counters c; + int ret; + + *err = 0; + if (test_get_tcp_counters(sk, &c)) + test_error("test_get_tcp_counters()"); + synchronize_threads(); /* 1: init skpair & read nscounters */ + + ret = __test_skpair_poll(sk, write, TEST_TIMEOUT_SEC, &c, cond, err); + test_tcp_counters_free(&c); + return ret; +} + +int _test_skpair_connect_poll(int sk, const char *device, + void *addr, size_t addr_sz, + test_cnt condition, volatile int *err) +{ + struct tcp_counters c; + int ret; + + *err = 0; + if (test_get_tcp_counters(sk, &c)) + test_error("test_get_tcp_counters()"); + synchronize_threads(); /* 1: init skpair & read nscounters */ + ret = __test_connect_socket(sk, device, addr, addr_sz, true); + if (ret < 0) { + test_tcp_counters_free(&c); + return (*err = ret); + } + ret = __test_skpair_poll(sk, 1, TEST_TIMEOUT_SEC, &c, condition, err); + if (ret < 0) + close(sk); + test_tcp_counters_free(&c); + return ret; +} + +int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix, + int vrf, const char *password) +{ + size_t pwd_len = strlen(password); + struct tcp_md5sig md5sig = {}; + + md5sig.tcpm_keylen = pwd_len; + memcpy(md5sig.tcpm_key, password, pwd_len); + md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX; + md5sig.tcpm_prefixlen = prefix; + if (vrf >= 0) { + md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX; + md5sig.tcpm_ifindex = (uint8_t)vrf; + } + memcpy(&md5sig.tcpm_addr, addr, addr_sz); + + errno = 0; + return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT, + &md5sig, sizeof(md5sig)); +} + + +int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg, + void *addr, size_t addr_sz, bool set_current, bool set_rnext, + uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid, + uint8_t maclen, uint8_t keyflags, + uint8_t keylen, const char *key) +{ + memset(ao, 0, sizeof(struct tcp_ao_add)); + + ao->set_current = !!set_current; + ao->set_rnext = !!set_rnext; + ao->prefix = prefix; + ao->sndid = sndid; + ao->rcvid = rcvid; + ao->maclen = maclen; + ao->keyflags = keyflags; + ao->keylen = keylen; + ao->ifindex = vrf; + + memcpy(&ao->addr, addr, addr_sz); + + if (strlen(alg) > 64) + return -ENOBUFS; + strncpy(ao->alg_name, alg, 64); + + memcpy(ao->key, key, + (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen); + return 0; +} + +static int test_get_ao_keys_nr(int sk) +{ + struct tcp_ao_getsockopt tmp = {}; + socklen_t tmp_sz = sizeof(tmp); + int ret; + + tmp.nkeys = 1; + tmp.get_all = 1; + + ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz); + if (ret) + return -errno; + return (int)tmp.nkeys; +} + +int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out, + void *addr, size_t addr_sz, uint8_t prefix, + uint8_t sndid, uint8_t rcvid) +{ + struct tcp_ao_getsockopt tmp = {}; + socklen_t tmp_sz = sizeof(tmp); + int ret; + + memcpy(&tmp.addr, addr, addr_sz); + tmp.prefix = prefix; + tmp.sndid = sndid; + tmp.rcvid = rcvid; + tmp.nkeys = 1; + + ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz); + if (ret) + return ret; + if (tmp.nkeys != 1) + return -E2BIG; + *out = tmp; + return 0; +} + +int test_get_ao_info(int sk, struct tcp_ao_info_opt *out) +{ + socklen_t sz = sizeof(*out); + + out->reserved = 0; + out->reserved2 = 0; + if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz)) + return -errno; + if (sz != sizeof(*out)) + return -EMSGSIZE; + return 0; +} + +int test_set_ao_info(int sk, struct tcp_ao_info_opt *in) +{ + socklen_t sz = sizeof(*in); + + in->reserved = 0; + in->reserved2 = 0; + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz)) + return -errno; + return 0; +} + +int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a, + const struct tcp_ao_getsockopt *b) +{ + bool is_kdf_aes_128_cmac = false; + bool is_cmac_aes = false; + + if (!strcmp("cmac(aes128)", a->alg_name)) { + is_kdf_aes_128_cmac = (a->keylen != 16); + is_cmac_aes = true; + } + +#define __cmp_ao(member) \ +do { \ + if (b->member != a->member) { \ + test_fail("getsockopt(): " __stringify(member) " %u != %u", \ + b->member, a->member); \ + return -1; \ + } \ +} while(0) + __cmp_ao(sndid); + __cmp_ao(rcvid); + __cmp_ao(prefix); + __cmp_ao(keyflags); + __cmp_ao(ifindex); + if (a->maclen) { + __cmp_ao(maclen); + } else if (b->maclen != 12) { + test_fail("getsockopt(): expected default maclen 12, but it's %u", + b->maclen); + return -1; + } + if (!is_kdf_aes_128_cmac) { + __cmp_ao(keylen); + } else if (b->keylen != 16) { + test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u", + b->keylen); + return -1; + } +#undef __cmp_ao + if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) { + test_fail("getsockopt(): returned key is different `%s' != `%s'", + b->key, a->key); + return -1; + } + if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) { + test_fail("getsockopt(): returned address is different"); + return -1; + } + if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) { + test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name); + return -1; + } + if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) { + test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name); + return -1; + } + /* For a established key rotation test don't add a key with + * set_current = 1, as it's likely to change by peer's request; + * rather use setsockopt(TCP_AO_INFO) + */ + if (a->set_current != b->is_current) { + test_fail("getsockopt(): returned key is not Current_key"); + return -1; + } + if (a->set_rnext != b->is_rnext) { + test_fail("getsockopt(): returned key is not RNext_key"); + return -1; + } + + return 0; +} + +int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a, + const struct tcp_ao_info_opt *b) +{ + /* No check for ::current_key, as it may change by the peer */ + if (a->ao_required != b->ao_required) { + test_fail("getsockopt(): returned ao doesn't have ao_required"); + return -1; + } + if (a->accept_icmps != b->accept_icmps) { + test_fail("getsockopt(): returned ao doesn't accept ICMPs"); + return -1; + } + if (a->set_rnext && a->rnext != b->rnext) { + test_fail("getsockopt(): RNext KeyID has changed"); + return -1; + } +#define __cmp_cnt(member) \ +do { \ + if (b->member != a->member) { \ + test_fail("getsockopt(): " __stringify(member) " %llu != %llu", \ + b->member, a->member); \ + return -1; \ + } \ +} while(0) + if (a->set_counters) { + __cmp_cnt(pkt_good); + __cmp_cnt(pkt_bad); + __cmp_cnt(pkt_key_not_found); + __cmp_cnt(pkt_ao_required); + __cmp_cnt(pkt_dropped_icmp); + } +#undef __cmp_cnt + return 0; +} + +int test_get_tcp_counters(int sk, struct tcp_counters *out) +{ + struct tcp_ao_getsockopt *key_dump; + socklen_t key_dump_sz = sizeof(*key_dump); + struct tcp_ao_info_opt info = {}; + bool c1, c2, c3, c4, c5, c6, c7, c8; + struct netstat *ns; + int err, nr_keys; + + memset(out, 0, sizeof(*out)); + + /* per-netns */ + ns = netstat_read(); + out->ao.netns_ao_good = netstat_get(ns, "TCPAOGood", &c1); + out->ao.netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2); + out->ao.netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3); + out->ao.netns_ao_required = netstat_get(ns, "TCPAORequired", &c4); + out->ao.netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5); + out->netns_md5_notfound = netstat_get(ns, "TCPMD5NotFound", &c6); + out->netns_md5_unexpected = netstat_get(ns, "TCPMD5Unexpected", &c7); + out->netns_md5_failure = netstat_get(ns, "TCPMD5Failure", &c8); + netstat_free(ns); + if (c1 || c2 || c3 || c4 || c5 || c6 || c7 || c8) + return -EOPNOTSUPP; + + err = test_get_ao_info(sk, &info); + if (err == -ENOENT) + return 0; + if (err) + return err; + + /* per-socket */ + out->ao.ao_info_pkt_good = info.pkt_good; + out->ao.ao_info_pkt_bad = info.pkt_bad; + out->ao.ao_info_pkt_key_not_found = info.pkt_key_not_found; + out->ao.ao_info_pkt_ao_required = info.pkt_ao_required; + out->ao.ao_info_pkt_dropped_icmp = info.pkt_dropped_icmp; + + /* per-key */ + nr_keys = test_get_ao_keys_nr(sk); + if (nr_keys < 0) + return nr_keys; + if (nr_keys == 0) + test_error("test_get_ao_keys_nr() == 0"); + out->ao.nr_keys = (size_t)nr_keys; + key_dump = calloc(nr_keys, key_dump_sz); + if (!key_dump) + return -errno; + + key_dump[0].nkeys = nr_keys; + key_dump[0].get_all = 1; + err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, + key_dump, &key_dump_sz); + if (err) { + free(key_dump); + return -errno; + } + + out->ao.key_cnts = calloc(nr_keys, sizeof(out->ao.key_cnts[0])); + if (!out->ao.key_cnts) { + free(key_dump); + return -errno; + } + + while (nr_keys--) { + out->ao.key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid; + out->ao.key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid; + out->ao.key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good; + out->ao.key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad; + } + free(key_dump); + + return 0; +} + +test_cnt test_cmp_counters(struct tcp_counters *before, + struct tcp_counters *after) +{ +#define __cmp(cnt, e_cnt) \ +do { \ + if (before->cnt > after->cnt) \ + test_error("counter " __stringify(cnt) " decreased"); \ + if (before->cnt != after->cnt) \ + ret |= e_cnt; \ +} while (0) + + test_cnt ret = 0; + size_t i; + + if (before->ao.nr_keys != after->ao.nr_keys) + test_error("the number of keys has changed"); + + _for_each_counter(__cmp); + + i = before->ao.nr_keys; + while (i--) { + __cmp(ao.key_cnts[i].pkt_good, TEST_CNT_KEY_GOOD); + __cmp(ao.key_cnts[i].pkt_bad, TEST_CNT_KEY_BAD); + } +#undef __cmp + return ret; +} + +int test_assert_counters_sk(const char *tst_name, + struct tcp_counters *before, + struct tcp_counters *after, + test_cnt expected) +{ +#define __cmp_ao(cnt, e_cnt) \ +do { \ + if (before->cnt > after->cnt) { \ + test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \ + tst_name ?: "", before->cnt, after->cnt); \ + return -1; \ + } \ + if ((before->cnt != after->cnt) != !!(expected & e_cnt)) { \ + test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \ + tst_name ?: "", (expected & e_cnt) ? "" : "not ", \ + before->cnt, after->cnt); \ + return -1; \ + } \ +} while (0) + + errno = 0; + _for_each_counter(__cmp_ao); + return 0; +#undef __cmp_ao +} + +int test_assert_counters_key(const char *tst_name, + struct tcp_ao_counters *before, + struct tcp_ao_counters *after, + test_cnt expected, int sndid, int rcvid) +{ + size_t i; +#define __cmp_ao(i, cnt, e_cnt) \ +do { \ + if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) { \ + test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \ + tst_name ?: "", before->key_cnts[i].cnt, \ + after->key_cnts[i].cnt, \ + before->key_cnts[i].sndid, \ + before->key_cnts[i].rcvid); \ + return -1; \ + } \ + if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != !!(expected & e_cnt)) { \ + test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \ + tst_name ?: "", (expected & e_cnt) ? "" : "not ",\ + before->key_cnts[i].cnt, \ + after->key_cnts[i].cnt, \ + before->key_cnts[i].sndid, \ + before->key_cnts[i].rcvid); \ + return -1; \ + } \ +} while (0) + + if (before->nr_keys != after->nr_keys) { + test_fail("%s: Keys changed on the socket %zu != %zu", + tst_name, before->nr_keys, after->nr_keys); + return -1; + } + + /* per-key */ + i = before->nr_keys; + while (i--) { + if (sndid >= 0 && before->key_cnts[i].sndid != sndid) + continue; + if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid) + continue; + __cmp_ao(i, pkt_good, TEST_CNT_KEY_GOOD); + __cmp_ao(i, pkt_bad, TEST_CNT_KEY_BAD); + } + return 0; +#undef __cmp_ao +} + +void test_tcp_counters_free(struct tcp_counters *cnts) +{ + free(cnts->ao.key_cnts); +} + +#define TEST_BUF_SIZE 4096 +static ssize_t _test_server_run(int sk, ssize_t quota, struct tcp_counters *c, + test_cnt cond, volatile int *err, + time_t timeout_sec) +{ + ssize_t total = 0; + + do { + char buf[TEST_BUF_SIZE]; + ssize_t bytes, sent; + int ret; + + ret = __test_skpair_poll(sk, 0, timeout_sec, c, cond, err); + if (ret) + return ret; + + bytes = recv(sk, buf, sizeof(buf), 0); + + if (bytes < 0) + test_error("recv(): %zd", bytes); + if (bytes == 0) + break; + + ret = __test_skpair_poll(sk, 1, timeout_sec, c, cond, err); + if (ret) + return ret; + + sent = send(sk, buf, bytes, 0); + if (sent == 0) + break; + if (sent != bytes) + test_error("send()"); + total += bytes; + } while (!quota || total < quota); + + return total; +} + +ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec) +{ + return _test_server_run(sk, quota, NULL, 0, NULL, + timeout_sec ?: TEST_TIMEOUT_SEC); +} + +int test_skpair_server(int sk, ssize_t quota, test_cnt cond, volatile int *err) +{ + struct tcp_counters c; + ssize_t ret; + + *err = 0; + if (test_get_tcp_counters(sk, &c)) + test_error("test_get_tcp_counters()"); + synchronize_threads(); /* 1: init skpair & read nscounters */ + + ret = _test_server_run(sk, quota, &c, cond, err, TEST_TIMEOUT_SEC); + test_tcp_counters_free(&c); + return ret; +} + +static ssize_t test_client_loop(int sk, size_t buf_sz, const size_t msg_len, + struct tcp_counters *c, test_cnt cond, + volatile int *err) +{ + char msg[msg_len]; + int nodelay = 1; + char *buf; + size_t i; + + buf = alloca(buf_sz); + if (!buf) + return -ENOMEM; + randomize_buffer(buf, buf_sz); + + if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay))) + test_error("setsockopt(TCP_NODELAY)"); + + for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) { + size_t sent, bytes = min(msg_len, buf_sz - i); + int ret; + + ret = __test_skpair_poll(sk, 1, TEST_TIMEOUT_SEC, c, cond, err); + if (ret) + return ret; + + sent = send(sk, buf + i, bytes, 0); + if (sent == 0) + break; + if (sent != bytes) + test_error("send()"); + + bytes = 0; + do { + ssize_t got; + + ret = __test_skpair_poll(sk, 0, TEST_TIMEOUT_SEC, + c, cond, err); + if (ret) + return ret; + + got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0); + if (got <= 0) + return i; + bytes += got; + } while (bytes < sent); + if (bytes > sent) + test_error("recv(): %zd > %zd", bytes, sent); + if (memcmp(buf + i, msg, bytes) != 0) { + test_fail("received message differs"); + return -1; + } + } + return i; +} + +int test_client_verify(int sk, const size_t msg_len, const size_t nr) +{ + size_t buf_sz = msg_len * nr; + ssize_t ret; + + ret = test_client_loop(sk, buf_sz, msg_len, NULL, 0, NULL); + if (ret < 0) + return (int)ret; + return ret != buf_sz ? -1 : 0; +} + +int test_skpair_client(int sk, const size_t msg_len, const size_t nr, + test_cnt cond, volatile int *err) +{ + struct tcp_counters c; + size_t buf_sz = msg_len * nr; + ssize_t ret; + + *err = 0; + if (test_get_tcp_counters(sk, &c)) + test_error("test_get_tcp_counters()"); + synchronize_threads(); /* 1: init skpair & read nscounters */ + + ret = test_client_loop(sk, buf_sz, msg_len, &c, cond, err); + test_tcp_counters_free(&c); + if (ret < 0) + return (int)ret; + return ret != buf_sz ? -1 : 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/utils.c b/tools/testing/selftests/net/tcp_ao/lib/utils.c new file mode 100644 index 000000000000..bdf5522c9213 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/utils.c @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: GPL-2.0 +#include "aolib.h" +#include <string.h> + +void randomize_buffer(void *buf, size_t buflen) +{ + int *p = (int *)buf; + size_t words = buflen / sizeof(int); + size_t leftover = buflen % sizeof(int); + + if (!buflen) + return; + + while (words--) + *p++ = rand(); + + if (leftover) { + int tmp = rand(); + + memcpy(buf + buflen - leftover, &tmp, leftover); + } +} + +__printf(3, 4) int test_echo(const char *fname, bool append, + const char *fmt, ...) +{ + size_t len, written; + va_list vargs; + char *msg; + FILE *f; + + f = fopen(fname, append ? "a" : "w"); + if (!f) + return -errno; + + va_start(vargs, fmt); + msg = test_snprintf(fmt, vargs); + va_end(vargs); + if (!msg) { + fclose(f); + return -1; + } + len = strlen(msg); + written = fwrite(msg, 1, len, f); + fclose(f); + free(msg); + return written == len ? 0 : -1; +} + +const struct sockaddr_in6 addr_any6 = { + .sin6_family = AF_INET6, +}; + +const struct sockaddr_in addr_any4 = { + .sin_family = AF_INET, +}; diff --git a/tools/testing/selftests/net/tcp_ao/restore.c b/tools/testing/selftests/net/tcp_ao/restore.c new file mode 100644 index 000000000000..9a059b6c4523 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/restore.c @@ -0,0 +1,251 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Author: Dmitry Safonov <dima@arista.com> */ +/* This is over-simplified TCP_REPAIR for TCP_ESTABLISHED sockets + * It tests that TCP-AO enabled connection can be restored. + * For the proper socket repair see: + * https://github.com/checkpoint-restore/criu/blob/criu-dev/soccr/soccr.h + */ +#include <inttypes.h> +#include "aolib.h" + +const size_t nr_packets = 20; +const size_t msg_len = 100; +const size_t quota = nr_packets * msg_len; +#define fault(type) (inj == FAULT_ ## type) + +static void try_server_run(const char *tst_name, unsigned int port, + fault_t inj, test_cnt cnt_expected) +{ + test_cnt poll_cnt = (cnt_expected == TEST_CNT_GOOD) ? 0 : cnt_expected; + const char *cnt_name = "TCPAOGood"; + struct tcp_counters cnt1, cnt2; + uint64_t before_cnt, after_cnt; + int sk, lsk, dummy; + ssize_t bytes; + + if (fault(TIMEOUT)) + cnt_name = "TCPAOBad"; + lsk = test_listen_socket(this_ip_addr, port, 1); + + if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + synchronize_threads(); /* 1: MKT added => connect() */ + + if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0)) + test_error("test_wait_fd()"); + + sk = accept(lsk, NULL, NULL); + if (sk < 0) + test_error("accept()"); + + synchronize_threads(); /* 2: accepted => send data */ + close(lsk); + + bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC); + if (bytes != quota) { + test_fail("%s: server served: %zd", tst_name, bytes); + goto out; + } + + before_cnt = netstat_get_one(cnt_name, NULL); + if (test_get_tcp_counters(sk, &cnt1)) + test_error("test_get_tcp_counters()"); + + bytes = test_skpair_server(sk, quota, poll_cnt, &dummy); + if (fault(TIMEOUT)) { + if (bytes > 0) + test_fail("%s: server served: %zd", tst_name, bytes); + else + test_ok("%s: server couldn't serve", tst_name); + } else { + if (bytes != quota) + test_fail("%s: server served: %zd", tst_name, bytes); + else + test_ok("%s: server alive", tst_name); + } + synchronize_threads(); /* 3: counters checks */ + if (test_get_tcp_counters(sk, &cnt2)) + test_error("test_get_tcp_counters()"); + after_cnt = netstat_get_one(cnt_name, NULL); + + test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected); + + if (after_cnt <= before_cnt) { + test_fail("%s(server): %s counter did not increase: %" PRIu64 " <= %" PRIu64, + tst_name, cnt_name, after_cnt, before_cnt); + } else { + test_ok("%s(server): counter %s increased %" PRIu64 " => %" PRIu64, + tst_name, cnt_name, before_cnt, after_cnt); + } + + /* + * Before close() as that will send FIN and move the peer in TCP_CLOSE + * and that will prevent reading AO counters from the peer's socket. + */ + synchronize_threads(); /* 4: verified => closed */ +out: + close(sk); +} + +static void *server_fn(void *arg) +{ + unsigned int port = test_server_port; + + try_server_run("TCP-AO migrate to another socket (server)", port++, + 0, TEST_CNT_GOOD); + try_server_run("TCP-AO with wrong send ISN (server)", port++, + FAULT_TIMEOUT, TEST_CNT_BAD); + try_server_run("TCP-AO with wrong receive ISN (server)", port++, + FAULT_TIMEOUT, TEST_CNT_BAD); + try_server_run("TCP-AO with wrong send SEQ ext number (server)", port++, + FAULT_TIMEOUT, TEST_CNT_BAD); + try_server_run("TCP-AO with wrong receive SEQ ext number (server)", + port++, FAULT_TIMEOUT, TEST_CNT_NS_BAD | TEST_CNT_GOOD); + + synchronize_threads(); /* don't race to exit: client exits */ + return NULL; +} + +static void test_get_sk_checkpoint(unsigned int server_port, sockaddr_af *saddr, + struct tcp_sock_state *img, + struct tcp_ao_repair *ao_img) +{ + int sk; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + synchronize_threads(); /* 1: MKT added => connect() */ + if (test_connect_socket(sk, this_ip_dest, server_port) <= 0) + test_error("failed to connect()"); + + synchronize_threads(); /* 2: accepted => send data */ + if (test_client_verify(sk, msg_len, nr_packets)) + test_fail("pre-migrate verify failed"); + + test_enable_repair(sk); + test_sock_checkpoint(sk, img, saddr); + test_ao_checkpoint(sk, ao_img); + test_kill_sk(sk); +} + +static void test_sk_restore(const char *tst_name, unsigned int server_port, + sockaddr_af *saddr, struct tcp_sock_state *img, + struct tcp_ao_repair *ao_img, + fault_t inj, test_cnt cnt_expected) +{ + test_cnt poll_cnt = (cnt_expected == TEST_CNT_GOOD) ? 0 : cnt_expected; + const char *cnt_name = "TCPAOGood"; + struct tcp_counters cnt1, cnt2; + uint64_t before_cnt, after_cnt; + int sk, dummy; + + if (fault(TIMEOUT)) + cnt_name = "TCPAOBad"; + + before_cnt = netstat_get_one(cnt_name, NULL); + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + test_enable_repair(sk); + test_sock_restore(sk, img, saddr, this_ip_dest, server_port); + if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + test_ao_restore(sk, ao_img); + + if (test_get_tcp_counters(sk, &cnt1)) + test_error("test_get_tcp_counters()"); + + test_disable_repair(sk); + test_sock_state_free(img); + + if (test_skpair_client(sk, msg_len, nr_packets, poll_cnt, &dummy)) { + if (fault(TIMEOUT)) + test_ok("%s: post-migrate connection is broken", tst_name); + else + test_fail("%s: post-migrate connection is working", tst_name); + } else { + if (fault(TIMEOUT)) + test_fail("%s: post-migrate connection is working", tst_name); + else + test_ok("%s: post-migrate connection is alive", tst_name); + } + + synchronize_threads(); /* 3: counters checks */ + if (test_get_tcp_counters(sk, &cnt2)) + test_error("test_get_tcp_counters()"); + after_cnt = netstat_get_one(cnt_name, NULL); + + test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected); + + if (after_cnt <= before_cnt) { + test_fail("%s: %s counter did not increase: %" PRIu64 " <= %" PRIu64, + tst_name, cnt_name, after_cnt, before_cnt); + } else { + test_ok("%s: counter %s increased %" PRIu64 " => %" PRIu64, + tst_name, cnt_name, before_cnt, after_cnt); + } + synchronize_threads(); /* 4: verified => closed */ + close(sk); +} + +static void *client_fn(void *arg) +{ + unsigned int port = test_server_port; + struct tcp_sock_state tcp_img; + struct tcp_ao_repair ao_img; + sockaddr_af saddr; + + test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img); + test_sk_restore("TCP-AO migrate to another socket (client)", port++, + &saddr, &tcp_img, &ao_img, 0, TEST_CNT_GOOD); + + test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img); + ao_img.snt_isn += 1; + trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_addr, this_ip_dest, + -1, port, 0, -1, -1, -1, -1, -1, 100, 100, -1); + trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_dest, this_ip_addr, + port, -1, 0, -1, -1, -1, -1, -1, 100, 100, -1); + test_sk_restore("TCP-AO with wrong send ISN (client)", port++, + &saddr, &tcp_img, &ao_img, FAULT_TIMEOUT, TEST_CNT_BAD); + + test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img); + ao_img.rcv_isn += 1; + trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_addr, this_ip_dest, + -1, port, 0, -1, -1, -1, -1, -1, 100, 100, -1); + trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_dest, this_ip_addr, + port, -1, 0, -1, -1, -1, -1, -1, 100, 100, -1); + test_sk_restore("TCP-AO with wrong receive ISN (client)", port++, + &saddr, &tcp_img, &ao_img, FAULT_TIMEOUT, TEST_CNT_BAD); + + test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img); + ao_img.snd_sne += 1; + trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_addr, this_ip_dest, + -1, port, 0, -1, -1, -1, -1, -1, 100, 100, -1); + /* not expecting server => client mismatches as only snd sne is broken */ + test_sk_restore("TCP-AO with wrong send SEQ ext number (client)", + port++, &saddr, &tcp_img, &ao_img, FAULT_TIMEOUT, + TEST_CNT_NS_BAD | TEST_CNT_GOOD); + + test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img); + ao_img.rcv_sne += 1; + /* not expecting client => server mismatches as only rcv sne is broken */ + trace_ao_event_expect(TCP_AO_MISMATCH, this_ip_dest, this_ip_addr, + port, -1, 0, -1, -1, -1, -1, -1, 100, 100, -1); + test_sk_restore("TCP-AO with wrong receive SEQ ext number (client)", + port++, &saddr, &tcp_img, &ao_img, FAULT_TIMEOUT, + TEST_CNT_NS_GOOD | TEST_CNT_BAD); + + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(21, server_fn, client_fn); + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/rst.c b/tools/testing/selftests/net/tcp_ao/rst.c new file mode 100644 index 000000000000..883cddf377cf --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/rst.c @@ -0,0 +1,459 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * The test checks that both active and passive reset have correct TCP-AO + * signature. An "active" reset (abort) here is procured from closing + * listen() socket with non-accepted connections in the queue: + * inet_csk_listen_stop() => inet_child_forget() => + * => tcp_disconnect() => tcp_send_active_reset() + * + * The passive reset is quite hard to get on established TCP connections. + * It could be procured from non-established states, but the synchronization + * part from userspace in order to reliably get RST seems uneasy. + * So, instead it's procured by corrupting SEQ number on TIMED-WAIT state. + * + * It's important to test both passive and active RST as they go through + * different code-paths: + * - tcp_send_active_reset() makes no-data skb, sends it with tcp_transmit_skb() + * - tcp_v*_send_reset() create their reply skbs and send them with + * ip_send_unicast_reply() + * + * In both cases TCP-AO signatures have to be correct, which is verified by + * (1) checking that the TCP-AO connection was reset and (2) TCP-AO counters. + * + * Author: Dmitry Safonov <dima@arista.com> + */ +#include <inttypes.h> +#include "../../../../include/linux/kernel.h" +#include "aolib.h" + +const size_t quota = 1000; +const size_t packet_sz = 100; +/* + * Backlog == 0 means 1 connection in queue, see: + * commit 64a146513f8f ("[NET]: Revert incorrect accept queue...") + */ +const unsigned int backlog; + +static void netstats_check(struct netstat *before, struct netstat *after, + char *msg) +{ + uint64_t before_cnt, after_cnt; + + before_cnt = netstat_get(before, "TCPAORequired", NULL); + after_cnt = netstat_get(after, "TCPAORequired", NULL); + if (after_cnt > before_cnt) + test_fail("Segments without AO sign (%s): %" PRIu64 " => %" PRIu64, + msg, before_cnt, after_cnt); + else + test_ok("No segments without AO sign (%s)", msg); + + before_cnt = netstat_get(before, "TCPAOGood", NULL); + after_cnt = netstat_get(after, "TCPAOGood", NULL); + if (after_cnt <= before_cnt) + test_fail("Signed AO segments (%s): %" PRIu64 " => %" PRIu64, + msg, before_cnt, after_cnt); + else + test_ok("Signed AO segments (%s): %" PRIu64 " => %" PRIu64, + msg, before_cnt, after_cnt); + + before_cnt = netstat_get(before, "TCPAOBad", NULL); + after_cnt = netstat_get(after, "TCPAOBad", NULL); + if (after_cnt > before_cnt) + test_fail("Segments with bad AO sign (%s): %" PRIu64 " => %" PRIu64, + msg, before_cnt, after_cnt); + else + test_ok("No segments with bad AO sign (%s)", msg); +} + +/* + * Another way to send RST, but not through tcp_v{4,6}_send_reset() + * is tcp_send_active_reset(), that is not in reply to inbound segment, + * but rather active send. It uses tcp_transmit_skb(), so that should + * work, but as it also sends RST - nice that it can be covered as well. + */ +static void close_forced(int sk) +{ + struct linger sl; + + sl.l_onoff = 1; + sl.l_linger = 0; + if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &sl, sizeof(sl))) + test_error("setsockopt(SO_LINGER)"); + close(sk); +} + +static void test_server_active_rst(unsigned int port) +{ + struct tcp_counters cnt1, cnt2; + ssize_t bytes; + int sk, lsk; + + lsk = test_listen_socket(this_ip_addr, port, backlog); + if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + if (test_get_tcp_counters(lsk, &cnt1)) + test_error("test_get_tcp_counters()"); + + synchronize_threads(); /* 1: MKT added */ + if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0)) + test_error("test_wait_fd()"); + + sk = accept(lsk, NULL, NULL); + if (sk < 0) + test_error("accept()"); + + synchronize_threads(); /* 2: connection accept()ed, another queued */ + if (test_get_tcp_counters(lsk, &cnt2)) + test_error("test_get_tcp_counters()"); + + synchronize_threads(); /* 3: close listen socket */ + close(lsk); + bytes = test_server_run(sk, quota, 0); + if (bytes != quota) + test_error("servered only %zd bytes", bytes); + else + test_ok("servered %zd bytes", bytes); + + synchronize_threads(); /* 4: finishing up */ + close_forced(sk); + + synchronize_threads(); /* 5: closed active sk */ + + synchronize_threads(); /* 6: counters checks */ + if (test_assert_counters("active RST server", &cnt1, &cnt2, TEST_CNT_GOOD)) + test_fail("MKT counters (server) have not only good packets"); + else + test_ok("MKT counters are good on server"); +} + +static void test_server_passive_rst(unsigned int port) +{ + struct tcp_counters cnt1, cnt2; + int sk, lsk; + ssize_t bytes; + + lsk = test_listen_socket(this_ip_addr, port, 1); + + if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + synchronize_threads(); /* 1: MKT added => connect() */ + if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0)) + test_error("test_wait_fd()"); + + sk = accept(lsk, NULL, NULL); + if (sk < 0) + test_error("accept()"); + + synchronize_threads(); /* 2: accepted => send data */ + close(lsk); + if (test_get_tcp_counters(sk, &cnt1)) + test_error("test_get_tcp_counters()"); + + bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC); + if (bytes != quota) { + if (bytes > 0) + test_fail("server served: %zd", bytes); + else + test_fail("server returned %zd", bytes); + } + + synchronize_threads(); /* 3: checkpoint the client */ + synchronize_threads(); /* 4: close the server, creating twsk */ + if (test_get_tcp_counters(sk, &cnt2)) + test_error("test_get_tcp_counters()"); + close(sk); + + synchronize_threads(); /* 5: restore the socket, send more data */ + test_assert_counters("passive RST server", &cnt1, &cnt2, TEST_CNT_GOOD); + + synchronize_threads(); /* 6: server exits */ +} + +static void *server_fn(void *arg) +{ + struct netstat *ns_before, *ns_after; + unsigned int port = test_server_port; + + ns_before = netstat_read(); + + test_server_active_rst(port++); + test_server_passive_rst(port++); + + ns_after = netstat_read(); + netstats_check(ns_before, ns_after, "server"); + netstat_free(ns_after); + netstat_free(ns_before); + synchronize_threads(); /* exit */ + + synchronize_threads(); /* don't race to exit() - client exits */ + return NULL; +} + +static int test_wait_fds(int sk[], size_t nr, bool is_writable[], + ssize_t wait_for, time_t sec) +{ + struct timeval tv = { .tv_sec = sec }; + struct timeval *ptv = NULL; + fd_set left; + size_t i; + int ret; + + FD_ZERO(&left); + for (i = 0; i < nr; i++) { + FD_SET(sk[i], &left); + if (is_writable) + is_writable[i] = false; + } + + if (sec) + ptv = &tv; + + do { + bool is_empty = true; + fd_set fds, efds; + int nfd = 0; + + FD_ZERO(&fds); + FD_ZERO(&efds); + for (i = 0; i < nr; i++) { + if (!FD_ISSET(sk[i], &left)) + continue; + + if (sk[i] > nfd) + nfd = sk[i]; + + FD_SET(sk[i], &fds); + FD_SET(sk[i], &efds); + is_empty = false; + } + if (is_empty) + return -ENOENT; + + errno = 0; + ret = select(nfd + 1, NULL, &fds, &efds, ptv); + if (ret < 0) + return -errno; + if (!ret) + return -ETIMEDOUT; + for (i = 0; i < nr; i++) { + if (FD_ISSET(sk[i], &fds)) { + if (is_writable) + is_writable[i] = true; + FD_CLR(sk[i], &left); + wait_for--; + continue; + } + if (FD_ISSET(sk[i], &efds)) { + FD_CLR(sk[i], &left); + wait_for--; + } + } + } while (wait_for > 0); + + return 0; +} + +static void test_client_active_rst(unsigned int port) +{ + int i, sk[3], err; + bool is_writable[ARRAY_SIZE(sk)] = {false}; + unsigned int last = ARRAY_SIZE(sk) - 1; + + for (i = 0; i < ARRAY_SIZE(sk); i++) { + sk[i] = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk[i] < 0) + test_error("socket()"); + if (test_add_key(sk[i], DEFAULT_TEST_PASSWORD, + this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + } + + synchronize_threads(); /* 1: MKT added */ + for (i = 0; i < last; i++) { + err = _test_connect_socket(sk[i], this_ip_dest, port, i != 0); + if (err < 0) + test_error("failed to connect()"); + } + + synchronize_threads(); /* 2: two connections: one accept()ed, another queued */ + err = test_wait_fds(sk, last, is_writable, last, TEST_TIMEOUT_SEC); + if (err < 0) + test_error("test_wait_fds(): %d", err); + + /* async connect() with third sk to get into request_sock_queue */ + err = _test_connect_socket(sk[last], this_ip_dest, port, 1); + if (err < 0) + test_error("failed to connect()"); + + synchronize_threads(); /* 3: close listen socket */ + if (test_client_verify(sk[0], packet_sz, quota / packet_sz)) + test_fail("Failed to send data on connected socket"); + else + test_ok("Verified established tcp connection"); + + synchronize_threads(); /* 4: finishing up */ + + synchronize_threads(); /* 5: closed active sk */ + /* + * Wait for 2 connections: one accepted, another in the accept queue, + * the one in request_sock_queue won't get fully established, so + * doesn't receive an active RST, see inet_csk_listen_stop(). + */ + err = test_wait_fds(sk, last, NULL, last, TEST_TIMEOUT_SEC); + if (err < 0) + test_error("select(): %d", err); + + for (i = 0; i < ARRAY_SIZE(sk); i++) { + socklen_t slen = sizeof(err); + + if (getsockopt(sk[i], SOL_SOCKET, SO_ERROR, &err, &slen)) + test_error("getsockopt()"); + if (is_writable[i] && err != ECONNRESET) { + test_fail("sk[%d] = %d, err = %d, connection wasn't reset", + i, sk[i], err); + } else { + test_ok("sk[%d] = %d%s", i, sk[i], + is_writable[i] ? ", connection was reset" : ""); + } + } + synchronize_threads(); /* 6: counters checks */ +} + +static void test_client_passive_rst(unsigned int port) +{ + struct tcp_counters cnt1, cnt2; + struct tcp_ao_repair ao_img; + struct tcp_sock_state img; + sockaddr_af saddr; + int sk, err; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + synchronize_threads(); /* 1: MKT added => connect() */ + if (test_connect_socket(sk, this_ip_dest, port) <= 0) + test_error("failed to connect()"); + + synchronize_threads(); /* 2: accepted => send data */ + if (test_client_verify(sk, packet_sz, quota / packet_sz)) + test_fail("Failed to send data on connected socket"); + else + test_ok("Verified established tcp connection"); + + synchronize_threads(); /* 3: checkpoint the client */ + test_enable_repair(sk); + test_sock_checkpoint(sk, &img, &saddr); + test_ao_checkpoint(sk, &ao_img); + test_disable_repair(sk); + + synchronize_threads(); /* 4: close the server, creating twsk */ + + /* + * The "corruption" in SEQ has to be small enough to fit into TCP + * window, see tcp_timewait_state_process() for out-of-window + * segments. + */ + img.out.seq += 5; /* 5 is more noticeable in tcpdump than 1 */ + + /* + * FIXME: This is kind-of ugly and dirty, but it works. + * + * At this moment, the server has close'ed(sk). + * The passive RST that is being targeted here is new data after + * half-duplex close, see tcp_timewait_state_process() => TCP_TW_RST + * + * What is needed here is: + * (1) wait for FIN from the server + * (2) make sure that the ACK from the client went out + * (3) make sure that the ACK was received and processed by the server + * + * Otherwise, the data that will be sent from "repaired" socket + * post SEQ corruption may get to the server before it's in + * TCP_FIN_WAIT2. + * + * (1) is easy with select()/poll() + * (2) is possible by polling tcpi_state from TCP_INFO + * (3) is quite complex: as server's socket was already closed, + * probably the way to do it would be tcp-diag. + */ + sleep(TEST_RETRANSMIT_SEC); + + synchronize_threads(); /* 5: restore the socket, send more data */ + test_kill_sk(sk); + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + test_enable_repair(sk); + test_sock_restore(sk, &img, &saddr, this_ip_dest, port); + if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + test_ao_restore(sk, &ao_img); + + if (test_get_tcp_counters(sk, &cnt1)) + test_error("test_get_tcp_counters()"); + + test_disable_repair(sk); + test_sock_state_free(&img); + + /* + * This is how "passive reset" is acquired in this test from TCP_TW_RST: + * + * IP 10.0.254.1.7011 > 10.0.1.1.59772: Flags [P.], seq 901:1001, ack 1001, win 249, + * options [tcp-ao keyid 100 rnextkeyid 100 mac 0x10217d6c36a22379086ef3b1], length 100 + * IP 10.0.254.1.7011 > 10.0.1.1.59772: Flags [F.], seq 1001, ack 1001, win 249, + * options [tcp-ao keyid 100 rnextkeyid 100 mac 0x104ffc99b98c10a5298cc268], length 0 + * IP 10.0.1.1.59772 > 10.0.254.1.7011: Flags [.], ack 1002, win 251, + * options [tcp-ao keyid 100 rnextkeyid 100 mac 0xe496dd4f7f5a8a66873c6f93,nop,nop,sack 1 {1001:1002}], length 0 + * IP 10.0.1.1.59772 > 10.0.254.1.7011: Flags [P.], seq 1006:1106, ack 1001, win 251, + * options [tcp-ao keyid 100 rnextkeyid 100 mac 0x1b5f3330fb23fbcd0c77d0ca], length 100 + * IP 10.0.254.1.7011 > 10.0.1.1.59772: Flags [R], seq 3215596252, win 0, + * options [tcp-ao keyid 100 rnextkeyid 100 mac 0x0bcfbbf497bce844312304b2], length 0 + */ + err = test_client_verify(sk, packet_sz, quota / packet_sz); + /* Make sure that the connection was reset, not timeouted */ + if (err && err == -ECONNRESET) + test_ok("client sock was passively reset post-seq-adjust"); + else if (err) + test_fail("client sock was not reset post-seq-adjust: %d", err); + else + test_fail("client sock is yet connected post-seq-adjust"); + + if (test_get_tcp_counters(sk, &cnt2)) + test_error("test_get_tcp_counters()"); + + synchronize_threads(); /* 6: server exits */ + close(sk); + test_assert_counters("client passive RST", &cnt1, &cnt2, TEST_CNT_GOOD); +} + +static void *client_fn(void *arg) +{ + struct netstat *ns_before, *ns_after; + unsigned int port = test_server_port; + + ns_before = netstat_read(); + + test_client_active_rst(port++); + test_client_passive_rst(port++); + + ns_after = netstat_read(); + netstats_check(ns_before, ns_after, "client"); + netstat_free(ns_after); + netstat_free(ns_before); + + synchronize_threads(); /* exit */ + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(15, server_fn, client_fn); + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/self-connect.c b/tools/testing/selftests/net/tcp_ao/self-connect.c new file mode 100644 index 000000000000..2c73bea698a6 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/self-connect.c @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Author: Dmitry Safonov <dima@arista.com> */ +#include <inttypes.h> +#include "aolib.h" + +static union tcp_addr local_addr; + +static void __setup_lo_intf(const char *lo_intf, + const char *addr_str, uint8_t prefix) +{ + if (inet_pton(TEST_FAMILY, addr_str, &local_addr) != 1) + test_error("Can't convert local ip address"); + + if (ip_addr_add(lo_intf, TEST_FAMILY, local_addr, prefix)) + test_error("Failed to add %s ip address", lo_intf); + + if (link_set_up(lo_intf)) + test_error("Failed to bring %s up", lo_intf); + + if (ip_route_add(lo_intf, TEST_FAMILY, local_addr, local_addr)) + test_error("Failed to add a local route %s", lo_intf); +} + +static void setup_lo_intf(const char *lo_intf) +{ +#ifdef IPV6_TEST + __setup_lo_intf(lo_intf, "::1", 128); +#else + __setup_lo_intf(lo_intf, "127.0.0.1", 8); +#endif +} + +static void tcp_self_connect(const char *tst, unsigned int port, + bool different_keyids, bool check_restore) +{ + struct tcp_counters before, after; + uint64_t before_aogood, after_aogood; + struct netstat *ns_before, *ns_after; + const size_t nr_packets = 20; + struct tcp_ao_repair ao_img; + struct tcp_sock_state img; + sockaddr_af addr; + int sk; + + tcp_addr_to_sockaddr_in(&addr, &local_addr, htons(port)); + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + if (different_keyids) { + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 5, 7)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 7, 5)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + } else { + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, local_addr, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + } + + if (bind(sk, (struct sockaddr *)&addr, sizeof(addr)) < 0) + test_error("bind()"); + + ns_before = netstat_read(); + before_aogood = netstat_get(ns_before, "TCPAOGood", NULL); + if (test_get_tcp_counters(sk, &before)) + test_error("test_get_tcp_counters()"); + + if (__test_connect_socket(sk, "lo", (struct sockaddr *)&addr, + sizeof(addr), 0) < 0) { + ns_after = netstat_read(); + netstat_print_diff(ns_before, ns_after); + test_error("failed to connect()"); + } + + if (test_client_verify(sk, 100, nr_packets)) { + test_fail("%s: tcp connection verify failed", tst); + close(sk); + return; + } + + ns_after = netstat_read(); + after_aogood = netstat_get(ns_after, "TCPAOGood", NULL); + if (test_get_tcp_counters(sk, &after)) + test_error("test_get_tcp_counters()"); + if (!check_restore) { + /* to debug: netstat_print_diff(ns_before, ns_after); */ + netstat_free(ns_before); + } + netstat_free(ns_after); + + if (after_aogood <= before_aogood) { + test_fail("%s: TCPAOGood counter mismatch: %" PRIu64 " <= %" PRIu64, + tst, after_aogood, before_aogood); + close(sk); + return; + } + + if (test_assert_counters(tst, &before, &after, TEST_CNT_GOOD)) { + close(sk); + return; + } + + if (!check_restore) { + test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64, + tst, before_aogood, after_aogood); + close(sk); + return; + } + + test_enable_repair(sk); + test_sock_checkpoint(sk, &img, &addr); +#ifdef IPV6_TEST + addr.sin6_port = htons(port + 1); +#else + addr.sin_port = htons(port + 1); +#endif + test_ao_checkpoint(sk, &ao_img); + test_kill_sk(sk); + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + test_enable_repair(sk); + __test_sock_restore(sk, "lo", &img, &addr, &addr, sizeof(addr)); + if (different_keyids) { + if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, + local_addr, -1, 7, 5)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, + local_addr, -1, 5, 7)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + } else { + if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, + local_addr, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + } + test_ao_restore(sk, &ao_img); + test_disable_repair(sk); + test_sock_state_free(&img); + if (test_client_verify(sk, 100, nr_packets)) { + test_fail("%s: tcp connection verify failed", tst); + close(sk); + return; + } + ns_after = netstat_read(); + after_aogood = netstat_get(ns_after, "TCPAOGood", NULL); + /* to debug: netstat_print_diff(ns_before, ns_after); */ + netstat_free(ns_before); + netstat_free(ns_after); + close(sk); + if (after_aogood <= before_aogood) { + test_fail("%s: TCPAOGood counter mismatch: %" PRIu64 " <= %" PRIu64, + tst, after_aogood, before_aogood); + return; + } + test_ok("%s: connect TCPAOGood %" PRIu64 " => %" PRIu64, + tst, before_aogood, after_aogood); +} + +static void *client_fn(void *arg) +{ + unsigned int port = test_server_port; + + setup_lo_intf("lo"); + + tcp_self_connect("self-connect(same keyids)", port++, false, false); + + /* expecting rnext to change based on the first segment RNext != Current */ + trace_ao_event_expect(TCP_AO_RNEXT_REQUEST, local_addr, local_addr, + port, port, 0, -1, -1, -1, -1, -1, 7, 5, -1); + tcp_self_connect("self-connect(different keyids)", port++, true, false); + tcp_self_connect("self-connect(restore)", port, false, true); + port += 2; /* restore test restores over different port */ + trace_ao_event_expect(TCP_AO_RNEXT_REQUEST, local_addr, local_addr, + port, port, 0, -1, -1, -1, -1, -1, 7, 5, -1); + /* intentionally on restore they are added to the socket in different order */ + trace_ao_event_expect(TCP_AO_RNEXT_REQUEST, local_addr, local_addr, + port + 1, port + 1, 0, -1, -1, -1, -1, -1, 5, 7, -1); + tcp_self_connect("self-connect(restore, different keyids)", port, true, true); + port += 2; /* restore test restores over different port */ + + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(5, client_fn, NULL); + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/seq-ext.c b/tools/testing/selftests/net/tcp_ao/seq-ext.c new file mode 100644 index 000000000000..f00245263b20 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/seq-ext.c @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Check that after SEQ number wrap-around: + * 1. SEQ-extension has upper bytes set + * 2. TCP conneciton is alive and no TCPAOBad segments + * In order to test (2), the test doesn't just adjust seq number for a queue + * on a connected socket, but migrates it to another sk+port number, so + * that there won't be any delayed packets that will fail to verify + * with the new SEQ numbers. + */ +#include <inttypes.h> +#include "aolib.h" + +const unsigned int nr_packets = 1000; +const unsigned int msg_len = 1000; +const unsigned int quota = nr_packets * msg_len; +unsigned int client_new_port; + +/* Move them closer to roll-over */ +static void test_adjust_seqs(struct tcp_sock_state *img, + struct tcp_ao_repair *ao_img, + bool server) +{ + uint32_t new_seq1, new_seq2; + + /* make them roll-over during quota, but on different segments */ + if (server) { + new_seq1 = ((uint32_t)-1) - msg_len; + new_seq2 = ((uint32_t)-1) - (quota - 2 * msg_len); + } else { + new_seq1 = ((uint32_t)-1) - (quota - 2 * msg_len); + new_seq2 = ((uint32_t)-1) - msg_len; + } + + img->in.seq = new_seq1; + img->trw.snd_wl1 = img->in.seq - msg_len; + img->out.seq = new_seq2; + img->trw.rcv_wup = img->in.seq; +} + +static int test_sk_restore(struct tcp_sock_state *img, + struct tcp_ao_repair *ao_img, sockaddr_af *saddr, + const union tcp_addr daddr, unsigned int dport, + struct tcp_counters *cnt) +{ + int sk; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + test_enable_repair(sk); + test_sock_restore(sk, img, saddr, daddr, dport); + if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, daddr, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + test_ao_restore(sk, ao_img); + + if (test_get_tcp_counters(sk, cnt)) + test_error("test_get_tcp_counters()"); + + test_disable_repair(sk); + test_sock_state_free(img); + return sk; +} + +static void *server_fn(void *arg) +{ + uint64_t before_good, after_good, after_bad; + struct tcp_counters cnt1, cnt2; + struct tcp_sock_state img; + struct tcp_ao_repair ao_img; + sockaddr_af saddr; + ssize_t bytes; + int sk, lsk; + + lsk = test_listen_socket(this_ip_addr, test_server_port, 1); + + if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + synchronize_threads(); /* 1: MKT added => connect() */ + + if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0)) + test_error("test_wait_fd()"); + + sk = accept(lsk, NULL, NULL); + if (sk < 0) + test_error("accept()"); + + synchronize_threads(); /* 2: accepted => send data */ + close(lsk); + + bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC); + if (bytes != quota) { + if (bytes > 0) + test_fail("server served: %zd", bytes); + else + test_fail("server returned: %zd", bytes); + goto out; + } + + before_good = netstat_get_one("TCPAOGood", NULL); + + synchronize_threads(); /* 3: restore the connection on another port */ + + test_enable_repair(sk); + test_sock_checkpoint(sk, &img, &saddr); + test_ao_checkpoint(sk, &ao_img); + test_kill_sk(sk); +#ifdef IPV6_TEST + saddr.sin6_port = htons(ntohs(saddr.sin6_port) + 1); +#else + saddr.sin_port = htons(ntohs(saddr.sin_port) + 1); +#endif + test_adjust_seqs(&img, &ao_img, true); + synchronize_threads(); /* 4: dump finished */ + sk = test_sk_restore(&img, &ao_img, &saddr, this_ip_dest, + client_new_port, &cnt1); + + trace_ao_event_sne_expect(TCP_AO_SND_SNE_UPDATE, this_ip_addr, + this_ip_dest, test_server_port + 1, client_new_port, 1); + trace_ao_event_sne_expect(TCP_AO_SND_SNE_UPDATE, this_ip_dest, + this_ip_addr, client_new_port, test_server_port + 1, 1); + trace_ao_event_sne_expect(TCP_AO_RCV_SNE_UPDATE, this_ip_addr, + this_ip_dest, test_server_port + 1, client_new_port, 1); + trace_ao_event_sne_expect(TCP_AO_RCV_SNE_UPDATE, this_ip_dest, + this_ip_addr, client_new_port, test_server_port + 1, 1); + synchronize_threads(); /* 5: verify the connection during SEQ-number rollover */ + bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC); + if (bytes != quota) { + if (bytes > 0) + test_fail("server served: %zd", bytes); + else + test_fail("server returned: %zd", bytes); + } else { + test_ok("server alive"); + } + + synchronize_threads(); /* 6: verify counters after SEQ-number rollover */ + if (test_get_tcp_counters(sk, &cnt2)) + test_error("test_get_tcp_counters()"); + after_good = netstat_get_one("TCPAOGood", NULL); + + test_assert_counters(NULL, &cnt1, &cnt2, TEST_CNT_GOOD); + + if (after_good <= before_good) { + test_fail("TCPAOGood counter did not increase: %" PRIu64 " <= %" PRIu64, + after_good, before_good); + } else { + test_ok("TCPAOGood counter increased %" PRIu64 " => %" PRIu64, + before_good, after_good); + } + after_bad = netstat_get_one("TCPAOBad", NULL); + if (after_bad) + test_fail("TCPAOBad counter is non-zero: %" PRIu64, after_bad); + else + test_ok("TCPAOBad counter didn't increase"); + test_enable_repair(sk); + test_ao_checkpoint(sk, &ao_img); + if (ao_img.snd_sne && ao_img.rcv_sne) { + test_ok("SEQ extension incremented: %u/%u", + ao_img.snd_sne, ao_img.rcv_sne); + } else { + test_fail("SEQ extension was not incremented: %u/%u", + ao_img.snd_sne, ao_img.rcv_sne); + } + + synchronize_threads(); /* 6: verified => closed */ +out: + close(sk); + return NULL; +} + +static void *client_fn(void *arg) +{ + uint64_t before_good, after_good, after_bad; + struct tcp_counters cnt1, cnt2; + struct tcp_sock_state img; + struct tcp_ao_repair ao_img; + sockaddr_af saddr; + int sk; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + synchronize_threads(); /* 1: MKT added => connect() */ + if (test_connect_socket(sk, this_ip_dest, test_server_port) <= 0) + test_error("failed to connect()"); + + synchronize_threads(); /* 2: accepted => send data */ + if (test_client_verify(sk, msg_len, nr_packets)) { + test_fail("pre-migrate verify failed"); + return NULL; + } + + before_good = netstat_get_one("TCPAOGood", NULL); + + synchronize_threads(); /* 3: restore the connection on another port */ + test_enable_repair(sk); + test_sock_checkpoint(sk, &img, &saddr); + test_ao_checkpoint(sk, &ao_img); + test_kill_sk(sk); +#ifdef IPV6_TEST + client_new_port = ntohs(saddr.sin6_port) + 1; + saddr.sin6_port = htons(ntohs(saddr.sin6_port) + 1); +#else + client_new_port = ntohs(saddr.sin_port) + 1; + saddr.sin_port = htons(ntohs(saddr.sin_port) + 1); +#endif + test_adjust_seqs(&img, &ao_img, false); + synchronize_threads(); /* 4: dump finished */ + sk = test_sk_restore(&img, &ao_img, &saddr, this_ip_dest, + test_server_port + 1, &cnt1); + + synchronize_threads(); /* 5: verify the connection during SEQ-number rollover */ + if (test_client_verify(sk, msg_len, nr_packets)) + test_fail("post-migrate verify failed"); + else + test_ok("post-migrate connection alive"); + + synchronize_threads(); /* 5: verify counters after SEQ-number rollover */ + if (test_get_tcp_counters(sk, &cnt2)) + test_error("test_get_tcp_counters()"); + after_good = netstat_get_one("TCPAOGood", NULL); + + test_assert_counters(NULL, &cnt1, &cnt2, TEST_CNT_GOOD); + + if (after_good <= before_good) { + test_fail("TCPAOGood counter did not increase: %" PRIu64 " <= %" PRIu64, + after_good, before_good); + } else { + test_ok("TCPAOGood counter increased %" PRIu64 " => %" PRIu64, + before_good, after_good); + } + after_bad = netstat_get_one("TCPAOBad", NULL); + if (after_bad) + test_fail("TCPAOBad counter is non-zero: %" PRIu64, after_bad); + else + test_ok("TCPAOBad counter didn't increase"); + + synchronize_threads(); /* 6: verified => closed */ + close(sk); + + synchronize_threads(); /* don't race to exit: let server exit() */ + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(8, server_fn, client_fn); + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/setsockopt-closed.c b/tools/testing/selftests/net/tcp_ao/setsockopt-closed.c new file mode 100644 index 000000000000..0abb9807d742 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/setsockopt-closed.c @@ -0,0 +1,1011 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Author: Dmitry Safonov <dima@arista.com> */ +#include <inttypes.h> +#include "../../../../include/linux/kernel.h" +#include "aolib.h" + +static union tcp_addr tcp_md5_client; + +#define FILTER_TEST_NKEYS 16 + +static int test_port = 7788; +static void make_listen(int sk) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &this_ip_addr, htons(test_port++)); + if (bind(sk, (struct sockaddr *)&addr, sizeof(addr)) < 0) + test_error("bind()"); + if (listen(sk, 1)) + test_error("listen()"); +} + +static void test_vefify_ao_info(int sk, struct tcp_ao_info_opt *info, + const char *tst) +{ + struct tcp_ao_info_opt tmp = {}; + socklen_t len = sizeof(tmp); + + if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, &tmp, &len)) + test_error("getsockopt(TCP_AO_INFO) failed"); + +#define __cmp_ao(member) \ +do { \ + if (info->member != tmp.member) { \ + test_fail("%s: getsockopt(): " __stringify(member) " %" PRIu64 " != %" PRIu64, \ + tst, (uint64_t)info->member, (uint64_t)tmp.member); \ + return; \ + } \ +} while(0) + if (info->set_current) + __cmp_ao(current_key); + if (info->set_rnext) + __cmp_ao(rnext); + if (info->set_counters) { + __cmp_ao(pkt_good); + __cmp_ao(pkt_bad); + __cmp_ao(pkt_key_not_found); + __cmp_ao(pkt_ao_required); + __cmp_ao(pkt_dropped_icmp); + } + __cmp_ao(ao_required); + __cmp_ao(accept_icmps); + + test_ok("AO info get: %s", tst); +#undef __cmp_ao +} + +static void __setsockopt_checked(int sk, int optname, bool get, + void *optval, socklen_t *len, + int err, const char *tst, const char *tst2) +{ + int ret; + + if (!tst) + tst = ""; + if (!tst2) + tst2 = ""; + + errno = 0; + if (get) + ret = getsockopt(sk, IPPROTO_TCP, optname, optval, len); + else + ret = setsockopt(sk, IPPROTO_TCP, optname, optval, *len); + if (ret == -1) { + if (errno == err) + test_ok("%s%s", tst ?: "", tst2 ?: ""); + else + test_fail("%s%s: %setsockopt() failed", + tst, tst2, get ? "g" : "s"); + close(sk); + return; + } + + if (err) { + test_fail("%s%s: %setsockopt() was expected to fail with %d", + tst, tst2, get ? "g" : "s", err); + } else { + test_ok("%s%s", tst ?: "", tst2 ?: ""); + if (optname == TCP_AO_ADD_KEY) { + test_verify_socket_key(sk, optval); + } else if (optname == TCP_AO_INFO && !get) { + test_vefify_ao_info(sk, optval, tst2); + } else if (optname == TCP_AO_GET_KEYS) { + if (*len != sizeof(struct tcp_ao_getsockopt)) + test_fail("%s%s: get keys returned wrong tcp_ao_getsockopt size", + tst, tst2); + } + } + close(sk); +} + +static void setsockopt_checked(int sk, int optname, void *optval, + int err, const char *tst) +{ + const char *cmd = NULL; + socklen_t len; + + switch (optname) { + case TCP_AO_ADD_KEY: + cmd = "key add: "; + len = sizeof(struct tcp_ao_add); + break; + case TCP_AO_DEL_KEY: + cmd = "key del: "; + len = sizeof(struct tcp_ao_del); + break; + case TCP_AO_INFO: + cmd = "AO info set: "; + len = sizeof(struct tcp_ao_info_opt); + break; + default: + break; + } + + __setsockopt_checked(sk, optname, false, optval, &len, err, cmd, tst); +} + +static int prepare_defs(int cmd, void *optval) +{ + int sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + + if (sk < 0) + test_error("socket()"); + + switch (cmd) { + case TCP_AO_ADD_KEY: { + struct tcp_ao_add *add = optval; + + if (test_prepare_def_key(add, DEFAULT_TEST_PASSWORD, 0, this_ip_dest, + -1, 0, 100, 100)) + test_error("prepare default tcp_ao_add"); + break; + } + case TCP_AO_DEL_KEY: { + struct tcp_ao_del *del = optval; + + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, + DEFAULT_TEST_PREFIX, 100, 100)) + test_error("add default key"); + memset(del, 0, sizeof(struct tcp_ao_del)); + del->sndid = 100; + del->rcvid = 100; + del->prefix = DEFAULT_TEST_PREFIX; + tcp_addr_to_sockaddr_in(&del->addr, &this_ip_dest, 0); + break; + } + case TCP_AO_INFO: { + struct tcp_ao_info_opt *info = optval; + + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, + DEFAULT_TEST_PREFIX, 100, 100)) + test_error("add default key"); + memset(info, 0, sizeof(struct tcp_ao_info_opt)); + break; + } + case TCP_AO_GET_KEYS: { + struct tcp_ao_getsockopt *get = optval; + + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, + DEFAULT_TEST_PREFIX, 100, 100)) + test_error("add default key"); + memset(get, 0, sizeof(struct tcp_ao_getsockopt)); + get->nkeys = 1; + get->get_all = 1; + break; + } + default: + test_error("unknown cmd"); + } + + return sk; +} + +static void test_extend(int cmd, bool get, const char *tst, socklen_t under_size) +{ + struct { + union { + struct tcp_ao_add add; + struct tcp_ao_del del; + struct tcp_ao_getsockopt get; + struct tcp_ao_info_opt info; + }; + char *extend[100]; + } tmp_opt; + socklen_t extended_size = sizeof(tmp_opt); + int sk; + + memset(&tmp_opt, 0, sizeof(tmp_opt)); + sk = prepare_defs(cmd, &tmp_opt); + __setsockopt_checked(sk, cmd, get, &tmp_opt, &under_size, + EINVAL, tst, ": minimum size"); + + memset(&tmp_opt, 0, sizeof(tmp_opt)); + sk = prepare_defs(cmd, &tmp_opt); + __setsockopt_checked(sk, cmd, get, &tmp_opt, &extended_size, + 0, tst, ": extended size"); + + memset(&tmp_opt, 0, sizeof(tmp_opt)); + sk = prepare_defs(cmd, &tmp_opt); + __setsockopt_checked(sk, cmd, get, NULL, &extended_size, + EFAULT, tst, ": null optval"); + + if (get) { + memset(&tmp_opt, 0, sizeof(tmp_opt)); + sk = prepare_defs(cmd, &tmp_opt); + __setsockopt_checked(sk, cmd, get, &tmp_opt, NULL, + EFAULT, tst, ": null optlen"); + } +} + +static void extend_tests(void) +{ + test_extend(TCP_AO_ADD_KEY, false, "AO add", + offsetof(struct tcp_ao_add, key)); + test_extend(TCP_AO_DEL_KEY, false, "AO del", + offsetof(struct tcp_ao_del, keyflags)); + test_extend(TCP_AO_INFO, false, "AO set info", + offsetof(struct tcp_ao_info_opt, pkt_dropped_icmp)); + test_extend(TCP_AO_INFO, true, "AO get info", -1); + test_extend(TCP_AO_GET_KEYS, true, "AO get keys", -1); +} + +static void test_optmem_limit(void) +{ + size_t i, keys_limit, current_optmem = test_get_optmem(); + struct tcp_ao_add ao; + union tcp_addr net = {}; + int sk; + + if (inet_pton(TEST_FAMILY, TEST_NETWORK, &net) != 1) + test_error("Can't convert ip address %s", TEST_NETWORK); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + keys_limit = current_optmem / KERNEL_TCP_AO_KEY_SZ_ROUND_UP; + for (i = 0;; i++) { + union tcp_addr key_peer; + int err; + + key_peer = gen_tcp_addr(net, i + 1); + tcp_addr_to_sockaddr_in(&ao.addr, &key_peer, 0); + err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, + &ao, sizeof(ao)); + if (!err) { + /* + * TCP_AO_ADD_KEY should be the same order as the real + * sizeof(struct tcp_ao_key) in kernel. + */ + if (i <= keys_limit * 10) + continue; + test_fail("optmem limit test failed: added %zu key", i); + break; + } + if (i < keys_limit) { + test_fail("optmem limit test failed: couldn't add %zu key", i); + break; + } + test_ok("optmem limit was hit on adding %zu key", i); + break; + } + close(sk); +} + +static void test_einval_add_key(void) +{ + struct tcp_ao_add ao; + int sk; + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.keylen = TCP_AO_MAXKEYLEN + 1; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "too big keylen"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.reserved = 1; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "using reserved padding"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.reserved2 = 1; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "using reserved2 padding"); + + /* tcp_ao_verify_ipv{4,6}() checks */ + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.addr.ss_family = AF_UNIX; + memcpy(&ao.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY)); + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "wrong address family"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + tcp_addr_to_sockaddr_in(&ao.addr, &this_ip_dest, 1234); + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "port (unsupported)"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.prefix = 0; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "no prefix, addr"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.prefix = 0; + memcpy(&ao.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY)); + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, 0, "no prefix, any addr"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.prefix = 32; + memcpy(&ao.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY)); + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "prefix, any addr"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.prefix = 129; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "too big prefix"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.prefix = 2; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "too short prefix"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.keyflags = (uint8_t)(-1); + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "bad key flags"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + make_listen(sk); + ao.set_current = 1; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "add current key on a listen socket"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + make_listen(sk); + ao.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "add rnext key on a listen socket"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + make_listen(sk); + ao.set_current = 1; + ao.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "add current+rnext key on a listen socket"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.set_current = 1; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, 0, "add key and set as current"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, 0, "add key and set as rnext"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.set_current = 1; + ao.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, 0, "add key and set as current+rnext"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.ifindex = 42; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, + "ifindex without TCP_AO_KEYF_IFNINDEX"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.keyflags |= TCP_AO_KEYF_IFINDEX; + ao.ifindex = 42; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EINVAL, "non-existent VRF"); + /* + * tcp_md5_do_lookup{,_any_l3index}() are checked in unsigned-md5 + * see client_vrf_tests(). + */ + + test_optmem_limit(); + + /* tcp_ao_parse_crypto() */ + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao.maclen = 100; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EMSGSIZE, "maclen bigger than TCP hdr"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + strcpy(ao.alg_name, "imaginary hash algo"); + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, ENOENT, "bad algo"); +} + +static void test_einval_del_key(void) +{ + struct tcp_ao_del del; + int sk; + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.reserved = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "using reserved padding"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.reserved2 = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "using reserved2 padding"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + make_listen(sk); + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0)) + test_error("add key"); + del.set_current = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "del and set current key on a listen socket"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + make_listen(sk); + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0)) + test_error("add key"); + del.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "del and set rnext key on a listen socket"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + make_listen(sk); + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0)) + test_error("add key"); + del.set_current = 1; + del.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "del and set current+rnext key on a listen socket"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.keyflags = (uint8_t)(-1); + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "bad key flags"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.ifindex = 42; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, + "ifindex without TCP_AO_KEYF_IFNINDEX"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.keyflags |= TCP_AO_KEYF_IFINDEX; + del.ifindex = 42; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "non-existent VRF"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.set_current = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set non-existing current key"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set non-existing rnext key"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.set_current = 1; + del.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set non-existing current+rnext key"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0)) + test_error("add key"); + del.set_current = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, 0, "set current key"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0)) + test_error("add key"); + del.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, 0, "set rnext key"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, DEFAULT_TEST_PREFIX, 0, 0)) + test_error("add key"); + del.set_current = 1; + del.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, 0, "set current+rnext key"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.set_current = 1; + del.current_key = 100; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set as current key to be removed"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.set_rnext = 1; + del.rnext = 100; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set as rnext key to be removed"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.set_current = 1; + del.current_key = 100; + del.set_rnext = 1; + del.rnext = 100; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "set as current+rnext key to be removed"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.del_async = 1; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, EINVAL, "async on non-listen"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.sndid = 101; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "non-existing sndid"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + del.rcvid = 101; + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "non-existing rcvid"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + tcp_addr_to_sockaddr_in(&del.addr, &this_ip_addr, 0); + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, ENOENT, "incorrect addr"); + + sk = prepare_defs(TCP_AO_DEL_KEY, &del); + setsockopt_checked(sk, TCP_AO_DEL_KEY, &del, 0, "correct key delete"); +} + +static void test_einval_ao_info(void) +{ + struct tcp_ao_info_opt info; + int sk; + + sk = prepare_defs(TCP_AO_INFO, &info); + make_listen(sk); + info.set_current = 1; + setsockopt_checked(sk, TCP_AO_INFO, &info, EINVAL, "set current key on a listen socket"); + + sk = prepare_defs(TCP_AO_INFO, &info); + make_listen(sk); + info.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_INFO, &info, EINVAL, "set rnext key on a listen socket"); + + sk = prepare_defs(TCP_AO_INFO, &info); + make_listen(sk); + info.set_current = 1; + info.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_INFO, &info, EINVAL, "set current+rnext key on a listen socket"); + + sk = prepare_defs(TCP_AO_INFO, &info); + info.reserved = 1; + setsockopt_checked(sk, TCP_AO_INFO, &info, EINVAL, "using reserved padding"); + + sk = prepare_defs(TCP_AO_INFO, &info); + info.reserved2 = 1; + setsockopt_checked(sk, TCP_AO_INFO, &info, EINVAL, "using reserved2 padding"); + + sk = prepare_defs(TCP_AO_INFO, &info); + info.accept_icmps = 1; + setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "accept_icmps"); + + sk = prepare_defs(TCP_AO_INFO, &info); + info.ao_required = 1; + setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "ao required"); + + if (!should_skip_test("ao required with MD5 key", KCONFIG_TCP_MD5)) { + sk = prepare_defs(TCP_AO_INFO, &info); + info.ao_required = 1; + if (test_set_md5(sk, tcp_md5_client, TEST_PREFIX, -1, + "long long secret")) { + test_error("setsockopt(TCP_MD5SIG_EXT)"); + close(sk); + } else { + setsockopt_checked(sk, TCP_AO_INFO, &info, EKEYREJECTED, + "ao required with MD5 key"); + } + } + + sk = prepare_defs(TCP_AO_INFO, &info); + info.set_current = 1; + setsockopt_checked(sk, TCP_AO_INFO, &info, ENOENT, "set non-existing current key"); + + sk = prepare_defs(TCP_AO_INFO, &info); + info.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_INFO, &info, ENOENT, "set non-existing rnext key"); + + sk = prepare_defs(TCP_AO_INFO, &info); + info.set_current = 1; + info.set_rnext = 1; + setsockopt_checked(sk, TCP_AO_INFO, &info, ENOENT, "set non-existing current+rnext key"); + + sk = prepare_defs(TCP_AO_INFO, &info); + info.set_current = 1; + info.current_key = 100; + setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "set current key"); + + sk = prepare_defs(TCP_AO_INFO, &info); + info.set_rnext = 1; + info.rnext = 100; + setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "set rnext key"); + + sk = prepare_defs(TCP_AO_INFO, &info); + info.set_current = 1; + info.set_rnext = 1; + info.current_key = 100; + info.rnext = 100; + setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "set current+rnext key"); + + sk = prepare_defs(TCP_AO_INFO, &info); + info.set_counters = 1; + info.pkt_good = 321; + info.pkt_bad = 888; + info.pkt_key_not_found = 654; + info.pkt_ao_required = 987654; + info.pkt_dropped_icmp = 10000; + setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "set counters"); + + sk = prepare_defs(TCP_AO_INFO, &info); + setsockopt_checked(sk, TCP_AO_INFO, &info, 0, "no-op"); +} + +static void getsockopt_checked(int sk, struct tcp_ao_getsockopt *optval, + int err, const char *tst) +{ + socklen_t len = sizeof(struct tcp_ao_getsockopt); + + __setsockopt_checked(sk, TCP_AO_GET_KEYS, true, optval, &len, err, + "get keys: ", tst); +} + +static void test_einval_get_keys(void) +{ + struct tcp_ao_getsockopt out; + int sk; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + getsockopt_checked(sk, &out, ENOENT, "no ao_info"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + getsockopt_checked(sk, &out, 0, "proper tcp_ao_get_mkts()"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.pkt_good = 643; + getsockopt_checked(sk, &out, EINVAL, "set out-only pkt_good counter"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.pkt_bad = 94; + getsockopt_checked(sk, &out, EINVAL, "set out-only pkt_bad counter"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.keyflags = (uint8_t)(-1); + getsockopt_checked(sk, &out, EINVAL, "bad keyflags"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.ifindex = 42; + getsockopt_checked(sk, &out, EINVAL, + "ifindex without TCP_AO_KEYF_IFNINDEX"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.reserved = 1; + getsockopt_checked(sk, &out, EINVAL, "using reserved field"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.prefix = 0; + tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0); + getsockopt_checked(sk, &out, EINVAL, "no prefix, addr"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.prefix = 0; + memcpy(&out.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY)); + getsockopt_checked(sk, &out, 0, "no prefix, any addr"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.prefix = 32; + memcpy(&out.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY)); + getsockopt_checked(sk, &out, EINVAL, "prefix, any addr"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.prefix = 129; + tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0); + getsockopt_checked(sk, &out, EINVAL, "too big prefix"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.prefix = 2; + tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0); + getsockopt_checked(sk, &out, EINVAL, "too short prefix"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.prefix = DEFAULT_TEST_PREFIX; + tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0); + getsockopt_checked(sk, &out, 0, "prefix + addr"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 1; + out.prefix = DEFAULT_TEST_PREFIX; + getsockopt_checked(sk, &out, EINVAL, "get_all + prefix"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 1; + tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0); + getsockopt_checked(sk, &out, EINVAL, "get_all + addr"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 1; + out.sndid = 1; + getsockopt_checked(sk, &out, EINVAL, "get_all + sndid"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 1; + out.rcvid = 1; + getsockopt_checked(sk, &out, EINVAL, "get_all + rcvid"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.is_current = 1; + out.prefix = DEFAULT_TEST_PREFIX; + getsockopt_checked(sk, &out, EINVAL, "current + prefix"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.is_current = 1; + tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0); + getsockopt_checked(sk, &out, EINVAL, "current + addr"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.is_current = 1; + out.sndid = 1; + getsockopt_checked(sk, &out, EINVAL, "current + sndid"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.is_current = 1; + out.rcvid = 1; + getsockopt_checked(sk, &out, EINVAL, "current + rcvid"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.is_rnext = 1; + out.prefix = DEFAULT_TEST_PREFIX; + getsockopt_checked(sk, &out, EINVAL, "rnext + prefix"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.is_rnext = 1; + tcp_addr_to_sockaddr_in(&out.addr, &this_ip_dest, 0); + getsockopt_checked(sk, &out, EINVAL, "rnext + addr"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.is_rnext = 1; + out.sndid = 1; + getsockopt_checked(sk, &out, EINVAL, "rnext + sndid"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.is_rnext = 1; + out.rcvid = 1; + getsockopt_checked(sk, &out, EINVAL, "rnext + rcvid"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 1; + out.is_current = 1; + getsockopt_checked(sk, &out, EINVAL, "get_all + current"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 1; + out.is_rnext = 1; + getsockopt_checked(sk, &out, EINVAL, "get_all + rnext"); + + sk = prepare_defs(TCP_AO_GET_KEYS, &out); + out.get_all = 0; + out.is_current = 1; + out.is_rnext = 1; + getsockopt_checked(sk, &out, 0, "current + rnext"); +} + +static void einval_tests(void) +{ + test_einval_add_key(); + test_einval_del_key(); + test_einval_ao_info(); + test_einval_get_keys(); +} + +static void duplicate_tests(void) +{ + union tcp_addr network_dup; + struct tcp_ao_add ao, ao2; + int sk; + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao))) + test_error("setsockopt()"); + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: full copy"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + ao2 = ao; + memcpy(&ao2.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY)); + ao2.prefix = 0; + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao2, sizeof(ao))) + test_error("setsockopt()"); + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: any addr key on the socket"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao))) + test_error("setsockopt()"); + memcpy(&ao.addr, &SOCKADDR_ANY, sizeof(SOCKADDR_ANY)); + ao.prefix = 0; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: add any addr key"); + + if (inet_pton(TEST_FAMILY, TEST_NETWORK, &network_dup) != 1) + test_error("Can't convert ip address %s", TEST_NETWORK); + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao))) + test_error("setsockopt()"); + if (test_prepare_def_key(&ao, "password", 0, network_dup, + 16, 0, 100, 100)) + test_error("prepare default tcp_ao_add"); + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: add any addr for the same subnet"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao))) + test_error("setsockopt()"); + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: full copy of a key"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao))) + test_error("setsockopt()"); + ao.rcvid = 101; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: RecvID differs"); + + sk = prepare_defs(TCP_AO_ADD_KEY, &ao); + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &ao, sizeof(ao))) + test_error("setsockopt()"); + ao.sndid = 101; + setsockopt_checked(sk, TCP_AO_ADD_KEY, &ao, EEXIST, "duplicate: SendID differs"); +} + +static void fetch_all_keys(int sk, struct tcp_ao_getsockopt *keys) +{ + socklen_t optlen = sizeof(struct tcp_ao_getsockopt); + + memset(keys, 0, sizeof(struct tcp_ao_getsockopt) * FILTER_TEST_NKEYS); + keys[0].get_all = 1; + keys[0].nkeys = FILTER_TEST_NKEYS; + if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &keys[0], &optlen)) + test_error("getsockopt"); +} + +static int prepare_test_keys(struct tcp_ao_getsockopt *keys) +{ + const char *test_password = "Test password number "; + struct tcp_ao_add test_ao[FILTER_TEST_NKEYS]; + char test_password_scratch[64] = {}; + u8 rcvid = 100, sndid = 100; + int sk; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + for (int i = 0; i < FILTER_TEST_NKEYS; i++) { + snprintf(test_password_scratch, 64, "%s %d", test_password, i); + test_prepare_key(&test_ao[i], DEFAULT_TEST_ALGO, this_ip_dest, + false, false, DEFAULT_TEST_PREFIX, 0, sndid++, + rcvid++, 0, 0, strlen(test_password_scratch), + test_password_scratch); + } + test_ao[0].set_current = 1; + test_ao[1].set_rnext = 1; + /* One key with a different addr and overlapping sndid, rcvid */ + tcp_addr_to_sockaddr_in(&test_ao[2].addr, &this_ip_addr, 0); + test_ao[2].sndid = 100; + test_ao[2].rcvid = 100; + + /* Add keys in a random order */ + for (int i = 0; i < FILTER_TEST_NKEYS; i++) { + int randidx = rand() % (FILTER_TEST_NKEYS - i); + + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, + &test_ao[randidx], sizeof(struct tcp_ao_add))) + test_error("setsockopt()"); + memcpy(&test_ao[randidx], &test_ao[FILTER_TEST_NKEYS - 1 - i], + sizeof(struct tcp_ao_add)); + } + + fetch_all_keys(sk, keys); + + return sk; +} + +/* Assumes passwords are unique */ +static int compare_mkts(struct tcp_ao_getsockopt *expected, int nexpected, + struct tcp_ao_getsockopt *actual, int nactual) +{ + int matches = 0; + + for (int i = 0; i < nexpected; i++) { + for (int j = 0; j < nactual; j++) { + if (memcmp(expected[i].key, actual[j].key, + TCP_AO_MAXKEYLEN) == 0) + matches++; + } + } + return nexpected - matches; +} + +static void filter_keys_checked(int sk, struct tcp_ao_getsockopt *filter, + struct tcp_ao_getsockopt *expected, + unsigned int nexpected, const char *tst) +{ + struct tcp_ao_getsockopt filtered_keys[FILTER_TEST_NKEYS] = {}; + struct tcp_ao_getsockopt all_keys[FILTER_TEST_NKEYS] = {}; + socklen_t len = sizeof(struct tcp_ao_getsockopt); + + fetch_all_keys(sk, all_keys); + memcpy(&filtered_keys[0], filter, sizeof(struct tcp_ao_getsockopt)); + filtered_keys[0].nkeys = FILTER_TEST_NKEYS; + if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, filtered_keys, &len)) + test_error("getsockopt"); + if (filtered_keys[0].nkeys != nexpected) { + test_fail("wrong nr of keys, expected %u got %u", nexpected, + filtered_keys[0].nkeys); + goto out_close; + } + if (compare_mkts(expected, nexpected, filtered_keys, + filtered_keys[0].nkeys)) { + test_fail("got wrong keys back"); + goto out_close; + } + test_ok("filter keys: %s", tst); + +out_close: + close(sk); + memset(filter, 0, sizeof(struct tcp_ao_getsockopt)); +} + +static void filter_tests(void) +{ + struct tcp_ao_getsockopt original_keys[FILTER_TEST_NKEYS]; + struct tcp_ao_getsockopt expected_keys[FILTER_TEST_NKEYS]; + struct tcp_ao_getsockopt filter = {}; + int sk, f, nmatches; + socklen_t len; + + f = 2; + sk = prepare_test_keys(original_keys); + filter.rcvid = original_keys[f].rcvid; + filter.sndid = original_keys[f].sndid; + memcpy(&filter.addr, &original_keys[f].addr, + sizeof(original_keys[f].addr)); + filter.prefix = original_keys[f].prefix; + filter_keys_checked(sk, &filter, &original_keys[f], 1, + "by sndid, rcvid, address"); + + f = -1; + sk = prepare_test_keys(original_keys); + for (int i = 0; i < original_keys[0].nkeys; i++) { + if (original_keys[i].is_current) { + f = i; + break; + } + } + if (f < 0) + test_error("No current key after adding one"); + filter.is_current = 1; + filter_keys_checked(sk, &filter, &original_keys[f], 1, "by is_current"); + + f = -1; + sk = prepare_test_keys(original_keys); + for (int i = 0; i < original_keys[0].nkeys; i++) { + if (original_keys[i].is_rnext) { + f = i; + break; + } + } + if (f < 0) + test_error("No rnext key after adding one"); + filter.is_rnext = 1; + filter_keys_checked(sk, &filter, &original_keys[f], 1, "by is_rnext"); + + f = -1; + nmatches = 0; + sk = prepare_test_keys(original_keys); + for (int i = 0; i < original_keys[0].nkeys; i++) { + if (original_keys[i].sndid == 100) { + f = i; + memcpy(&expected_keys[nmatches], &original_keys[i], + sizeof(struct tcp_ao_getsockopt)); + nmatches++; + } + } + if (f < 0) + test_error("No key for sndid 100"); + if (nmatches != 2) + test_error("Should have 2 keys with sndid 100"); + filter.rcvid = original_keys[f].rcvid; + filter.sndid = original_keys[f].sndid; + filter.addr.ss_family = test_family; + filter_keys_checked(sk, &filter, expected_keys, nmatches, + "by sndid, rcvid"); + + sk = prepare_test_keys(original_keys); + filter.get_all = 1; + filter.nkeys = FILTER_TEST_NKEYS / 2; + len = sizeof(struct tcp_ao_getsockopt); + if (getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &filter, &len)) + test_error("getsockopt"); + if (filter.nkeys == FILTER_TEST_NKEYS) + test_ok("filter keys: correct nkeys when in.nkeys < matches"); + else + test_fail("filter keys: wrong nkeys, expected %u got %u", + FILTER_TEST_NKEYS, filter.nkeys); +} + +static void *client_fn(void *arg) +{ + if (inet_pton(TEST_FAMILY, __TEST_CLIENT_IP(2), &tcp_md5_client) != 1) + test_error("Can't convert ip address"); + extend_tests(); + einval_tests(); + filter_tests(); + duplicate_tests(); + + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(126, client_fn, NULL); + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/settings b/tools/testing/selftests/net/tcp_ao/settings new file mode 100644 index 000000000000..6091b45d226b --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/settings @@ -0,0 +1 @@ +timeout=120 diff --git a/tools/testing/selftests/net/tcp_ao/unsigned-md5.c b/tools/testing/selftests/net/tcp_ao/unsigned-md5.c new file mode 100644 index 000000000000..a1467b64390a --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/unsigned-md5.c @@ -0,0 +1,772 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Author: Dmitry Safonov <dima@arista.com> */ +#include <inttypes.h> +#include "aolib.h" + +#define fault(type) (inj == FAULT_ ## type) +static const char *md5_password = "Some evil genius, enemy to mankind, must have been the first contriver."; +static const char *ao_password = DEFAULT_TEST_PASSWORD; +static volatile int sk_pair; + +static union tcp_addr client2; +static union tcp_addr client3; + +static const int test_vrf_ifindex = 200; +static const uint8_t test_vrf_tabid = 42; +static void setup_vrfs(void) +{ + int err; + + if (!kernel_config_has(KCONFIG_NET_VRF)) + return; + + err = add_vrf("ksft-vrf", test_vrf_tabid, test_vrf_ifindex, -1); + if (err) + test_error("Failed to add a VRF: %d", err); + + err = link_set_up("ksft-vrf"); + if (err) + test_error("Failed to bring up a VRF"); + + err = ip_route_add_vrf(veth_name, TEST_FAMILY, + this_ip_addr, this_ip_dest, test_vrf_tabid); + if (err) + test_error("Failed to add a route to VRF: %d", err); +} + +static void try_accept(const char *tst_name, unsigned int port, + union tcp_addr *md5_addr, uint8_t md5_prefix, + union tcp_addr *ao_addr, uint8_t ao_prefix, + bool set_ao_required, + uint8_t sndid, uint8_t rcvid, uint8_t vrf, + const char *cnt_name, test_cnt cnt_expected, + int needs_tcp_md5, fault_t inj) +{ + struct tcp_counters cnt1, cnt2; + uint64_t before_cnt = 0, after_cnt = 0; /* silence GCC */ + test_cnt poll_cnt = (cnt_expected == TEST_CNT_GOOD) ? 0 : cnt_expected; + int lsk, err, sk = -1; + + if (needs_tcp_md5 && should_skip_test(tst_name, KCONFIG_TCP_MD5)) + return; + + lsk = test_listen_socket(this_ip_addr, port, 1); + + if (md5_addr && test_set_md5(lsk, *md5_addr, md5_prefix, -1, md5_password)) + test_error("setsockopt(TCP_MD5SIG_EXT)"); + + if (ao_addr && test_add_key(lsk, ao_password, + *ao_addr, ao_prefix, sndid, rcvid)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + if (set_ao_required && test_set_ao_flags(lsk, true, false)) + test_error("setsockopt(TCP_AO_INFO)"); + + if (cnt_name) + before_cnt = netstat_get_one(cnt_name, NULL); + if (ao_addr && test_get_tcp_counters(lsk, &cnt1)) + test_error("test_get_tcp_counters()"); + + synchronize_threads(); /* preparations done */ + + err = test_skpair_wait_poll(lsk, 0, poll_cnt, &sk_pair); + synchronize_threads(); /* connect()/accept() timeouts */ + if (err == -ETIMEDOUT) { + sk_pair = err; + if (!fault(TIMEOUT)) + test_fail("%s: timed out for accept()", tst_name); + } else if (err == -EKEYREJECTED) { + if (!fault(KEYREJECT)) + test_fail("%s: key was rejected", tst_name); + } else if (err < 0) { + test_error("test_skpair_wait_poll()"); + } else { + if (fault(TIMEOUT)) + test_fail("%s: ready to accept", tst_name); + + sk = accept(lsk, NULL, NULL); + if (sk < 0) { + test_error("accept()"); + } else { + if (fault(TIMEOUT)) + test_fail("%s: accepted", tst_name); + } + } + + if (ao_addr && test_get_tcp_counters(lsk, &cnt2)) + test_error("test_get_tcp_counters()"); + close(lsk); + + if (!cnt_name) { + test_ok("%s: no counter checks", tst_name); + goto out; + } + + after_cnt = netstat_get_one(cnt_name, NULL); + + if (after_cnt <= before_cnt) { + test_fail("%s: %s counter did not increase: %" PRIu64 " <= %" PRIu64, + tst_name, cnt_name, after_cnt, before_cnt); + } else { + test_ok("%s: counter %s increased %" PRIu64 " => %" PRIu64, + tst_name, cnt_name, before_cnt, after_cnt); + } + if (ao_addr) + test_assert_counters(tst_name, &cnt1, &cnt2, cnt_expected); + +out: + synchronize_threads(); /* test_kill_sk() */ + if (sk >= 0) + test_kill_sk(sk); +} + +static void server_add_routes(void) +{ + int family = TEST_FAMILY; + + synchronize_threads(); /* client_add_ips() */ + + if (ip_route_add(veth_name, family, this_ip_addr, client2)) + test_error("Failed to add route"); + if (ip_route_add(veth_name, family, this_ip_addr, client3)) + test_error("Failed to add route"); +} + +static void server_add_fail_tests(unsigned int *port) +{ + union tcp_addr addr_any = {}; + + try_accept("TCP-AO established: add TCP-MD5 key", (*port)++, NULL, 0, + &addr_any, 0, 0, 100, 100, 0, "TCPAOGood", TEST_CNT_GOOD, + 1, 0); + try_accept("TCP-MD5 established: add TCP-AO key", (*port)++, &addr_any, + 0, NULL, 0, 0, 0, 0, 0, NULL, 0, 1, 0); + try_accept("non-signed established: add TCP-AO key", (*port)++, NULL, 0, + NULL, 0, 0, 0, 0, 0, "CurrEstab", 0, 0, 0); +} + +static void server_vrf_tests(unsigned int *port) +{ + setup_vrfs(); +} + +static void *server_fn(void *arg) +{ + unsigned int port = test_server_port; + union tcp_addr addr_any = {}; + + server_add_routes(); + + try_accept("[server] AO server (INADDR_ANY): AO client", port++, NULL, 0, + &addr_any, 0, 0, 100, 100, 0, "TCPAOGood", + TEST_CNT_GOOD, 0, 0); + try_accept("[server] AO server (INADDR_ANY): MD5 client", port++, NULL, 0, + &addr_any, 0, 0, 100, 100, 0, "TCPMD5Unexpected", + TEST_CNT_NS_MD5_UNEXPECTED, 1, FAULT_TIMEOUT); + try_accept("[server] AO server (INADDR_ANY): no sign client", port++, NULL, 0, + &addr_any, 0, 0, 100, 100, 0, "TCPAORequired", + TEST_CNT_AO_REQUIRED, 0, FAULT_TIMEOUT); + try_accept("[server] AO server (AO_REQUIRED): AO client", port++, NULL, 0, + &this_ip_dest, TEST_PREFIX, true, + 100, 100, 0, "TCPAOGood", TEST_CNT_GOOD, 0, 0); + try_accept("[server] AO server (AO_REQUIRED): unsigned client", port++, NULL, 0, + &this_ip_dest, TEST_PREFIX, true, + 100, 100, 0, "TCPAORequired", + TEST_CNT_AO_REQUIRED, 0, FAULT_TIMEOUT); + + try_accept("[server] MD5 server (INADDR_ANY): AO client", port++, &addr_any, 0, + NULL, 0, 0, 0, 0, 0, "TCPAOKeyNotFound", + TEST_CNT_NS_KEY_NOT_FOUND, 1, FAULT_TIMEOUT); + try_accept("[server] MD5 server (INADDR_ANY): MD5 client", port++, &addr_any, 0, + NULL, 0, 0, 0, 0, 0, NULL, 0, 1, 0); + try_accept("[server] MD5 server (INADDR_ANY): no sign client", port++, &addr_any, + 0, NULL, 0, 0, 0, 0, 0, "TCPMD5NotFound", + TEST_CNT_NS_MD5_NOT_FOUND, 1, FAULT_TIMEOUT); + + try_accept("[server] no sign server: AO client", port++, NULL, 0, + NULL, 0, 0, 0, 0, 0, "TCPAOKeyNotFound", + TEST_CNT_NS_KEY_NOT_FOUND, 0, FAULT_TIMEOUT); + try_accept("[server] no sign server: MD5 client", port++, NULL, 0, + NULL, 0, 0, 0, 0, 0, "TCPMD5Unexpected", + TEST_CNT_NS_MD5_UNEXPECTED, 1, FAULT_TIMEOUT); + try_accept("[server] no sign server: no sign client", port++, NULL, 0, + NULL, 0, 0, 0, 0, 0, "CurrEstab", 0, 0, 0); + + try_accept("[server] AO+MD5 server: AO client (matching)", port++, + &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, "TCPAOGood", TEST_CNT_GOOD, 1, 0); + try_accept("[server] AO+MD5 server: AO client (misconfig, matching MD5)", port++, + &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, + 1, FAULT_TIMEOUT); + try_accept("[server] AO+MD5 server: AO client (misconfig, non-matching)", port++, + &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, "TCPAOKeyNotFound", TEST_CNT_AO_KEY_NOT_FOUND, + 1, FAULT_TIMEOUT); + try_accept("[server] AO+MD5 server: MD5 client (matching)", port++, + &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, NULL, 0, 1, 0); + try_accept("[server] AO+MD5 server: MD5 client (misconfig, matching AO)", port++, + &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, "TCPMD5Unexpected", + TEST_CNT_NS_MD5_UNEXPECTED, 1, FAULT_TIMEOUT); + try_accept("[server] AO+MD5 server: MD5 client (misconfig, non-matching)", port++, + &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, "TCPMD5Unexpected", + TEST_CNT_NS_MD5_UNEXPECTED, 1, FAULT_TIMEOUT); + try_accept("[server] AO+MD5 server: no sign client (unmatched)", port++, + &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, "CurrEstab", 0, 1, 0); + try_accept("[server] AO+MD5 server: no sign client (misconfig, matching AO)", + port++, &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, "TCPAORequired", + TEST_CNT_AO_REQUIRED, 1, FAULT_TIMEOUT); + try_accept("[server] AO+MD5 server: no sign client (misconfig, matching MD5)", + port++, &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, "TCPMD5NotFound", + TEST_CNT_NS_MD5_NOT_FOUND, 1, FAULT_TIMEOUT); + + /* Key rejected by the other side, failing short through skpair */ + try_accept("[server] AO+MD5 server: client with both [TCP-MD5] and TCP-AO keys", + port++, &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, NULL, 0, 1, FAULT_KEYREJECT); + try_accept("[server] AO+MD5 server: client with both TCP-MD5 and [TCP-AO] keys", + port++, &this_ip_dest, TEST_PREFIX, &client2, TEST_PREFIX, 0, + 100, 100, 0, NULL, 0, 1, FAULT_KEYREJECT); + + server_add_fail_tests(&port); + + server_vrf_tests(&port); + + /* client exits */ + synchronize_threads(); + return NULL; +} + +static int client_bind(int sk, union tcp_addr bind_addr) +{ +#ifdef IPV6_TEST + struct sockaddr_in6 addr = { + .sin6_family = AF_INET6, + .sin6_port = 0, + .sin6_addr = bind_addr.a6, + }; +#else + struct sockaddr_in addr = { + .sin_family = AF_INET, + .sin_port = 0, + .sin_addr = bind_addr.a4, + }; +#endif + return bind(sk, &addr, sizeof(addr)); +} + +static void try_connect(const char *tst_name, unsigned int port, + union tcp_addr *md5_addr, uint8_t md5_prefix, + union tcp_addr *ao_addr, uint8_t ao_prefix, + uint8_t sndid, uint8_t rcvid, uint8_t vrf, + fault_t inj, int needs_tcp_md5, union tcp_addr *bind_addr) +{ + int sk, ret; + + if (needs_tcp_md5 && should_skip_test(tst_name, KCONFIG_TCP_MD5)) + return; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + if (bind_addr && client_bind(sk, *bind_addr)) + test_error("bind()"); + + if (md5_addr && test_set_md5(sk, *md5_addr, md5_prefix, -1, md5_password)) + test_error("setsockopt(TCP_MD5SIG_EXT)"); + + if (ao_addr && test_add_key(sk, ao_password, *ao_addr, + ao_prefix, sndid, rcvid)) + test_error("setsockopt(TCP_AO_ADD_KEY)"); + + synchronize_threads(); /* preparations done */ + + ret = test_skpair_connect_poll(sk, this_ip_dest, port, 0, &sk_pair); + synchronize_threads(); /* connect()/accept() timeouts */ + if (ret < 0) { + sk_pair = ret; + if (fault(KEYREJECT) && ret == -EKEYREJECTED) + test_ok("%s: connect() was prevented", tst_name); + else if (ret == -ETIMEDOUT && fault(TIMEOUT)) + test_ok("%s", tst_name); + else if (ret == -ECONNREFUSED && + (fault(TIMEOUT) || fault(KEYREJECT))) + test_ok("%s: refused to connect", tst_name); + else + test_error("%s: connect() returned %d", tst_name, ret); + goto out; + } + + if (fault(TIMEOUT) || fault(KEYREJECT)) + test_fail("%s: connected", tst_name); + else + test_ok("%s: connected", tst_name); + +out: + synchronize_threads(); /* test_kill_sk() */ + if (ret > 0) /* test_skpair_connect_poll() cleans up on failure */ + test_kill_sk(sk); +} + +#define PREINSTALL_MD5_FIRST BIT(0) +#define PREINSTALL_AO BIT(1) +#define POSTINSTALL_AO BIT(2) +#define PREINSTALL_MD5 BIT(3) +#define POSTINSTALL_MD5 BIT(4) + +static int try_add_key_vrf(int sk, union tcp_addr in_addr, uint8_t prefix, + int vrf, uint8_t sndid, uint8_t rcvid, + bool set_ao_required) +{ + uint8_t keyflags = 0; + + if (vrf >= 0) + keyflags |= TCP_AO_KEYF_IFINDEX; + else + vrf = 0; + if (set_ao_required) { + int err = test_set_ao_flags(sk, true, 0); + + if (err) + return err; + } + return test_add_key_vrf(sk, ao_password, keyflags, in_addr, prefix, + (uint8_t)vrf, sndid, rcvid); +} + +static bool test_continue(const char *tst_name, int err, + fault_t inj, bool added_ao) +{ + bool expected_to_fail; + + expected_to_fail = fault(PREINSTALL_AO) && added_ao; + expected_to_fail |= fault(PREINSTALL_MD5) && !added_ao; + + if (!err) { + if (!expected_to_fail) + return true; + test_fail("%s: setsockopt()s were expected to fail", tst_name); + return false; + } + if (err != -EKEYREJECTED || !expected_to_fail) { + test_error("%s: setsockopt(%s) = %d", tst_name, + added_ao ? "TCP_AO_ADD_KEY" : "TCP_MD5SIG_EXT", err); + return false; + } + test_ok("%s: prefailed as expected: %m", tst_name); + return false; +} + +static int open_add(const char *tst_name, unsigned int port, + unsigned int strategy, + union tcp_addr md5_addr, uint8_t md5_prefix, int md5_vrf, + union tcp_addr ao_addr, uint8_t ao_prefix, + int ao_vrf, bool set_ao_required, + uint8_t sndid, uint8_t rcvid, + fault_t inj) +{ + int sk; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) + test_error("socket()"); + + if (client_bind(sk, this_ip_addr)) + test_error("bind()"); + + if (strategy & PREINSTALL_MD5_FIRST) { + if (test_set_md5(sk, md5_addr, md5_prefix, md5_vrf, md5_password)) + test_error("setsockopt(TCP_MD5SIG_EXT)"); + } + + if (strategy & PREINSTALL_AO) { + int err = try_add_key_vrf(sk, ao_addr, ao_prefix, ao_vrf, + sndid, rcvid, set_ao_required); + + if (!test_continue(tst_name, err, inj, true)) { + close(sk); + return -1; + } + } + + if (strategy & PREINSTALL_MD5) { + errno = 0; + test_set_md5(sk, md5_addr, md5_prefix, md5_vrf, md5_password); + if (!test_continue(tst_name, -errno, inj, false)) { + close(sk); + return -1; + } + } + + return sk; +} + +static void try_to_preadd(const char *tst_name, unsigned int port, + unsigned int strategy, + union tcp_addr md5_addr, uint8_t md5_prefix, + int md5_vrf, + union tcp_addr ao_addr, uint8_t ao_prefix, + int ao_vrf, bool set_ao_required, + uint8_t sndid, uint8_t rcvid, + int needs_tcp_md5, int needs_vrf, fault_t inj) +{ + int sk; + + if (needs_tcp_md5 && should_skip_test(tst_name, KCONFIG_TCP_MD5)) + return; + if (needs_vrf && should_skip_test(tst_name, KCONFIG_NET_VRF)) + return; + + sk = open_add(tst_name, port, strategy, md5_addr, md5_prefix, md5_vrf, + ao_addr, ao_prefix, ao_vrf, set_ao_required, + sndid, rcvid, inj); + if (sk < 0) + return; + + test_ok("%s", tst_name); + close(sk); +} + +static void try_to_add(const char *tst_name, unsigned int port, + unsigned int strategy, + union tcp_addr md5_addr, uint8_t md5_prefix, + int md5_vrf, + union tcp_addr ao_addr, uint8_t ao_prefix, + int ao_vrf, uint8_t sndid, uint8_t rcvid, + int needs_tcp_md5, fault_t inj) +{ + int sk, ret; + + if (needs_tcp_md5 && should_skip_test(tst_name, KCONFIG_TCP_MD5)) + return; + + sk = open_add(tst_name, port, strategy, md5_addr, md5_prefix, md5_vrf, + ao_addr, ao_prefix, ao_vrf, 0, sndid, rcvid, inj); + if (sk < 0) + return; + + synchronize_threads(); /* preparations done */ + + ret = test_skpair_connect_poll(sk, this_ip_dest, port, 0, &sk_pair); + + synchronize_threads(); /* connect()/accept() timeouts */ + if (ret < 0) { + test_error("%s: connect() returned %d", tst_name, ret); + goto out; + } + + if (strategy & POSTINSTALL_MD5) { + if (test_set_md5(sk, md5_addr, md5_prefix, md5_vrf, md5_password)) { + if (fault(POSTINSTALL)) { + test_ok("%s: postfailed as expected", tst_name); + goto out; + } else { + test_error("setsockopt(TCP_MD5SIG_EXT)"); + } + } else if (fault(POSTINSTALL)) { + test_fail("%s: post setsockopt() was expected to fail", tst_name); + goto out; + } + } + + if (strategy & POSTINSTALL_AO) { + if (try_add_key_vrf(sk, ao_addr, ao_prefix, ao_vrf, + sndid, rcvid, 0)) { + if (fault(POSTINSTALL)) { + test_ok("%s: postfailed as expected", tst_name); + goto out; + } else { + test_error("setsockopt(TCP_AO_ADD_KEY)"); + } + } else if (fault(POSTINSTALL)) { + test_fail("%s: post setsockopt() was expected to fail", tst_name); + goto out; + } + } + +out: + synchronize_threads(); /* test_kill_sk() */ + if (ret > 0) /* test_skpair_connect_poll() cleans up on failure */ + test_kill_sk(sk); +} + +static void client_add_ip(union tcp_addr *client, const char *ip) +{ + int err, family = TEST_FAMILY; + + if (inet_pton(family, ip, client) != 1) + test_error("Can't convert ip address %s", ip); + + err = ip_addr_add(veth_name, family, *client, TEST_PREFIX); + if (err) + test_error("Failed to add ip address: %d", err); +} + +static void client_add_ips(void) +{ + client_add_ip(&client2, __TEST_CLIENT_IP(2)); + client_add_ip(&client3, __TEST_CLIENT_IP(3)); + synchronize_threads(); /* server_add_routes() */ +} + +static void client_add_fail_tests(unsigned int *port) +{ + try_to_add("TCP-AO established: add TCP-MD5 key", + (*port)++, POSTINSTALL_MD5 | PREINSTALL_AO, + this_ip_dest, TEST_PREFIX, -1, this_ip_dest, TEST_PREFIX, 0, + 100, 100, 1, FAULT_POSTINSTALL); + try_to_add("TCP-MD5 established: add TCP-AO key", + (*port)++, PREINSTALL_MD5 | POSTINSTALL_AO, + this_ip_dest, TEST_PREFIX, -1, this_ip_dest, TEST_PREFIX, 0, + 100, 100, 1, FAULT_POSTINSTALL); + try_to_add("non-signed established: add TCP-AO key", + (*port)++, POSTINSTALL_AO, + this_ip_dest, TEST_PREFIX, -1, this_ip_dest, TEST_PREFIX, 0, + 100, 100, 0, FAULT_POSTINSTALL); + + try_to_add("TCP-AO key intersects with existing TCP-MD5 key", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, -1, this_ip_addr, TEST_PREFIX, -1, + 100, 100, 1, FAULT_PREINSTALL_AO); + try_to_add("TCP-MD5 key intersects with existing TCP-AO key", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, -1, this_ip_addr, TEST_PREFIX, -1, + 100, 100, 1, FAULT_PREINSTALL_MD5); + + try_to_preadd("TCP-MD5 key + TCP-AO required", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, -1, + this_ip_addr, TEST_PREFIX, -1, true, + 100, 100, 1, 0, FAULT_PREINSTALL_AO); + try_to_preadd("TCP-AO required on socket + TCP-MD5 key", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, -1, + this_ip_addr, TEST_PREFIX, -1, true, + 100, 100, 1, 0, FAULT_PREINSTALL_MD5); +} + +static void client_vrf_tests(unsigned int *port) +{ + setup_vrfs(); + + /* The following restrictions for setsockopt()s are expected: + * + * |--------------|-----------------|-------------|-------------| + * | | MD5 key without | MD5 key | MD5 key | + * | | l3index | l3index=0 | l3index=N | + * |--------------|-----------------|-------------|-------------| + * | TCP-AO key | | | | + * | without | reject | reject | reject | + * | l3index | | | | + * |--------------|-----------------|-------------|-------------| + * | TCP-AO key | | | | + * | l3index=0 | reject | reject | allow | + * |--------------|-----------------|-------------|-------------| + * | TCP-AO key | | | | + * | l3index=N | reject | allow | reject | + * |--------------|-----------------|-------------|-------------| + */ + try_to_preadd("VRF: TCP-AO key (no l3index) + TCP-MD5 key (no l3index)", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, -1, + this_ip_addr, TEST_PREFIX, -1, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_MD5); + try_to_preadd("VRF: TCP-MD5 key (no l3index) + TCP-AO key (no l3index)", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, -1, + this_ip_addr, TEST_PREFIX, -1, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_AO); + try_to_preadd("VRF: TCP-AO key (no l3index) + TCP-MD5 key (l3index=0)", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, 0, + this_ip_addr, TEST_PREFIX, -1, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_MD5); + try_to_preadd("VRF: TCP-MD5 key (l3index=0) + TCP-AO key (no l3index)", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, 0, + this_ip_addr, TEST_PREFIX, -1, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_AO); + try_to_preadd("VRF: TCP-AO key (no l3index) + TCP-MD5 key (l3index=N)", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, + this_ip_addr, TEST_PREFIX, -1, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_MD5); + try_to_preadd("VRF: TCP-MD5 key (l3index=N) + TCP-AO key (no l3index)", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, + this_ip_addr, TEST_PREFIX, -1, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_AO); + + try_to_preadd("VRF: TCP-AO key (l3index=0) + TCP-MD5 key (no l3index)", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, -1, + this_ip_addr, TEST_PREFIX, 0, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_MD5); + try_to_preadd("VRF: TCP-MD5 key (no l3index) + TCP-AO key (l3index=0)", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, -1, + this_ip_addr, TEST_PREFIX, 0, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_AO); + try_to_preadd("VRF: TCP-AO key (l3index=0) + TCP-MD5 key (l3index=0)", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, 0, + this_ip_addr, TEST_PREFIX, 0, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_MD5); + try_to_preadd("VRF: TCP-MD5 key (l3index=0) + TCP-AO key (l3index=0)", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, 0, + this_ip_addr, TEST_PREFIX, 0, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_AO); + try_to_preadd("VRF: TCP-AO key (l3index=0) + TCP-MD5 key (l3index=N)", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, + this_ip_addr, TEST_PREFIX, 0, 0, 100, 100, + 1, 1, 0); + try_to_preadd("VRF: TCP-MD5 key (l3index=N) + TCP-AO key (l3index=0)", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, + this_ip_addr, TEST_PREFIX, 0, 0, 100, 100, + 1, 1, 0); + + try_to_preadd("VRF: TCP-AO key (l3index=N) + TCP-MD5 key (no l3index)", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, + this_ip_addr, TEST_PREFIX, -1, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_MD5); + try_to_preadd("VRF: TCP-MD5 key (no l3index) + TCP-AO key (l3index=N)", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, -1, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_AO); + try_to_preadd("VRF: TCP-AO key (l3index=N) + TCP-MD5 key (l3index=0)", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, 0, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, 0, 100, 100, + 1, 1, 0); + try_to_preadd("VRF: TCP-MD5 key (l3index=0) + TCP-AO key (l3index=N)", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, 0, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, 0, 100, 100, + 1, 1, 0); + try_to_preadd("VRF: TCP-AO key (l3index=N) + TCP-MD5 key (l3index=N)", + (*port)++, PREINSTALL_MD5 | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_MD5); + try_to_preadd("VRF: TCP-MD5 key (l3index=N) + TCP-AO key (l3index=N)", + (*port)++, PREINSTALL_MD5_FIRST | PREINSTALL_AO, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, + this_ip_addr, TEST_PREFIX, test_vrf_ifindex, 0, 100, 100, + 1, 1, FAULT_PREINSTALL_AO); +} + +static void *client_fn(void *arg) +{ + unsigned int port = test_server_port; + union tcp_addr addr_any = {}; + + client_add_ips(); + + try_connect("AO server (INADDR_ANY): AO client", port++, NULL, 0, + &addr_any, 0, 100, 100, 0, 0, 0, &this_ip_addr); + trace_hash_event_expect(TCP_HASH_MD5_UNEXPECTED, this_ip_addr, + this_ip_dest, -1, port, 0, 0, 1, 0, 0, 0); + try_connect("AO server (INADDR_ANY): MD5 client", port++, &addr_any, 0, + NULL, 0, 100, 100, 0, FAULT_TIMEOUT, 1, &this_ip_addr); + trace_hash_event_expect(TCP_HASH_AO_REQUIRED, this_ip_addr, + this_ip_dest, -1, port, 0, 0, 1, 0, 0, 0); + try_connect("AO server (INADDR_ANY): unsigned client", port++, NULL, 0, + NULL, 0, 100, 100, 0, FAULT_TIMEOUT, 0, &this_ip_addr); + try_connect("AO server (AO_REQUIRED): AO client", port++, NULL, 0, + &addr_any, 0, 100, 100, 0, 0, 0, &this_ip_addr); + trace_hash_event_expect(TCP_HASH_AO_REQUIRED, client2, + this_ip_dest, -1, port, 0, 0, 1, 0, 0, 0); + try_connect("AO server (AO_REQUIRED): unsigned client", port++, NULL, 0, + NULL, 0, 100, 100, 0, FAULT_TIMEOUT, 0, &client2); + + trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest, + -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); + try_connect("MD5 server (INADDR_ANY): AO client", port++, NULL, 0, + &addr_any, 0, 100, 100, 0, FAULT_TIMEOUT, 1, &this_ip_addr); + try_connect("MD5 server (INADDR_ANY): MD5 client", port++, &addr_any, 0, + NULL, 0, 100, 100, 0, 0, 1, &this_ip_addr); + trace_hash_event_expect(TCP_HASH_MD5_REQUIRED, this_ip_addr, + this_ip_dest, -1, port, 0, 0, 1, 0, 0, 0); + try_connect("MD5 server (INADDR_ANY): no sign client", port++, NULL, 0, + NULL, 0, 100, 100, 0, FAULT_TIMEOUT, 1, &this_ip_addr); + + trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest, + -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); + try_connect("no sign server: AO client", port++, NULL, 0, + &addr_any, 0, 100, 100, 0, FAULT_TIMEOUT, 0, &this_ip_addr); + trace_hash_event_expect(TCP_HASH_MD5_UNEXPECTED, this_ip_addr, + this_ip_dest, -1, port, 0, 0, 1, 0, 0, 0); + try_connect("no sign server: MD5 client", port++, &addr_any, 0, + NULL, 0, 100, 100, 0, FAULT_TIMEOUT, 1, &this_ip_addr); + try_connect("no sign server: no sign client", port++, NULL, 0, + NULL, 0, 100, 100, 0, 0, 0, &this_ip_addr); + + try_connect("AO+MD5 server: AO client (matching)", port++, NULL, 0, + &addr_any, 0, 100, 100, 0, 0, 1, &client2); + trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, this_ip_addr, this_ip_dest, + -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); + try_connect("AO+MD5 server: AO client (misconfig, matching MD5)", + port++, NULL, 0, &addr_any, 0, 100, 100, 0, + FAULT_TIMEOUT, 1, &this_ip_addr); + trace_ao_event_expect(TCP_AO_KEY_NOT_FOUND, client3, this_ip_dest, + -1, port, 0, 0, 1, 0, 0, 0, 100, 100, -1); + try_connect("AO+MD5 server: AO client (misconfig, non-matching)", + port++, NULL, 0, &addr_any, 0, 100, 100, 0, + FAULT_TIMEOUT, 1, &client3); + try_connect("AO+MD5 server: MD5 client (matching)", port++, &addr_any, 0, + NULL, 0, 100, 100, 0, 0, 1, &this_ip_addr); + trace_hash_event_expect(TCP_HASH_MD5_UNEXPECTED, client2, + this_ip_dest, -1, port, 0, 0, 1, 0, 0, 0); + try_connect("AO+MD5 server: MD5 client (misconfig, matching AO)", + port++, &addr_any, 0, NULL, 0, 100, 100, 0, FAULT_TIMEOUT, + 1, &client2); + trace_hash_event_expect(TCP_HASH_MD5_UNEXPECTED, client3, + this_ip_dest, -1, port, 0, 0, 1, 0, 0, 0); + try_connect("AO+MD5 server: MD5 client (misconfig, non-matching)", + port++, &addr_any, 0, NULL, 0, 100, 100, 0, FAULT_TIMEOUT, + 1, &client3); + try_connect("AO+MD5 server: no sign client (unmatched)", + port++, NULL, 0, NULL, 0, 100, 100, 0, 0, 1, &client3); + trace_hash_event_expect(TCP_HASH_AO_REQUIRED, client2, + this_ip_dest, -1, port, 0, 0, 1, 0, 0, 0); + try_connect("AO+MD5 server: no sign client (misconfig, matching AO)", + port++, NULL, 0, NULL, 0, 100, 100, 0, FAULT_TIMEOUT, + 1, &client2); + trace_hash_event_expect(TCP_HASH_MD5_REQUIRED, this_ip_addr, + this_ip_dest, -1, port, 0, 0, 1, 0, 0, 0); + try_connect("AO+MD5 server: no sign client (misconfig, matching MD5)", + port++, NULL, 0, NULL, 0, 100, 100, 0, FAULT_TIMEOUT, + 1, &this_ip_addr); + + try_connect("AO+MD5 server: client with both [TCP-MD5] and TCP-AO keys", + port++, &this_ip_addr, TEST_PREFIX, + &client2, TEST_PREFIX, 100, 100, 0, FAULT_KEYREJECT, + 1, &this_ip_addr); + try_connect("AO+MD5 server: client with both TCP-MD5 and [TCP-AO] keys", + port++, &this_ip_addr, TEST_PREFIX, + &client2, TEST_PREFIX, 100, 100, 0, FAULT_KEYREJECT, + 1, &client2); + + client_add_fail_tests(&port); + client_vrf_tests(&port); + + return NULL; +} + +int main(int argc, char *argv[]) +{ + test_init(73, server_fn, client_fn); + return 0; +} |