diff options
Diffstat (limited to 'tools/testing/selftests/net/tcp_ao/lib')
-rw-r--r-- | tools/testing/selftests/net/tcp_ao/lib/aolib.h | 832 | ||||
-rw-r--r-- | tools/testing/selftests/net/tcp_ao/lib/ftrace-tcp.c | 556 | ||||
-rw-r--r-- | tools/testing/selftests/net/tcp_ao/lib/ftrace.c | 543 | ||||
-rw-r--r-- | tools/testing/selftests/net/tcp_ao/lib/kconfig.c | 157 | ||||
-rw-r--r-- | tools/testing/selftests/net/tcp_ao/lib/netlink.c | 413 | ||||
-rw-r--r-- | tools/testing/selftests/net/tcp_ao/lib/proc.c | 273 | ||||
-rw-r--r-- | tools/testing/selftests/net/tcp_ao/lib/repair.c | 254 | ||||
-rw-r--r-- | tools/testing/selftests/net/tcp_ao/lib/setup.c | 368 | ||||
-rw-r--r-- | tools/testing/selftests/net/tcp_ao/lib/sock.c | 730 | ||||
-rw-r--r-- | tools/testing/selftests/net/tcp_ao/lib/utils.c | 56 |
10 files changed, 4182 insertions, 0 deletions
diff --git a/tools/testing/selftests/net/tcp_ao/lib/aolib.h b/tools/testing/selftests/net/tcp_ao/lib/aolib.h new file mode 100644 index 000000000000..ebb2899c12fe --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/aolib.h @@ -0,0 +1,832 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * TCP-AO selftest library. Provides helpers to unshare network + * namespaces, create veth, assign ip addresses, set routes, + * manipulate socket options, read network counter and etc. + * Author: Dmitry Safonov <dima@arista.com> + */ +#ifndef _AOLIB_H_ +#define _AOLIB_H_ + +#include <arpa/inet.h> +#include <errno.h> +#include <linux/snmp.h> +#include <linux/tcp.h> +#include <netinet/in.h> +#include <stdarg.h> +#include <stdbool.h> +#include <stdlib.h> +#include <stdio.h> +#include <string.h> +#include <sys/syscall.h> +#include <unistd.h> + +#include "../../../../../include/linux/stringify.h" +#include "../../../../../include/linux/bits.h" + +#ifndef SOL_TCP +/* can't include <netinet/tcp.h> as including <linux/tcp.h> */ +# define SOL_TCP 6 /* TCP level */ +#endif + +/* Working around ksft, see the comment in lib/setup.c */ +extern void __test_msg(const char *buf); +extern void __test_ok(const char *buf); +extern void __test_fail(const char *buf); +extern void __test_xfail(const char *buf); +extern void __test_error(const char *buf); +extern void __test_skip(const char *buf); + +static inline char *test_snprintf(const char *fmt, va_list vargs) +{ + char *ret = NULL; + size_t size = 0; + va_list tmp; + int n = 0; + + va_copy(tmp, vargs); + n = vsnprintf(ret, size, fmt, tmp); + va_end(tmp); + if (n < 0) + return NULL; + + size = n + 1; + ret = malloc(size); + if (!ret) + return NULL; + + n = vsnprintf(ret, size, fmt, vargs); + if (n < 0 || n > size - 1) { + free(ret); + return NULL; + } + return ret; +} + +static __printf(1, 2) inline char *test_sprintf(const char *fmt, ...) +{ + va_list vargs; + char *ret; + + va_start(vargs, fmt); + ret = test_snprintf(fmt, vargs); + va_end(vargs); + + return ret; +} + +static __printf(2, 3) inline void __test_print(void (*fn)(const char *), + const char *fmt, ...) +{ + va_list vargs; + char *msg; + + va_start(vargs, fmt); + msg = test_snprintf(fmt, vargs); + va_end(vargs); + + if (!msg) + return; + + fn(msg); + free(msg); +} + +#define test_print(fmt, ...) \ + __test_print(__test_msg, "%ld[%s:%u] " fmt "\n", \ + syscall(SYS_gettid), \ + __FILE__, __LINE__, ##__VA_ARGS__) + +#define test_ok(fmt, ...) \ + __test_print(__test_ok, fmt "\n", ##__VA_ARGS__) +#define test_skip(fmt, ...) \ + __test_print(__test_skip, fmt "\n", ##__VA_ARGS__) +#define test_xfail(fmt, ...) \ + __test_print(__test_xfail, fmt "\n", ##__VA_ARGS__) + +#define test_fail(fmt, ...) \ +do { \ + if (errno) \ + __test_print(__test_fail, fmt ": %m\n", ##__VA_ARGS__); \ + else \ + __test_print(__test_fail, fmt "\n", ##__VA_ARGS__); \ + test_failed(); \ +} while (0) + +#define KSFT_FAIL 1 +#define test_error(fmt, ...) \ +do { \ + if (errno) \ + __test_print(__test_error, "%ld[%s:%u] " fmt ": %m\n", \ + syscall(SYS_gettid), __FILE__, __LINE__, \ + ##__VA_ARGS__); \ + else \ + __test_print(__test_error, "%ld[%s:%u] " fmt "\n", \ + syscall(SYS_gettid), __FILE__, __LINE__, \ + ##__VA_ARGS__); \ + exit(KSFT_FAIL); \ +} while (0) + +enum test_fault { + FAULT_TIMEOUT = 1, + FAULT_KEYREJECT, + FAULT_PREINSTALL_AO, + FAULT_PREINSTALL_MD5, + FAULT_POSTINSTALL, + FAULT_BUSY, + FAULT_CURRNEXT, + FAULT_FIXME, +}; +typedef enum test_fault fault_t; + +enum test_needs_kconfig { + KCONFIG_NET_NS = 0, /* required */ + KCONFIG_VETH, /* required */ + KCONFIG_TCP_AO, /* required */ + KCONFIG_TCP_MD5, /* optional, for TCP-MD5 features */ + KCONFIG_NET_VRF, /* optional, for L3/VRF testing */ + KCONFIG_FTRACE, /* optional, for tracepoints checks */ + __KCONFIG_LAST__ +}; +extern bool kernel_config_has(enum test_needs_kconfig k); +extern const char *tests_skip_reason[__KCONFIG_LAST__]; +static inline bool should_skip_test(const char *tst_name, + enum test_needs_kconfig k) +{ + if (kernel_config_has(k)) + return false; + test_skip("%s: %s", tst_name, tests_skip_reason[k]); + return true; +} + +union tcp_addr { + struct in_addr a4; + struct in6_addr a6; +}; + +typedef void *(*thread_fn)(void *); +extern void test_failed(void); +extern void __test_init(unsigned int ntests, int family, unsigned int prefix, + union tcp_addr addr1, union tcp_addr addr2, + thread_fn peer1, thread_fn peer2); + +static inline void test_init2(unsigned int ntests, + thread_fn peer1, thread_fn peer2, + int family, unsigned int prefix, + const char *addr1, const char *addr2) +{ + union tcp_addr taddr1, taddr2; + + if (inet_pton(family, addr1, &taddr1) != 1) + test_error("Can't convert ip address %s", addr1); + if (inet_pton(family, addr2, &taddr2) != 1) + test_error("Can't convert ip address %s", addr2); + + __test_init(ntests, family, prefix, taddr1, taddr2, peer1, peer2); +} +extern void test_add_destructor(void (*d)(void)); +extern void test_init_ftrace(int nsfd1, int nsfd2); +extern int test_setup_tracing(void); + +/* To adjust optmem socket limit, approximately estimate a number, + * that is bigger than sizeof(struct tcp_ao_key). + */ +#define KERNEL_TCP_AO_KEY_SZ_ROUND_UP 300 + +extern void test_set_optmem(size_t value); +extern size_t test_get_optmem(void); + +extern const struct sockaddr_in6 addr_any6; +extern const struct sockaddr_in addr_any4; + +#ifdef IPV6_TEST +# define __TEST_CLIENT_IP(n) ("2001:db8:" __stringify(n) "::1") +# define TEST_CLIENT_IP __TEST_CLIENT_IP(1) +# define TEST_WRONG_IP "2001:db8:253::1" +# define TEST_SERVER_IP "2001:db8:254::1" +# define TEST_NETWORK "2001::" +# define TEST_PREFIX 128 +# define TEST_FAMILY AF_INET6 +# define SOCKADDR_ANY addr_any6 +# define sockaddr_af struct sockaddr_in6 +#else +# define __TEST_CLIENT_IP(n) ("10.0." __stringify(n) ".1") +# define TEST_CLIENT_IP __TEST_CLIENT_IP(1) +# define TEST_WRONG_IP "10.0.253.1" +# define TEST_SERVER_IP "10.0.254.1" +# define TEST_NETWORK "10.0.0.0" +# define TEST_PREFIX 32 +# define TEST_FAMILY AF_INET +# define SOCKADDR_ANY addr_any4 +# define sockaddr_af struct sockaddr_in +#endif + +static inline union tcp_addr gen_tcp_addr(union tcp_addr net, size_t n) +{ + union tcp_addr ret = net; + +#ifdef IPV6_TEST + ret.a6.s6_addr32[3] = htonl(n & (BIT(32) - 1)); + ret.a6.s6_addr32[2] = htonl((n >> 32) & (BIT(32) - 1)); +#else + ret.a4.s_addr = htonl(ntohl(net.a4.s_addr) + n); +#endif + + return ret; +} + +static inline void tcp_addr_to_sockaddr_in(void *dest, + const union tcp_addr *src, + unsigned int port) +{ + sockaddr_af *out = dest; + + memset(out, 0, sizeof(*out)); +#ifdef IPV6_TEST + out->sin6_family = AF_INET6; + out->sin6_port = port; + out->sin6_addr = src->a6; +#else + out->sin_family = AF_INET; + out->sin_port = port; + out->sin_addr = src->a4; +#endif +} + +static inline void test_init(unsigned int ntests, + thread_fn peer1, thread_fn peer2) +{ + test_init2(ntests, peer1, peer2, TEST_FAMILY, TEST_PREFIX, + TEST_SERVER_IP, TEST_CLIENT_IP); +} +extern void synchronize_threads(void); +extern void switch_ns(int fd); +extern int switch_save_ns(int fd); +extern void switch_close_ns(int fd); + +extern __thread union tcp_addr this_ip_addr; +extern __thread union tcp_addr this_ip_dest; +extern int test_family; + +extern void randomize_buffer(void *buf, size_t buflen); +extern __printf(3, 4) int test_echo(const char *fname, bool append, + const char *fmt, ...); + +extern int open_netns(void); +extern int unshare_open_netns(void); +extern const char veth_name[]; +extern int add_veth(const char *name, int nsfda, int nsfdb); +extern int add_vrf(const char *name, uint32_t tabid, int ifindex, int nsfd); +extern int ip_addr_add(const char *intf, int family, + union tcp_addr addr, uint8_t prefix); +extern int ip_route_add(const char *intf, int family, + union tcp_addr src, union tcp_addr dst); +extern int ip_route_add_vrf(const char *intf, int family, + union tcp_addr src, union tcp_addr dst, + uint8_t vrf); +extern int link_set_up(const char *intf); + +extern const unsigned int test_server_port; +extern int test_wait_fd(int sk, time_t sec, bool write); +extern int __test_connect_socket(int sk, const char *device, + void *addr, size_t addr_sz, bool async); +extern int __test_listen_socket(int backlog, void *addr, size_t addr_sz); + +static inline int test_listen_socket(const union tcp_addr taddr, + unsigned int port, int backlog) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port)); + return __test_listen_socket(backlog, (void *)&addr, sizeof(addr)); +} + +/* + * In order for selftests to work under CONFIG_CRYPTO_FIPS=y, + * the password should be loger than 14 bytes, see hmac_setkey() + */ +#define TEST_TCP_AO_MINKEYLEN 14 +#define DEFAULT_TEST_PASSWORD "In this hour, I do not believe that any darkness will endure." + +#ifndef DEFAULT_TEST_ALGO +#define DEFAULT_TEST_ALGO "cmac(aes128)" +#endif + +#ifdef IPV6_TEST +#define DEFAULT_TEST_PREFIX 128 +#else +#define DEFAULT_TEST_PREFIX 32 +#endif + +/* + * Timeout on syscalls where failure is not expected. + * You may want to rise it if the test machine is very busy. + */ +#ifndef TEST_TIMEOUT_SEC +#define TEST_TIMEOUT_SEC 5 +#endif + +/* + * Timeout on connect() where a failure is expected. + * If set to 0 - kernel will try to retransmit SYN number of times, set in + * /proc/sys/net/ipv4/tcp_syn_retries + * By default set to 1 to make tests pass faster on non-busy machine. + * [in process of removal, don't use in new tests] + */ +#ifndef TEST_RETRANSMIT_SEC +#define TEST_RETRANSMIT_SEC 1 +#endif + +static inline int _test_connect_socket(int sk, const union tcp_addr taddr, + unsigned int port, bool async) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port)); + return __test_connect_socket(sk, veth_name, + (void *)&addr, sizeof(addr), async); +} + +static inline int test_connect_socket(int sk, const union tcp_addr taddr, + unsigned int port) +{ + return _test_connect_socket(sk, taddr, port, false); +} + +extern int __test_set_md5(int sk, void *addr, size_t addr_sz, + uint8_t prefix, int vrf, const char *password); +static inline int test_set_md5(int sk, const union tcp_addr in_addr, + uint8_t prefix, int vrf, const char *password) +{ + sockaddr_af addr; + + if (prefix > DEFAULT_TEST_PREFIX) + prefix = DEFAULT_TEST_PREFIX; + + tcp_addr_to_sockaddr_in(&addr, &in_addr, 0); + return __test_set_md5(sk, (void *)&addr, sizeof(addr), + prefix, vrf, password); +} + +extern int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg, + void *addr, size_t addr_sz, bool set_current, bool set_rnext, + uint8_t prefix, uint8_t vrf, + uint8_t sndid, uint8_t rcvid, uint8_t maclen, + uint8_t keyflags, uint8_t keylen, const char *key); + +static inline int test_prepare_key(struct tcp_ao_add *ao, + const char *alg, union tcp_addr taddr, + bool set_current, bool set_rnext, + uint8_t prefix, uint8_t vrf, + uint8_t sndid, uint8_t rcvid, uint8_t maclen, + uint8_t keyflags, uint8_t keylen, const char *key) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &taddr, 0); + return test_prepare_key_sockaddr(ao, alg, (void *)&addr, sizeof(addr), + set_current, set_rnext, prefix, vrf, sndid, rcvid, + maclen, keyflags, keylen, key); +} + +static inline int test_prepare_def_key(struct tcp_ao_add *ao, + const char *key, uint8_t keyflags, + union tcp_addr in_addr, uint8_t prefix, uint8_t vrf, + uint8_t sndid, uint8_t rcvid) +{ + if (prefix > DEFAULT_TEST_PREFIX) + prefix = DEFAULT_TEST_PREFIX; + + return test_prepare_key(ao, DEFAULT_TEST_ALGO, in_addr, false, false, + prefix, vrf, sndid, rcvid, 0, keyflags, + strlen(key), key); +} + +extern int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out, + void *addr, size_t addr_sz, + uint8_t prefix, uint8_t sndid, uint8_t rcvid); +extern int test_get_ao_info(int sk, struct tcp_ao_info_opt *out); +extern int test_set_ao_info(int sk, struct tcp_ao_info_opt *in); +extern int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a, + const struct tcp_ao_getsockopt *b); +extern int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a, + const struct tcp_ao_info_opt *b); + +static inline int test_verify_socket_key(int sk, struct tcp_ao_add *key) +{ + struct tcp_ao_getsockopt key2 = {}; + int err; + + err = test_get_one_ao(sk, &key2, &key->addr, sizeof(key->addr), + key->prefix, key->sndid, key->rcvid); + if (err) + return err; + + return test_cmp_getsockopt_setsockopt(key, &key2); +} + +static inline int test_add_key_vrf(int sk, + const char *key, uint8_t keyflags, + union tcp_addr in_addr, uint8_t prefix, + uint8_t vrf, uint8_t sndid, uint8_t rcvid) +{ + struct tcp_ao_add tmp = {}; + int err; + + err = test_prepare_def_key(&tmp, key, keyflags, in_addr, prefix, + vrf, sndid, rcvid); + if (err) + return err; + + err = setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)); + if (err < 0) + return -errno; + + return test_verify_socket_key(sk, &tmp); +} + +static inline int test_add_key(int sk, const char *key, + union tcp_addr in_addr, uint8_t prefix, + uint8_t sndid, uint8_t rcvid) +{ + return test_add_key_vrf(sk, key, 0, in_addr, prefix, 0, sndid, rcvid); +} + +static inline int test_verify_socket_ao(int sk, struct tcp_ao_info_opt *ao) +{ + struct tcp_ao_info_opt ao2 = {}; + int err; + + err = test_get_ao_info(sk, &ao2); + if (err) + return err; + + return test_cmp_getsockopt_setsockopt_ao(ao, &ao2); +} + +static inline int test_set_ao_flags(int sk, bool ao_required, bool accept_icmps) +{ + struct tcp_ao_info_opt ao = {}; + int err; + + err = test_get_ao_info(sk, &ao); + /* Maybe ao_info wasn't allocated yet */ + if (err && err != -ENOENT) + return err; + + ao.ao_required = !!ao_required; + ao.accept_icmps = !!accept_icmps; + err = test_set_ao_info(sk, &ao); + if (err) + return err; + + return test_verify_socket_ao(sk, &ao); +} + +extern ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec); +extern int test_client_verify(int sk, const size_t msg_len, const size_t nr); + +struct tcp_ao_key_counters { + uint8_t sndid; + uint8_t rcvid; + uint64_t pkt_good; + uint64_t pkt_bad; +}; + +struct tcp_ao_counters { + /* per-netns */ + uint64_t netns_ao_good; + uint64_t netns_ao_bad; + uint64_t netns_ao_key_not_found; + uint64_t netns_ao_required; + uint64_t netns_ao_dropped_icmp; + /* per-socket */ + uint64_t ao_info_pkt_good; + uint64_t ao_info_pkt_bad; + uint64_t ao_info_pkt_key_not_found; + uint64_t ao_info_pkt_ao_required; + uint64_t ao_info_pkt_dropped_icmp; + /* per-key */ + size_t nr_keys; + struct tcp_ao_key_counters *key_cnts; +}; + +struct tcp_counters { + struct tcp_ao_counters ao; + uint64_t netns_md5_notfound; + uint64_t netns_md5_unexpected; + uint64_t netns_md5_failure; +}; + +extern int test_get_tcp_counters(int sk, struct tcp_counters *out); + +#define TEST_CNT_KEY_GOOD BIT(0) +#define TEST_CNT_KEY_BAD BIT(1) +#define TEST_CNT_SOCK_GOOD BIT(2) +#define TEST_CNT_SOCK_BAD BIT(3) +#define TEST_CNT_SOCK_KEY_NOT_FOUND BIT(4) +#define TEST_CNT_SOCK_AO_REQUIRED BIT(5) +#define TEST_CNT_SOCK_DROPPED_ICMP BIT(6) +#define TEST_CNT_NS_GOOD BIT(7) +#define TEST_CNT_NS_BAD BIT(8) +#define TEST_CNT_NS_KEY_NOT_FOUND BIT(9) +#define TEST_CNT_NS_AO_REQUIRED BIT(10) +#define TEST_CNT_NS_DROPPED_ICMP BIT(11) +#define TEST_CNT_NS_MD5_NOT_FOUND BIT(12) +#define TEST_CNT_NS_MD5_UNEXPECTED BIT(13) +#define TEST_CNT_NS_MD5_FAILURE BIT(14) +typedef uint16_t test_cnt; + +#define _for_each_counter(f) \ +do { \ + /* per-netns */ \ + f(ao.netns_ao_good, TEST_CNT_NS_GOOD); \ + f(ao.netns_ao_bad, TEST_CNT_NS_BAD); \ + f(ao.netns_ao_key_not_found, TEST_CNT_NS_KEY_NOT_FOUND); \ + f(ao.netns_ao_required, TEST_CNT_NS_AO_REQUIRED); \ + f(ao.netns_ao_dropped_icmp, TEST_CNT_NS_DROPPED_ICMP); \ + /* per-socket */ \ + f(ao.ao_info_pkt_good, TEST_CNT_SOCK_GOOD); \ + f(ao.ao_info_pkt_bad, TEST_CNT_SOCK_BAD); \ + f(ao.ao_info_pkt_key_not_found, TEST_CNT_SOCK_KEY_NOT_FOUND); \ + f(ao.ao_info_pkt_ao_required, TEST_CNT_SOCK_AO_REQUIRED); \ + f(ao.ao_info_pkt_dropped_icmp, TEST_CNT_SOCK_DROPPED_ICMP); \ + /* non-AO */ \ + f(netns_md5_notfound, TEST_CNT_NS_MD5_NOT_FOUND); \ + f(netns_md5_unexpected, TEST_CNT_NS_MD5_UNEXPECTED); \ + f(netns_md5_failure, TEST_CNT_NS_MD5_FAILURE); \ +} while (0) + +#define TEST_CNT_AO_GOOD (TEST_CNT_SOCK_GOOD | TEST_CNT_NS_GOOD) +#define TEST_CNT_AO_BAD (TEST_CNT_SOCK_BAD | TEST_CNT_NS_BAD) +#define TEST_CNT_AO_KEY_NOT_FOUND (TEST_CNT_SOCK_KEY_NOT_FOUND | \ + TEST_CNT_NS_KEY_NOT_FOUND) +#define TEST_CNT_AO_REQUIRED (TEST_CNT_SOCK_AO_REQUIRED | \ + TEST_CNT_NS_AO_REQUIRED) +#define TEST_CNT_AO_DROPPED_ICMP (TEST_CNT_SOCK_DROPPED_ICMP | \ + TEST_CNT_NS_DROPPED_ICMP) +#define TEST_CNT_GOOD (TEST_CNT_KEY_GOOD | TEST_CNT_AO_GOOD) +#define TEST_CNT_BAD (TEST_CNT_KEY_BAD | TEST_CNT_AO_BAD) + +extern test_cnt test_cmp_counters(struct tcp_counters *before, + struct tcp_counters *after); +extern int test_assert_counters_sk(const char *tst_name, + struct tcp_counters *before, struct tcp_counters *after, + test_cnt expected); +extern int test_assert_counters_key(const char *tst_name, + struct tcp_ao_counters *before, struct tcp_ao_counters *after, + test_cnt expected, int sndid, int rcvid); +extern void test_tcp_counters_free(struct tcp_counters *cnts); + +/* + * Polling for netns and socket counters during select()/connect() and also + * client/server messaging. Instead of constant timeout on underlying select(), + * check the counters and return early. This allows to pass the tests where + * timeout is expected without waiting for that fixing timeout (tests speed-up). + * Previously shorter timeouts were used for tests expecting to time out, + * but that leaded to sporadic false positives on counter checks failures, + * as one second timeouts aren't enough for TCP retransmit. + * + * Two sides of the socketpair (client/server) should synchronize failures + * using a shared variable *err, so that they can detect the other side's + * failure. + */ +extern int test_skpair_wait_poll(int sk, bool write, test_cnt cond, + volatile int *err); +extern int _test_skpair_connect_poll(int sk, const char *device, + void *addr, size_t addr_sz, + test_cnt cond, volatile int *err); +static inline int test_skpair_connect_poll(int sk, const union tcp_addr taddr, + unsigned int port, + test_cnt cond, volatile int *err) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &taddr, htons(port)); + return _test_skpair_connect_poll(sk, veth_name, + (void *)&addr, sizeof(addr), cond, err); +} + +extern int test_skpair_client(int sk, const size_t msg_len, const size_t nr, + test_cnt cond, volatile int *err); +extern int test_skpair_server(int sk, ssize_t quota, + test_cnt cond, volatile int *err); + +/* + * Frees buffers allocated in test_get_tcp_counters(). + * The function doesn't expect new keys or keys removed between calls + * to test_get_tcp_counters(). Check key counters manually if they + * may change. + */ +static inline int test_assert_counters(const char *tst_name, + struct tcp_counters *before, + struct tcp_counters *after, + test_cnt expected) +{ + int ret; + + ret = test_assert_counters_sk(tst_name, before, after, expected); + if (ret) + goto out; + ret = test_assert_counters_key(tst_name, &before->ao, &after->ao, + expected, -1, -1); +out: + test_tcp_counters_free(before); + test_tcp_counters_free(after); + return ret; +} + +struct netstat; +extern struct netstat *netstat_read(void); +extern void netstat_free(struct netstat *ns); +extern void netstat_print_diff(struct netstat *nsa, struct netstat *nsb); +extern uint64_t netstat_get(struct netstat *ns, + const char *name, bool *not_found); + +static inline uint64_t netstat_get_one(const char *name, bool *not_found) +{ + struct netstat *ns = netstat_read(); + uint64_t ret; + + ret = netstat_get(ns, name, not_found); + + netstat_free(ns); + return ret; +} + +struct tcp_sock_queue { + uint32_t seq; + void *buf; +}; + +struct tcp_sock_state { + struct tcp_info info; + struct tcp_repair_window trw; + struct tcp_sock_queue out; + int outq_len; /* output queue size (not sent + not acked) */ + int outq_nsd_len; /* output queue size (not sent only) */ + struct tcp_sock_queue in; + int inq_len; + int mss; + int timestamp; +}; + +extern void __test_sock_checkpoint(int sk, struct tcp_sock_state *state, + void *addr, size_t addr_size); +static inline void test_sock_checkpoint(int sk, struct tcp_sock_state *state, + sockaddr_af *saddr) +{ + __test_sock_checkpoint(sk, state, saddr, sizeof(*saddr)); +} +extern void test_ao_checkpoint(int sk, struct tcp_ao_repair *state); +extern void __test_sock_restore(int sk, const char *device, + struct tcp_sock_state *state, + void *saddr, void *daddr, size_t addr_size); +static inline void test_sock_restore(int sk, struct tcp_sock_state *state, + sockaddr_af *saddr, + const union tcp_addr daddr, + unsigned int dport) +{ + sockaddr_af addr; + + tcp_addr_to_sockaddr_in(&addr, &daddr, htons(dport)); + __test_sock_restore(sk, veth_name, state, saddr, &addr, sizeof(addr)); +} +extern void test_ao_restore(int sk, struct tcp_ao_repair *state); +extern void test_sock_state_free(struct tcp_sock_state *state); +extern void test_enable_repair(int sk); +extern void test_disable_repair(int sk); +extern void test_kill_sk(int sk); +static inline int test_add_repaired_key(int sk, + const char *key, uint8_t keyflags, + union tcp_addr in_addr, uint8_t prefix, + uint8_t sndid, uint8_t rcvid) +{ + struct tcp_ao_add tmp = {}; + int err; + + err = test_prepare_def_key(&tmp, key, keyflags, in_addr, prefix, + 0, sndid, rcvid); + if (err) + return err; + + tmp.set_current = 1; + tmp.set_rnext = 1; + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)) < 0) + return -errno; + + return test_verify_socket_key(sk, &tmp); +} + +#define DEFAULT_FTRACE_BUFFER_KB 10000 +#define DEFAULT_TRACER_LINES_ARR 200 +struct test_ftracer; +extern uint64_t ns_cookie1, ns_cookie2; + +enum ftracer_op { + FTRACER_LINE_DISCARD = 0, + FTRACER_LINE_PRESERVE, + FTRACER_EXIT, +}; + +extern struct test_ftracer *create_ftracer(const char *name, + enum ftracer_op (*process_line)(const char *line), + void (*destructor)(struct test_ftracer *tracer), + bool (*expecting_more)(void), + size_t lines_buf_sz, size_t buffer_size_kb); +extern int setup_trace_event(struct test_ftracer *tracer, + const char *event, const char *filter); +extern void destroy_ftracer(struct test_ftracer *tracer); +extern const size_t tracer_get_savedlines_nr(struct test_ftracer *tracer); +extern const char **tracer_get_savedlines(struct test_ftracer *tracer); + +enum trace_events { + /* TCP_HASH_EVENT */ + TCP_HASH_BAD_HEADER = 0, + TCP_HASH_MD5_REQUIRED, + TCP_HASH_MD5_UNEXPECTED, + TCP_HASH_MD5_MISMATCH, + TCP_HASH_AO_REQUIRED, + /* TCP_AO_EVENT */ + TCP_AO_HANDSHAKE_FAILURE, + TCP_AO_WRONG_MACLEN, + TCP_AO_MISMATCH, + TCP_AO_KEY_NOT_FOUND, + TCP_AO_RNEXT_REQUEST, + /* TCP_AO_EVENT_SK */ + TCP_AO_SYNACK_NO_KEY, + /* TCP_AO_EVENT_SNE */ + TCP_AO_SND_SNE_UPDATE, + TCP_AO_RCV_SNE_UPDATE, + __MAX_TRACE_EVENTS +}; + +extern int __trace_event_expect(enum trace_events type, int family, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, int L3index, + int fin, int syn, int rst, int psh, int ack, + int keyid, int rnext, int maclen, int sne); + +static inline void trace_hash_event_expect(enum trace_events type, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, int L3index, + int fin, int syn, int rst, int psh, int ack) +{ + int err; + + err = __trace_event_expect(type, TEST_FAMILY, src, dst, + src_port, dst_port, L3index, + fin, syn, rst, psh, ack, + -1, -1, -1, -1); + if (err) + test_error("Couldn't add a trace event: %d", err); +} + +static inline void trace_ao_event_expect(enum trace_events type, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, int L3index, + int fin, int syn, int rst, int psh, int ack, + int keyid, int rnext, int maclen) +{ + int err; + + err = __trace_event_expect(type, TEST_FAMILY, src, dst, + src_port, dst_port, L3index, + fin, syn, rst, psh, ack, + keyid, rnext, maclen, -1); + if (err) + test_error("Couldn't add a trace event: %d", err); +} + +static inline void trace_ao_event_sk_expect(enum trace_events type, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, + int keyid, int rnext) +{ + int err; + + err = __trace_event_expect(type, TEST_FAMILY, src, dst, + src_port, dst_port, -1, + -1, -1, -1, -1, -1, + keyid, rnext, -1, -1); + if (err) + test_error("Couldn't add a trace event: %d", err); +} + +static inline void trace_ao_event_sne_expect(enum trace_events type, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, int sne) +{ + int err; + + err = __trace_event_expect(type, TEST_FAMILY, src, dst, + src_port, dst_port, -1, + -1, -1, -1, -1, -1, + -1, -1, -1, sne); + if (err) + test_error("Couldn't add a trace event: %d", err); +} + +extern int setup_aolib_ftracer(void); + +#endif /* _AOLIB_H_ */ diff --git a/tools/testing/selftests/net/tcp_ao/lib/ftrace-tcp.c b/tools/testing/selftests/net/tcp_ao/lib/ftrace-tcp.c new file mode 100644 index 000000000000..27403f875054 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/ftrace-tcp.c @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <inttypes.h> +#include <pthread.h> +#include "aolib.h" + +static const char *trace_event_names[__MAX_TRACE_EVENTS] = { + /* TCP_HASH_EVENT */ + "tcp_hash_bad_header", + "tcp_hash_md5_required", + "tcp_hash_md5_unexpected", + "tcp_hash_md5_mismatch", + "tcp_hash_ao_required", + /* TCP_AO_EVENT */ + "tcp_ao_handshake_failure", + "tcp_ao_wrong_maclen", + "tcp_ao_mismatch", + "tcp_ao_key_not_found", + "tcp_ao_rnext_request", + /* TCP_AO_EVENT_SK */ + "tcp_ao_synack_no_key", + /* TCP_AO_EVENT_SNE */ + "tcp_ao_snd_sne_update", + "tcp_ao_rcv_sne_update" +}; + +struct expected_trace_point { + /* required */ + enum trace_events type; + int family; + union tcp_addr src; + union tcp_addr dst; + + /* optional */ + int src_port; + int dst_port; + int L3index; + + int fin; + int syn; + int rst; + int psh; + int ack; + + int keyid; + int rnext; + int maclen; + int sne; + + size_t matched; +}; + +static struct expected_trace_point *exp_tps; +static size_t exp_tps_nr; +static size_t exp_tps_size; +static pthread_mutex_t exp_tps_mutex = PTHREAD_MUTEX_INITIALIZER; + +int __trace_event_expect(enum trace_events type, int family, + union tcp_addr src, union tcp_addr dst, + int src_port, int dst_port, int L3index, + int fin, int syn, int rst, int psh, int ack, + int keyid, int rnext, int maclen, int sne) +{ + struct expected_trace_point new_tp = { + .type = type, + .family = family, + .src = src, + .dst = dst, + .src_port = src_port, + .dst_port = dst_port, + .L3index = L3index, + .fin = fin, + .syn = syn, + .rst = rst, + .psh = psh, + .ack = ack, + .keyid = keyid, + .rnext = rnext, + .maclen = maclen, + .sne = sne, + .matched = 0, + }; + int ret = 0; + + if (!kernel_config_has(KCONFIG_FTRACE)) + return 0; + + pthread_mutex_lock(&exp_tps_mutex); + if (exp_tps_nr == exp_tps_size) { + struct expected_trace_point *tmp; + + if (exp_tps_size == 0) + exp_tps_size = 10; + else + exp_tps_size = exp_tps_size * 1.6; + + tmp = reallocarray(exp_tps, exp_tps_size, sizeof(exp_tps[0])); + if (!tmp) { + ret = -ENOMEM; + goto out; + } + exp_tps = tmp; + } + exp_tps[exp_tps_nr] = new_tp; + exp_tps_nr++; +out: + pthread_mutex_unlock(&exp_tps_mutex); + return ret; +} + +static void free_expected_events(void) +{ + /* We're from the process destructor - not taking the mutex */ + exp_tps_size = 0; + exp_tps = NULL; + free(exp_tps); +} + +struct trace_point { + int family; + union tcp_addr src; + union tcp_addr dst; + unsigned int src_port; + unsigned int dst_port; + int L3index; + unsigned int fin:1, + syn:1, + rst:1, + psh:1, + ack:1; + + unsigned int keyid; + unsigned int rnext; + unsigned int maclen; + + unsigned int sne; +}; + +static bool lookup_expected_event(int event_type, struct trace_point *e) +{ + size_t i; + + pthread_mutex_lock(&exp_tps_mutex); + for (i = 0; i < exp_tps_nr; i++) { + struct expected_trace_point *p = &exp_tps[i]; + size_t sk_size; + + if (p->type != event_type) + continue; + if (p->family != e->family) + continue; + if (p->family == AF_INET) + sk_size = sizeof(p->src.a4); + else + sk_size = sizeof(p->src.a6); + if (memcmp(&p->src, &e->src, sk_size)) + continue; + if (memcmp(&p->dst, &e->dst, sk_size)) + continue; + if (p->src_port >= 0 && p->src_port != e->src_port) + continue; + if (p->dst_port >= 0 && p->dst_port != e->dst_port) + continue; + if (p->L3index >= 0 && p->L3index != e->L3index) + continue; + + if (p->fin >= 0 && p->fin != e->fin) + continue; + if (p->syn >= 0 && p->syn != e->syn) + continue; + if (p->rst >= 0 && p->rst != e->rst) + continue; + if (p->psh >= 0 && p->psh != e->psh) + continue; + if (p->ack >= 0 && p->ack != e->ack) + continue; + + if (p->keyid >= 0 && p->keyid != e->keyid) + continue; + if (p->rnext >= 0 && p->rnext != e->rnext) + continue; + if (p->maclen >= 0 && p->maclen != e->maclen) + continue; + if (p->sne >= 0 && p->sne != e->sne) + continue; + p->matched++; + pthread_mutex_unlock(&exp_tps_mutex); + return true; + } + pthread_mutex_unlock(&exp_tps_mutex); + return false; +} + +static int check_event_type(const char *line) +{ + size_t i; + + /* + * This should have been a set or hashmap, but it's a selftest, + * so... KISS. + */ + for (i = 0; i < __MAX_TRACE_EVENTS; i++) { + if (!strncmp(trace_event_names[i], line, strlen(trace_event_names[i]))) + return i; + } + return -1; +} + +static bool event_has_flags(enum trace_events event) +{ + switch (event) { + case TCP_HASH_BAD_HEADER: + case TCP_HASH_MD5_REQUIRED: + case TCP_HASH_MD5_UNEXPECTED: + case TCP_HASH_MD5_MISMATCH: + case TCP_HASH_AO_REQUIRED: + case TCP_AO_HANDSHAKE_FAILURE: + case TCP_AO_WRONG_MACLEN: + case TCP_AO_MISMATCH: + case TCP_AO_KEY_NOT_FOUND: + case TCP_AO_RNEXT_REQUEST: + return true; + default: + return false; + } +} + +static int tracer_ip_split(int family, char *src, char **addr, char **port) +{ + char *p; + + if (family == AF_INET) { + /* fomat is <addr>:port, i.e.: 10.0.254.1:7015 */ + *addr = src; + p = strchr(src, ':'); + if (!p) { + test_print("Couldn't parse trace event addr:port %s", src); + return -EINVAL; + } + *p++ = '\0'; + *port = p; + return 0; + } + if (family != AF_INET6) + return -EAFNOSUPPORT; + + /* format is [<addr>]:port, i.e.: [2001:db8:254::1]:7013 */ + *addr = strchr(src, '['); + p = strchr(src, ']'); + + if (!p || !*addr) { + test_print("Couldn't parse trace event [addr]:port %s", src); + return -EINVAL; + } + + *addr = *addr + 1; /* '[' */ + *p++ = '\0'; /* ']' */ + if (*p != ':') { + test_print("Couldn't parse trace event :port %s", p); + return -EINVAL; + } + *p++ = '\0'; /* ':' */ + *port = p; + return 0; +} + +static int tracer_scan_address(int family, char *src, + union tcp_addr *dst, unsigned int *port) +{ + char *addr, *port_str; + int ret; + + ret = tracer_ip_split(family, src, &addr, &port_str); + if (ret) + return ret; + + if (inet_pton(family, addr, dst) != 1) { + test_print("Couldn't parse trace event addr %s", addr); + return -EINVAL; + } + errno = 0; + *port = (unsigned int)strtoul(port_str, NULL, 10); + if (errno != 0) { + test_print("Couldn't parse trace event port %s", port_str); + return -errno; + } + return 0; +} + +static int tracer_scan_event(const char *line, enum trace_events event, + struct trace_point *out) +{ + char *src = NULL, *dst = NULL, *family = NULL; + char fin, syn, rst, psh, ack; + int nr_matched, ret = 0; + uint64_t netns_cookie; + + switch (event) { + case TCP_HASH_BAD_HEADER: + case TCP_HASH_MD5_REQUIRED: + case TCP_HASH_MD5_UNEXPECTED: + case TCP_HASH_MD5_MISMATCH: + case TCP_HASH_AO_REQUIRED: { + nr_matched = sscanf(line, "%*s net=%" PRIu64 " state%*s family=%ms src=%ms dest=%ms L3index=%d [%c%c%c%c%c]", + &netns_cookie, &family, + &src, &dst, &out->L3index, + &fin, &syn, &rst, &psh, &ack); + if (nr_matched != 10) + test_print("Couldn't parse trace event, matched = %d/10", + nr_matched); + break; + } + case TCP_AO_HANDSHAKE_FAILURE: + case TCP_AO_WRONG_MACLEN: + case TCP_AO_MISMATCH: + case TCP_AO_KEY_NOT_FOUND: + case TCP_AO_RNEXT_REQUEST: { + nr_matched = sscanf(line, "%*s net=%" PRIu64 " state%*s family=%ms src=%ms dest=%ms L3index=%d [%c%c%c%c%c] keyid=%u rnext=%u maclen=%u", + &netns_cookie, &family, + &src, &dst, &out->L3index, + &fin, &syn, &rst, &psh, &ack, + &out->keyid, &out->rnext, &out->maclen); + if (nr_matched != 13) + test_print("Couldn't parse trace event, matched = %d/13", + nr_matched); + break; + } + case TCP_AO_SYNACK_NO_KEY: { + nr_matched = sscanf(line, "%*s net=%" PRIu64 " state%*s family=%ms src=%ms dest=%ms keyid=%u rnext=%u", + &netns_cookie, &family, + &src, &dst, &out->keyid, &out->rnext); + if (nr_matched != 6) + test_print("Couldn't parse trace event, matched = %d/6", + nr_matched); + break; + } + case TCP_AO_SND_SNE_UPDATE: + case TCP_AO_RCV_SNE_UPDATE: { + nr_matched = sscanf(line, "%*s net=%" PRIu64 " state%*s family=%ms src=%ms dest=%ms sne=%u", + &netns_cookie, &family, + &src, &dst, &out->sne); + if (nr_matched != 5) + test_print("Couldn't parse trace event, matched = %d/5", + nr_matched); + break; + } + default: + return -1; + } + + if (family) { + if (!strcmp(family, "AF_INET")) { + out->family = AF_INET; + } else if (!strcmp(family, "AF_INET6")) { + out->family = AF_INET6; + } else { + test_print("Couldn't parse trace event family %s", family); + ret = -EINVAL; + goto out_free; + } + } + + if (event_has_flags(event)) { + out->fin = (fin == 'F'); + out->syn = (syn == 'S'); + out->rst = (rst == 'R'); + out->psh = (psh == 'P'); + out->ack = (ack == '.'); + + if ((fin != 'F' && fin != ' ') || + (syn != 'S' && syn != ' ') || + (rst != 'R' && rst != ' ') || + (psh != 'P' && psh != ' ') || + (ack != '.' && ack != ' ')) { + test_print("Couldn't parse trace event flags %c%c%c%c%c", + fin, syn, rst, psh, ack); + ret = -EINVAL; + goto out_free; + } + } + + if (src && tracer_scan_address(out->family, src, &out->src, &out->src_port)) { + ret = -EINVAL; + goto out_free; + } + + if (dst && tracer_scan_address(out->family, dst, &out->dst, &out->dst_port)) { + ret = -EINVAL; + goto out_free; + } + + if (netns_cookie != ns_cookie1 && netns_cookie != ns_cookie2) { + test_print("Net namespace filter for trace event didn't work: %" PRIu64 " != %" PRIu64 " OR %" PRIu64, + netns_cookie, ns_cookie1, ns_cookie2); + ret = -EINVAL; + } + +out_free: + free(src); + free(dst); + free(family); + return ret; +} + +static enum ftracer_op aolib_tracer_process_event(const char *line) +{ + int event_type = check_event_type(line); + struct trace_point tmp = {}; + + if (event_type < 0) + return FTRACER_LINE_PRESERVE; + + if (tracer_scan_event(line, event_type, &tmp)) + return FTRACER_LINE_PRESERVE; + + return lookup_expected_event(event_type, &tmp) ? + FTRACER_LINE_DISCARD : FTRACER_LINE_PRESERVE; +} + +static void dump_trace_event(struct expected_trace_point *e) +{ + char src[INET6_ADDRSTRLEN], dst[INET6_ADDRSTRLEN]; + + if (!inet_ntop(e->family, &e->src, src, INET6_ADDRSTRLEN)) + test_error("inet_ntop()"); + if (!inet_ntop(e->family, &e->dst, dst, INET6_ADDRSTRLEN)) + test_error("inet_ntop()"); + test_print("trace event filter %s [%s:%d => %s:%d, L3index %d, flags: %s%s%s%s%s, keyid: %d, rnext: %d, maclen: %d, sne: %d] = %zu", + trace_event_names[e->type], + src, e->src_port, dst, e->dst_port, e->L3index, + e->fin ? "F" : "", e->syn ? "S" : "", e->rst ? "R" : "", + e->psh ? "P" : "", e->ack ? "." : "", + e->keyid, e->rnext, e->maclen, e->sne, e->matched); +} + +static void print_match_stats(bool unexpected_events) +{ + size_t matches_per_type[__MAX_TRACE_EVENTS] = {}; + bool expected_but_none = false; + size_t i, total_matched = 0; + char *stat_line = NULL; + + for (i = 0; i < exp_tps_nr; i++) { + struct expected_trace_point *e = &exp_tps[i]; + + total_matched += e->matched; + matches_per_type[e->type] += e->matched; + if (!e->matched) + expected_but_none = true; + } + for (i = 0; i < __MAX_TRACE_EVENTS; i++) { + if (!matches_per_type[i]) + continue; + stat_line = test_sprintf("%s%s[%zu] ", stat_line ?: "", + trace_event_names[i], + matches_per_type[i]); + if (!stat_line) + test_error("test_sprintf()"); + } + + if (unexpected_events || expected_but_none) { + for (i = 0; i < exp_tps_nr; i++) + dump_trace_event(&exp_tps[i]); + } + + if (unexpected_events) + return; + + if (expected_but_none) + test_fail("Some trace events were expected, but didn't occur"); + else if (total_matched) + test_ok("Trace events matched expectations: %zu %s", + total_matched, stat_line); + else + test_ok("No unexpected trace events during the test run"); +} + +#define dump_events(fmt, ...) \ + __test_print(__test_msg, fmt, ##__VA_ARGS__) +static void check_free_events(struct test_ftracer *tracer) +{ + const char **lines; + size_t nr; + + if (!kernel_config_has(KCONFIG_FTRACE)) { + test_skip("kernel config doesn't have ftrace - no checks"); + return; + } + + nr = tracer_get_savedlines_nr(tracer); + lines = tracer_get_savedlines(tracer); + print_match_stats(!!nr); + if (!nr) + return; + + errno = 0; + test_xfail("Trace events [%zu] were not expected:", nr); + while (nr) + dump_events("\t%s", lines[--nr]); +} + +static int setup_tcp_trace_events(struct test_ftracer *tracer) +{ + char *filter; + size_t i; + int ret; + + filter = test_sprintf("net_cookie == %zu || net_cookie == %zu", + ns_cookie1, ns_cookie2); + if (!filter) + return -ENOMEM; + + for (i = 0; i < __MAX_TRACE_EVENTS; i++) { + char *event_name = test_sprintf("tcp/%s", trace_event_names[i]); + + if (!event_name) { + ret = -ENOMEM; + break; + } + ret = setup_trace_event(tracer, event_name, filter); + free(event_name); + if (ret) + break; + } + + free(filter); + return ret; +} + +static void aolib_tracer_destroy(struct test_ftracer *tracer) +{ + check_free_events(tracer); + free_expected_events(); +} + +static bool aolib_tracer_expecting_more(void) +{ + size_t i; + + for (i = 0; i < exp_tps_nr; i++) + if (!exp_tps[i].matched) + return true; + return false; +} + +int setup_aolib_ftracer(void) +{ + struct test_ftracer *f; + + f = create_ftracer("aolib", aolib_tracer_process_event, + aolib_tracer_destroy, aolib_tracer_expecting_more, + DEFAULT_FTRACE_BUFFER_KB, DEFAULT_TRACER_LINES_ARR); + if (!f) + return -1; + + return setup_tcp_trace_events(f); +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/ftrace.c b/tools/testing/selftests/net/tcp_ao/lib/ftrace.c new file mode 100644 index 000000000000..e4d0b173bc94 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/ftrace.c @@ -0,0 +1,543 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <inttypes.h> +#include <pthread.h> +#include <stdbool.h> +#include <stdio.h> +#include <stdlib.h> +#include <sys/mount.h> +#include <sys/time.h> +#include <unistd.h> +#include "../../../../../include/linux/kernel.h" +#include "aolib.h" + +static char ftrace_path[] = "ksft-ftrace-XXXXXX"; +static bool ftrace_mounted; +uint64_t ns_cookie1, ns_cookie2; + +struct test_ftracer { + pthread_t tracer_thread; + int error; + char *instance_path; + FILE *trace_pipe; + + enum ftracer_op (*process_line)(const char *line); + void (*destructor)(struct test_ftracer *tracer); + bool (*expecting_more)(void); + + char **saved_lines; + size_t saved_lines_size; + size_t next_line_ind; + + pthread_cond_t met_all_expected; + pthread_mutex_t met_all_expected_lock; + + struct test_ftracer *next; +}; + +static struct test_ftracer *ftracers; +static pthread_mutex_t ftracers_lock = PTHREAD_MUTEX_INITIALIZER; + +static int mount_ftrace(void) +{ + if (!mkdtemp(ftrace_path)) + test_error("Can't create temp dir"); + + if (mount("tracefs", ftrace_path, "tracefs", 0, "rw")) + return -errno; + + ftrace_mounted = true; + + return 0; +} + +static void unmount_ftrace(void) +{ + if (ftrace_mounted && umount(ftrace_path)) + test_print("Failed on cleanup: can't unmount tracefs: %m"); + + if (rmdir(ftrace_path)) + test_error("Failed on cleanup: can't remove ftrace dir %s", + ftrace_path); +} + +struct opts_list_t { + char *opt_name; + struct opts_list_t *next; +}; + +static int disable_trace_options(const char *ftrace_path) +{ + struct opts_list_t *opts_list = NULL; + char *fopts, *line = NULL; + size_t buf_len = 0; + ssize_t line_len; + int ret = 0; + FILE *opts; + + fopts = test_sprintf("%s/%s", ftrace_path, "trace_options"); + if (!fopts) + return -ENOMEM; + + opts = fopen(fopts, "r+"); + if (!opts) { + ret = -errno; + goto out_free; + } + + while ((line_len = getline(&line, &buf_len, opts)) != -1) { + struct opts_list_t *tmp; + + if (!strncmp(line, "no", 2)) + continue; + + tmp = malloc(sizeof(*tmp)); + if (!tmp) { + ret = -ENOMEM; + goto out_free_opts_list; + } + tmp->next = opts_list; + tmp->opt_name = test_sprintf("no%s", line); + if (!tmp->opt_name) { + ret = -ENOMEM; + free(tmp); + goto out_free_opts_list; + } + opts_list = tmp; + } + + while (opts_list) { + struct opts_list_t *tmp = opts_list; + + fseek(opts, 0, SEEK_SET); + fwrite(tmp->opt_name, 1, strlen(tmp->opt_name), opts); + + opts_list = opts_list->next; + free(tmp->opt_name); + free(tmp); + } + +out_free_opts_list: + while (opts_list) { + struct opts_list_t *tmp = opts_list; + + opts_list = opts_list->next; + free(tmp->opt_name); + free(tmp); + } + free(line); + fclose(opts); +out_free: + free(fopts); + return ret; +} + +static int setup_buffer_size(const char *ftrace_path, size_t sz) +{ + char *fbuf_size = test_sprintf("%s/buffer_size_kb", ftrace_path); + int ret; + + if (!fbuf_size) + return -1; + + ret = test_echo(fbuf_size, 0, "%zu", sz); + free(fbuf_size); + return ret; +} + +static int setup_ftrace_instance(struct test_ftracer *tracer, const char *name) +{ + char *tmp; + + tmp = test_sprintf("%s/instances/ksft-%s-XXXXXX", ftrace_path, name); + if (!tmp) + return -ENOMEM; + + tracer->instance_path = mkdtemp(tmp); + if (!tracer->instance_path) { + free(tmp); + return -errno; + } + + return 0; +} + +static void remove_ftrace_instance(struct test_ftracer *tracer) +{ + if (rmdir(tracer->instance_path)) + test_print("Failed on cleanup: can't remove ftrace instance %s", + tracer->instance_path); + free(tracer->instance_path); +} + +static void tracer_cleanup(void *arg) +{ + struct test_ftracer *tracer = arg; + + fclose(tracer->trace_pipe); +} + +static void tracer_set_error(struct test_ftracer *tracer, int error) +{ + if (!tracer->error) + tracer->error = error; +} + +const size_t tracer_get_savedlines_nr(struct test_ftracer *tracer) +{ + return tracer->next_line_ind; +} + +const char **tracer_get_savedlines(struct test_ftracer *tracer) +{ + return (const char **)tracer->saved_lines; +} + +static void *tracer_thread_func(void *arg) +{ + struct test_ftracer *tracer = arg; + + pthread_cleanup_push(tracer_cleanup, arg); + + while (tracer->next_line_ind < tracer->saved_lines_size) { + char **lp = &tracer->saved_lines[tracer->next_line_ind]; + enum ftracer_op op; + size_t buf_len = 0; + ssize_t line_len; + + line_len = getline(lp, &buf_len, tracer->trace_pipe); + if (line_len == -1) + break; + + pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, NULL); + op = tracer->process_line(*lp); + pthread_setcancelstate(PTHREAD_CANCEL_ENABLE, NULL); + + if (tracer->expecting_more) { + pthread_mutex_lock(&tracer->met_all_expected_lock); + if (!tracer->expecting_more()) + pthread_cond_signal(&tracer->met_all_expected); + pthread_mutex_unlock(&tracer->met_all_expected_lock); + } + + if (op == FTRACER_LINE_DISCARD) + continue; + if (op == FTRACER_EXIT) + break; + if (op != FTRACER_LINE_PRESERVE) + test_error("unexpected tracer command %d", op); + + tracer->next_line_ind++; + buf_len = 0; + } + test_print("too many lines in ftracer buffer %zu, exiting tracer", + tracer->next_line_ind); + + pthread_cleanup_pop(1); + return NULL; +} + +static int setup_trace_thread(struct test_ftracer *tracer) +{ + int ret = 0; + char *path; + + path = test_sprintf("%s/trace_pipe", tracer->instance_path); + if (!path) + return -ENOMEM; + + tracer->trace_pipe = fopen(path, "r"); + if (!tracer->trace_pipe) { + ret = -errno; + goto out_free; + } + + if (pthread_create(&tracer->tracer_thread, NULL, + tracer_thread_func, (void *)tracer)) { + ret = -errno; + fclose(tracer->trace_pipe); + } + +out_free: + free(path); + return ret; +} + +static void stop_trace_thread(struct test_ftracer *tracer) +{ + void *res; + + if (pthread_cancel(tracer->tracer_thread)) { + test_print("Can't stop tracer pthread: %m"); + tracer_set_error(tracer, -errno); + } + if (pthread_join(tracer->tracer_thread, &res)) { + test_print("Can't join tracer pthread: %m"); + tracer_set_error(tracer, -errno); + } + if (res != PTHREAD_CANCELED) { + test_print("Tracer thread wasn't canceled"); + tracer_set_error(tracer, -errno); + } + if (tracer->error) + test_fail("tracer errored by %s", strerror(tracer->error)); +} + +static void final_wait_for_events(struct test_ftracer *tracer, + unsigned timeout_sec) +{ + struct timespec timeout; + struct timeval now; + int ret = 0; + + if (!tracer->expecting_more) + return; + + pthread_mutex_lock(&tracer->met_all_expected_lock); + gettimeofday(&now, NULL); + timeout.tv_sec = now.tv_sec + timeout_sec; + timeout.tv_nsec = now.tv_usec * 1000; + + while (tracer->expecting_more() && ret != ETIMEDOUT) + ret = pthread_cond_timedwait(&tracer->met_all_expected, + &tracer->met_all_expected_lock, &timeout); + pthread_mutex_unlock(&tracer->met_all_expected_lock); +} + +int setup_trace_event(struct test_ftracer *tracer, + const char *event, const char *filter) +{ + char *enable_path, *filter_path, *instance = tracer->instance_path; + int ret; + + enable_path = test_sprintf("%s/events/%s/enable", instance, event); + if (!enable_path) + return -ENOMEM; + + filter_path = test_sprintf("%s/events/%s/filter", instance, event); + if (!filter_path) { + ret = -ENOMEM; + goto out_free; + } + + ret = test_echo(filter_path, 0, "%s", filter); + if (!ret) + ret = test_echo(enable_path, 0, "1"); + +out_free: + free(filter_path); + free(enable_path); + return ret; +} + +struct test_ftracer *create_ftracer(const char *name, + enum ftracer_op (*process_line)(const char *line), + void (*destructor)(struct test_ftracer *tracer), + bool (*expecting_more)(void), + size_t lines_buf_sz, size_t buffer_size_kb) +{ + struct test_ftracer *tracer; + int err; + + /* XXX: separate __create_ftracer() helper and do here + * if (!kernel_config_has(KCONFIG_FTRACE)) + * return NULL; + */ + + tracer = malloc(sizeof(*tracer)); + if (!tracer) { + test_print("malloc()"); + return NULL; + } + + memset(tracer, 0, sizeof(*tracer)); + + err = setup_ftrace_instance(tracer, name); + if (err) { + test_print("setup_ftrace_instance(): %d", err); + goto err_free; + } + + err = disable_trace_options(tracer->instance_path); + if (err) { + test_print("disable_trace_options(): %d", err); + goto err_remove; + } + + err = setup_buffer_size(tracer->instance_path, buffer_size_kb); + if (err) { + test_print("disable_trace_options(): %d", err); + goto err_remove; + } + + tracer->saved_lines = calloc(lines_buf_sz, sizeof(tracer->saved_lines[0])); + if (!tracer->saved_lines) { + test_print("calloc()"); + goto err_remove; + } + tracer->saved_lines_size = lines_buf_sz; + + tracer->process_line = process_line; + tracer->destructor = destructor; + tracer->expecting_more = expecting_more; + + err = pthread_cond_init(&tracer->met_all_expected, NULL); + if (err) { + test_print("pthread_cond_init(): %d", err); + goto err_free_lines; + } + + err = pthread_mutex_init(&tracer->met_all_expected_lock, NULL); + if (err) { + test_print("pthread_mutex_init(): %d", err); + goto err_cond_destroy; + } + + err = setup_trace_thread(tracer); + if (err) { + test_print("setup_trace_thread(): %d", err); + goto err_mutex_destroy; + } + + pthread_mutex_lock(&ftracers_lock); + tracer->next = ftracers; + ftracers = tracer; + pthread_mutex_unlock(&ftracers_lock); + + return tracer; + +err_mutex_destroy: + pthread_mutex_destroy(&tracer->met_all_expected_lock); +err_cond_destroy: + pthread_cond_destroy(&tracer->met_all_expected); +err_free_lines: + free(tracer->saved_lines); +err_remove: + remove_ftrace_instance(tracer); +err_free: + free(tracer); + return NULL; +} + +static void __destroy_ftracer(struct test_ftracer *tracer) +{ + size_t i; + + final_wait_for_events(tracer, TEST_TIMEOUT_SEC); + stop_trace_thread(tracer); + remove_ftrace_instance(tracer); + if (tracer->destructor) + tracer->destructor(tracer); + for (i = 0; i < tracer->saved_lines_size; i++) + free(tracer->saved_lines[i]); + pthread_cond_destroy(&tracer->met_all_expected); + pthread_mutex_destroy(&tracer->met_all_expected_lock); + free(tracer); +} + +void destroy_ftracer(struct test_ftracer *tracer) +{ + pthread_mutex_lock(&ftracers_lock); + if (tracer == ftracers) { + ftracers = tracer->next; + } else { + struct test_ftracer *f = ftracers; + + while (f->next != tracer) { + if (!f->next) + test_error("tracers list corruption or double free %p", tracer); + f = f->next; + } + f->next = tracer->next; + } + tracer->next = NULL; + pthread_mutex_unlock(&ftracers_lock); + __destroy_ftracer(tracer); +} + +static void destroy_all_ftracers(void) +{ + struct test_ftracer *f; + + pthread_mutex_lock(&ftracers_lock); + f = ftracers; + ftracers = NULL; + pthread_mutex_unlock(&ftracers_lock); + + while (f) { + struct test_ftracer *n = f->next; + + f->next = NULL; + __destroy_ftracer(f); + f = n; + } +} + +static void test_unset_tracing(void) +{ + destroy_all_ftracers(); + unmount_ftrace(); +} + +int test_setup_tracing(void) +{ + /* + * Just a basic protection - this should be called only once from + * lib/kconfig. Not thread safe, which is fine as it's early, before + * threads are created. + */ + static int already_set; + int err; + + if (already_set) + return -1; + + /* Needs net-namespace cookies for filters */ + if (ns_cookie1 == ns_cookie2) { + test_print("net-namespace cookies: %" PRIu64 " == %" PRIu64 ", can't set up tracing", + ns_cookie1, ns_cookie2); + return -1; + } + + already_set = 1; + + test_add_destructor(test_unset_tracing); + + err = mount_ftrace(); + if (err) { + test_print("failed to mount_ftrace(): %d", err); + return err; + } + + return setup_aolib_ftracer(); +} + +static int get_ns_cookie(int nsfd, uint64_t *out) +{ + int old_ns = switch_save_ns(nsfd); + socklen_t size = sizeof(*out); + int sk; + + sk = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) { + test_print("socket(): %m"); + return -errno; + } + + if (getsockopt(sk, SOL_SOCKET, SO_NETNS_COOKIE, out, &size)) { + test_print("getsockopt(SO_NETNS_COOKIE): %m"); + close(sk); + return -errno; + } + + close(sk); + switch_close_ns(old_ns); + return 0; +} + +void test_init_ftrace(int nsfd1, int nsfd2) +{ + get_ns_cookie(nsfd1, &ns_cookie1); + get_ns_cookie(nsfd2, &ns_cookie2); + /* Populate kernel config state */ + kernel_config_has(KCONFIG_FTRACE); +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/kconfig.c b/tools/testing/selftests/net/tcp_ao/lib/kconfig.c new file mode 100644 index 000000000000..9f1c175846f8 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/kconfig.c @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Check what features does the kernel support (where the selftest is running). + * Somewhat inspired by CRIU kerndat/kdat kernel features detector. + */ +#include <pthread.h> +#include "aolib.h" + +struct kconfig_t { + int _error; /* negative errno if not supported */ + int (*check_kconfig)(int *error); +}; + +static int has_net_ns(int *err) +{ + if (access("/proc/self/ns/net", F_OK) < 0) { + *err = errno; + if (errno == ENOENT) + return 0; + test_print("Unable to access /proc/self/ns/net: %m"); + return -errno; + } + return *err = errno = 0; +} + +static int has_veth(int *err) +{ + int orig_netns, ns_a, ns_b; + + orig_netns = open_netns(); + ns_a = unshare_open_netns(); + ns_b = unshare_open_netns(); + + *err = add_veth("check_veth", ns_a, ns_b); + + switch_ns(orig_netns); + close(orig_netns); + close(ns_a); + close(ns_b); + return 0; +} + +static int has_tcp_ao(int *err) +{ + struct sockaddr_in addr = { + .sin_family = test_family, + }; + struct tcp_ao_add tmp = {}; + const char *password = DEFAULT_TEST_PASSWORD; + int sk, ret = 0; + + sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) { + test_print("socket(): %m"); + return -errno; + } + + tmp.sndid = 100; + tmp.rcvid = 100; + tmp.keylen = strlen(password); + memcpy(tmp.key, password, strlen(password)); + strcpy(tmp.alg_name, "hmac(sha1)"); + memcpy(&tmp.addr, &addr, sizeof(addr)); + *err = 0; + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_ADD_KEY, &tmp, sizeof(tmp)) < 0) { + *err = -errno; + if (errno != ENOPROTOOPT) + ret = -errno; + } + close(sk); + return ret; +} + +static int has_tcp_md5(int *err) +{ + union tcp_addr addr_any = {}; + int sk, ret = 0; + + sk = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (sk < 0) { + test_print("socket(): %m"); + return -errno; + } + + /* + * Under CONFIG_CRYPTO_FIPS=y it fails with ENOMEM, rather with + * anything more descriptive. Oh well. + */ + *err = 0; + if (test_set_md5(sk, addr_any, 0, -1, DEFAULT_TEST_PASSWORD)) { + *err = -errno; + if (errno != ENOPROTOOPT && errno == ENOMEM) { + test_print("setsockopt(TCP_MD5SIG_EXT): %m"); + ret = -errno; + } + } + close(sk); + return ret; +} + +static int has_vrfs(int *err) +{ + int orig_netns, ns_test, ret = 0; + + orig_netns = open_netns(); + ns_test = unshare_open_netns(); + + *err = add_vrf("ksft-check", 55, 101, ns_test); + if (*err && *err != -EOPNOTSUPP) { + test_print("Failed to add a VRF: %d", *err); + ret = *err; + } + + switch_ns(orig_netns); + close(orig_netns); + close(ns_test); + return ret; +} + +static int has_ftrace(int *err) +{ + *err = test_setup_tracing(); + return 0; +} + +#define KCONFIG_UNKNOWN 1 +static pthread_mutex_t kconfig_lock = PTHREAD_MUTEX_INITIALIZER; +static struct kconfig_t kconfig[__KCONFIG_LAST__] = { + { KCONFIG_UNKNOWN, has_net_ns }, + { KCONFIG_UNKNOWN, has_veth }, + { KCONFIG_UNKNOWN, has_tcp_ao }, + { KCONFIG_UNKNOWN, has_tcp_md5 }, + { KCONFIG_UNKNOWN, has_vrfs }, + { KCONFIG_UNKNOWN, has_ftrace }, +}; + +const char *tests_skip_reason[__KCONFIG_LAST__] = { + "Tests require network namespaces support (CONFIG_NET_NS)", + "Tests require veth support (CONFIG_VETH)", + "Tests require TCP-AO support (CONFIG_TCP_AO)", + "setsockopt(TCP_MD5SIG_EXT) is not supported (CONFIG_TCP_MD5)", + "VRFs are not supported (CONFIG_NET_VRF)", + "Ftrace points are not supported (CONFIG_TRACEPOINTS)", +}; + +bool kernel_config_has(enum test_needs_kconfig k) +{ + bool ret; + + pthread_mutex_lock(&kconfig_lock); + if (kconfig[k]._error == KCONFIG_UNKNOWN) { + if (kconfig[k].check_kconfig(&kconfig[k]._error)) + test_error("Failed to initialize kconfig %u", k); + } + ret = kconfig[k]._error == 0; + pthread_mutex_unlock(&kconfig_lock); + return ret; +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/netlink.c b/tools/testing/selftests/net/tcp_ao/lib/netlink.c new file mode 100644 index 000000000000..7f108493a29a --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/netlink.c @@ -0,0 +1,413 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Original from tools/testing/selftests/net/ipsec.c */ +#include <linux/netlink.h> +#include <linux/random.h> +#include <linux/rtnetlink.h> +#include <linux/veth.h> +#include <net/if.h> +#include <stdint.h> +#include <string.h> +#include <sys/socket.h> + +#include "aolib.h" + +#define MAX_PAYLOAD 2048 + +static int netlink_sock(int *sock, uint32_t *seq_nr, int proto) +{ + if (*sock > 0) { + seq_nr++; + return 0; + } + + *sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto); + if (*sock < 0) { + test_print("socket(AF_NETLINK)"); + return -1; + } + + randomize_buffer(seq_nr, sizeof(*seq_nr)); + + return 0; +} + +static int netlink_check_answer(int sock, bool quite) +{ + struct nlmsgerror { + struct nlmsghdr hdr; + int error; + struct nlmsghdr orig_msg; + } answer; + + if (recv(sock, &answer, sizeof(answer), 0) < 0) { + test_print("recv()"); + return -1; + } else if (answer.hdr.nlmsg_type != NLMSG_ERROR) { + test_print("expected NLMSG_ERROR, got %d", + (int)answer.hdr.nlmsg_type); + return -1; + } else if (answer.error) { + if (!quite) { + test_print("NLMSG_ERROR: %d: %s", + answer.error, strerror(-answer.error)); + } + return answer.error; + } + + return 0; +} + +static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh) +{ + return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len)); +} + +static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz, + unsigned short rta_type, const void *payload, size_t size) +{ + /* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */ + struct rtattr *attr = rtattr_hdr(nh); + size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size); + + if (req_sz < nl_size) { + test_print("req buf is too small: %zu < %zu", req_sz, nl_size); + return -1; + } + nh->nlmsg_len = nl_size; + + attr->rta_len = RTA_LENGTH(size); + attr->rta_type = rta_type; + memcpy(RTA_DATA(attr), payload, size); + + return 0; +} + +static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz, + unsigned short rta_type, const void *payload, size_t size) +{ + struct rtattr *ret = rtattr_hdr(nh); + + if (rtattr_pack(nh, req_sz, rta_type, payload, size)) + return 0; + + return ret; +} + +static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz, + unsigned short rta_type) +{ + return _rtattr_begin(nh, req_sz, rta_type, 0, 0); +} + +static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr) +{ + char *nlmsg_end = (char *)nh + nh->nlmsg_len; + + attr->rta_len = nlmsg_end - (char *)attr; +} + +static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz, + const char *peer, int ns) +{ + struct ifinfomsg pi; + struct rtattr *peer_attr; + + memset(&pi, 0, sizeof(pi)); + pi.ifi_family = AF_UNSPEC; + pi.ifi_change = 0xFFFFFFFF; + + peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi)); + if (!peer_attr) + return -1; + + if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer))) + return -1; + + if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns))) + return -1; + + rtattr_end(nh, peer_attr); + + return 0; +} + +static int __add_veth(int sock, uint32_t seq, const char *name, + int ns_a, int ns_b) +{ + uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE; + struct { + struct nlmsghdr nh; + struct ifinfomsg info; + char attrbuf[MAX_PAYLOAD]; + } req; + static const char veth_type[] = "veth"; + struct rtattr *link_info, *info_data; + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info)); + req.nh.nlmsg_type = RTM_NEWLINK; + req.nh.nlmsg_flags = flags; + req.nh.nlmsg_seq = seq; + req.info.ifi_family = AF_UNSPEC; + req.info.ifi_change = 0xFFFFFFFF; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, name, strlen(name))) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a))) + return -1; + + link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO); + if (!link_info) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type))) + return -1; + + info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA); + if (!info_data) + return -1; + + if (veth_pack_peerb(&req.nh, sizeof(req), name, ns_b)) + return -1; + + rtattr_end(&req.nh, info_data); + rtattr_end(&req.nh, link_info); + + if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) { + test_print("send()"); + return -1; + } + return netlink_check_answer(sock, false); +} + +int add_veth(const char *name, int nsfda, int nsfdb) +{ + int route_sock = -1, ret; + uint32_t route_seq; + + if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) + test_error("Failed to open netlink route socket\n"); + + ret = __add_veth(route_sock, route_seq++, name, nsfda, nsfdb); + close(route_sock); + return ret; +} + +static int __ip_addr_add(int sock, uint32_t seq, const char *intf, + int family, union tcp_addr addr, uint8_t prefix) +{ + uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE; + struct { + struct nlmsghdr nh; + struct ifaddrmsg info; + char attrbuf[MAX_PAYLOAD]; + } req; + size_t addr_len = (family == AF_INET) ? sizeof(struct in_addr) : + sizeof(struct in6_addr); + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info)); + req.nh.nlmsg_type = RTM_NEWADDR; + req.nh.nlmsg_flags = flags; + req.nh.nlmsg_seq = seq; + req.info.ifa_family = family; + req.info.ifa_prefixlen = prefix; + req.info.ifa_index = if_nametoindex(intf); + req.info.ifa_flags = IFA_F_NODAD; + + if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, addr_len)) + return -1; + + if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) { + test_print("send()"); + return -1; + } + return netlink_check_answer(sock, true); +} + +int ip_addr_add(const char *intf, int family, + union tcp_addr addr, uint8_t prefix) +{ + int route_sock = -1, ret; + uint32_t route_seq; + + if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) + test_error("Failed to open netlink route socket\n"); + + ret = __ip_addr_add(route_sock, route_seq++, intf, + family, addr, prefix); + + close(route_sock); + return ret; +} + +static int __ip_route_add(int sock, uint32_t seq, const char *intf, int family, + union tcp_addr src, union tcp_addr dst, uint8_t vrf) +{ + struct { + struct nlmsghdr nh; + struct rtmsg rt; + char attrbuf[MAX_PAYLOAD]; + } req; + unsigned int index = if_nametoindex(intf); + size_t addr_len = (family == AF_INET) ? sizeof(struct in_addr) : + sizeof(struct in6_addr); + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.rt)); + req.nh.nlmsg_type = RTM_NEWROUTE; + req.nh.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE; + req.nh.nlmsg_seq = seq; + req.rt.rtm_family = family; + req.rt.rtm_dst_len = (family == AF_INET) ? 32 : 128; + req.rt.rtm_table = vrf; + req.rt.rtm_protocol = RTPROT_BOOT; + req.rt.rtm_scope = RT_SCOPE_UNIVERSE; + req.rt.rtm_type = RTN_UNICAST; + + if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, addr_len)) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, addr_len)) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index))) + return -1; + + if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) { + test_print("send()"); + return -1; + } + + return netlink_check_answer(sock, true); +} + +int ip_route_add_vrf(const char *intf, int family, + union tcp_addr src, union tcp_addr dst, uint8_t vrf) +{ + int route_sock = -1, ret; + uint32_t route_seq; + + if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) + test_error("Failed to open netlink route socket\n"); + + ret = __ip_route_add(route_sock, route_seq++, intf, + family, src, dst, vrf); + + close(route_sock); + return ret; +} + +int ip_route_add(const char *intf, int family, + union tcp_addr src, union tcp_addr dst) +{ + return ip_route_add_vrf(intf, family, src, dst, RT_TABLE_MAIN); +} + +static int __link_set_up(int sock, uint32_t seq, const char *intf) +{ + struct { + struct nlmsghdr nh; + struct ifinfomsg info; + char attrbuf[MAX_PAYLOAD]; + } req; + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info)); + req.nh.nlmsg_type = RTM_NEWLINK; + req.nh.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + req.nh.nlmsg_seq = seq; + req.info.ifi_family = AF_UNSPEC; + req.info.ifi_change = 0xFFFFFFFF; + req.info.ifi_index = if_nametoindex(intf); + req.info.ifi_flags = IFF_UP; + req.info.ifi_change = IFF_UP; + + if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) { + test_print("send()"); + return -1; + } + return netlink_check_answer(sock, false); +} + +int link_set_up(const char *intf) +{ + int route_sock = -1, ret; + uint32_t route_seq; + + if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) + test_error("Failed to open netlink route socket\n"); + + ret = __link_set_up(route_sock, route_seq++, intf); + + close(route_sock); + return ret; +} + +static int __add_vrf(int sock, uint32_t seq, const char *name, + uint32_t tabid, int ifindex, int nsfd) +{ + uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE; + struct { + struct nlmsghdr nh; + struct ifinfomsg info; + char attrbuf[MAX_PAYLOAD]; + } req; + static const char vrf_type[] = "vrf"; + struct rtattr *link_info, *info_data; + + memset(&req, 0, sizeof(req)); + req.nh.nlmsg_len = NLMSG_LENGTH(sizeof(req.info)); + req.nh.nlmsg_type = RTM_NEWLINK; + req.nh.nlmsg_flags = flags; + req.nh.nlmsg_seq = seq; + req.info.ifi_family = AF_UNSPEC; + req.info.ifi_change = 0xFFFFFFFF; + req.info.ifi_index = ifindex; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, name, strlen(name))) + return -1; + + if (nsfd >= 0) + if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, + &nsfd, sizeof(nsfd))) + return -1; + + link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO); + if (!link_info) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, vrf_type, sizeof(vrf_type))) + return -1; + + info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA); + if (!info_data) + return -1; + + if (rtattr_pack(&req.nh, sizeof(req), IFLA_VRF_TABLE, + &tabid, sizeof(tabid))) + return -1; + + rtattr_end(&req.nh, info_data); + rtattr_end(&req.nh, link_info); + + if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) { + test_print("send()"); + return -1; + } + return netlink_check_answer(sock, true); +} + +int add_vrf(const char *name, uint32_t tabid, int ifindex, int nsfd) +{ + int route_sock = -1, ret; + uint32_t route_seq; + + if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) + test_error("Failed to open netlink route socket\n"); + + ret = __add_vrf(route_sock, route_seq++, name, tabid, ifindex, nsfd); + close(route_sock); + return ret; +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/proc.c b/tools/testing/selftests/net/tcp_ao/lib/proc.c new file mode 100644 index 000000000000..8b984fa04286 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/proc.c @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <inttypes.h> +#include <pthread.h> +#include <stdio.h> +#include "../../../../../include/linux/compiler.h" +#include "../../../../../include/linux/kernel.h" +#include "aolib.h" + +struct netstat_counter { + uint64_t val; + char *name; +}; + +struct netstat { + char *header_name; + struct netstat *next; + size_t counters_nr; + struct netstat_counter *counters; +}; + +static struct netstat *lookup_type(struct netstat *ns, + const char *type, size_t len) +{ + while (ns != NULL) { + size_t cmp = max(len, strlen(ns->header_name)); + + if (!strncmp(ns->header_name, type, cmp)) + return ns; + ns = ns->next; + } + return NULL; +} + +static struct netstat *lookup_get(struct netstat *ns, + const char *type, const size_t len) +{ + struct netstat *ret; + + ret = lookup_type(ns, type, len); + if (ret != NULL) + return ret; + + ret = malloc(sizeof(struct netstat)); + if (!ret) + test_error("malloc()"); + + ret->header_name = strndup(type, len); + if (ret->header_name == NULL) + test_error("strndup()"); + ret->next = ns; + ret->counters_nr = 0; + ret->counters = NULL; + + return ret; +} + +static struct netstat *lookup_get_column(struct netstat *ns, const char *line) +{ + char *column; + + column = strchr(line, ':'); + if (!column) + test_error("can't parse netstat file"); + + return lookup_get(ns, line, column - line); +} + +static void netstat_read_type(FILE *fnetstat, struct netstat **dest, char *line) +{ + struct netstat *type = lookup_get_column(*dest, line); + const char *pos = line; + size_t i, nr_elems = 0; + char tmp; + + while ((pos = strchr(pos, ' '))) { + nr_elems++; + pos++; + } + + *dest = type; + type->counters = reallocarray(type->counters, + type->counters_nr + nr_elems, + sizeof(struct netstat_counter)); + if (!type->counters) + test_error("reallocarray()"); + + pos = strchr(line, ' ') + 1; + + if (fscanf(fnetstat, "%[^ :]", type->header_name) == EOF) + test_error("fscanf(%s)", type->header_name); + if (fread(&tmp, 1, 1, fnetstat) != 1 || tmp != ':') + test_error("Unexpected netstat format (%c)", tmp); + + for (i = type->counters_nr; i < type->counters_nr + nr_elems; i++) { + struct netstat_counter *nc = &type->counters[i]; + const char *new_pos = strchr(pos, ' '); + const char *fmt = " %" PRIu64; + + if (new_pos == NULL) + new_pos = strchr(pos, '\n'); + + nc->name = strndup(pos, new_pos - pos); + if (nc->name == NULL) + test_error("strndup()"); + + if (unlikely(!strcmp(nc->name, "MaxConn"))) + fmt = " %" PRId64; /* MaxConn is signed, RFC 2012 */ + if (fscanf(fnetstat, fmt, &nc->val) != 1) + test_error("fscanf(%s)", nc->name); + pos = new_pos + 1; + } + type->counters_nr += nr_elems; + + if (fread(&tmp, 1, 1, fnetstat) != 1 || tmp != '\n') + test_error("Unexpected netstat format"); +} + +static const char *snmp6_name = "Snmp6"; +static void snmp6_read(FILE *fnetstat, struct netstat **dest) +{ + struct netstat *type = lookup_get(*dest, snmp6_name, strlen(snmp6_name)); + char *counter_name; + size_t i; + + for (i = type->counters_nr;; i++) { + struct netstat_counter *nc; + uint64_t counter; + + if (fscanf(fnetstat, "%ms", &counter_name) == EOF) + break; + if (fscanf(fnetstat, "%" PRIu64, &counter) == EOF) + test_error("Unexpected snmp6 format"); + type->counters = reallocarray(type->counters, i + 1, + sizeof(struct netstat_counter)); + if (!type->counters) + test_error("reallocarray()"); + nc = &type->counters[i]; + nc->name = counter_name; + nc->val = counter; + } + type->counters_nr = i; + *dest = type; +} + +struct netstat *netstat_read(void) +{ + struct netstat *ret = 0; + size_t line_sz = 0; + char *line = NULL; + FILE *fnetstat; + + /* + * Opening thread-self instead of /proc/net/... as the latter + * points to /proc/self/net/ which instantiates thread-leader's + * net-ns, see: + * commit 155134fef2b6 ("Revert "proc: Point /proc/{mounts,net} at..") + */ + errno = 0; + fnetstat = fopen("/proc/thread-self/net/netstat", "r"); + if (fnetstat == NULL) + test_error("failed to open /proc/net/netstat"); + + while (getline(&line, &line_sz, fnetstat) != -1) + netstat_read_type(fnetstat, &ret, line); + fclose(fnetstat); + + errno = 0; + fnetstat = fopen("/proc/thread-self/net/snmp", "r"); + if (fnetstat == NULL) + test_error("failed to open /proc/net/snmp"); + + while (getline(&line, &line_sz, fnetstat) != -1) + netstat_read_type(fnetstat, &ret, line); + fclose(fnetstat); + + errno = 0; + fnetstat = fopen("/proc/thread-self/net/snmp6", "r"); + if (fnetstat == NULL) + test_error("failed to open /proc/net/snmp6"); + + snmp6_read(fnetstat, &ret); + fclose(fnetstat); + + free(line); + return ret; +} + +void netstat_free(struct netstat *ns) +{ + while (ns != NULL) { + struct netstat *prev = ns; + size_t i; + + free(ns->header_name); + for (i = 0; i < ns->counters_nr; i++) + free(ns->counters[i].name); + free(ns->counters); + ns = ns->next; + free(prev); + } +} + +static inline void +__netstat_print_diff(uint64_t a, struct netstat *nsb, size_t i) +{ + if (unlikely(!strcmp(nsb->header_name, "MaxConn"))) { + test_print("%8s %25s: %" PRId64 " => %" PRId64, + nsb->header_name, nsb->counters[i].name, + a, nsb->counters[i].val); + return; + } + + test_print("%8s %25s: %" PRIu64 " => %" PRIu64, nsb->header_name, + nsb->counters[i].name, a, nsb->counters[i].val); +} + +void netstat_print_diff(struct netstat *nsa, struct netstat *nsb) +{ + size_t i, j; + + while (nsb != NULL) { + if (unlikely(strcmp(nsb->header_name, nsa->header_name))) { + for (i = 0; i < nsb->counters_nr; i++) + __netstat_print_diff(0, nsb, i); + nsb = nsb->next; + continue; + } + + if (nsb->counters_nr < nsa->counters_nr) + test_error("Unexpected: some counters disappeared!"); + + for (j = 0, i = 0; i < nsb->counters_nr; i++) { + if (strcmp(nsb->counters[i].name, nsa->counters[j].name)) { + __netstat_print_diff(0, nsb, i); + continue; + } + + if (nsa->counters[j].val == nsb->counters[i].val) { + j++; + continue; + } + + __netstat_print_diff(nsa->counters[j].val, nsb, i); + j++; + } + if (j != nsa->counters_nr) + test_error("Unexpected: some counters disappeared!"); + + nsb = nsb->next; + nsa = nsa->next; + } +} + +uint64_t netstat_get(struct netstat *ns, const char *name, bool *not_found) +{ + if (not_found) + *not_found = false; + + while (ns != NULL) { + size_t i; + + for (i = 0; i < ns->counters_nr; i++) { + if (!strcmp(name, ns->counters[i].name)) + return ns->counters[i].val; + } + + ns = ns->next; + } + + if (not_found) + *not_found = true; + return 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/repair.c b/tools/testing/selftests/net/tcp_ao/lib/repair.c new file mode 100644 index 000000000000..9893b3ba69f5 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/repair.c @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: GPL-2.0 +/* This is over-simplified TCP_REPAIR for TCP_ESTABLISHED sockets + * It tests that TCP-AO enabled connection can be restored. + * For the proper socket repair see: + * https://github.com/checkpoint-restore/criu/blob/criu-dev/soccr/soccr.h + */ +#include <fcntl.h> +#include <linux/sockios.h> +#include <sys/ioctl.h> +#include "aolib.h" + +#ifndef TCPOPT_MAXSEG +# define TCPOPT_MAXSEG 2 +#endif +#ifndef TCPOPT_WINDOW +# define TCPOPT_WINDOW 3 +#endif +#ifndef TCPOPT_SACK_PERMITTED +# define TCPOPT_SACK_PERMITTED 4 +#endif +#ifndef TCPOPT_TIMESTAMP +# define TCPOPT_TIMESTAMP 8 +#endif + +enum { + TCP_ESTABLISHED = 1, + TCP_SYN_SENT, + TCP_SYN_RECV, + TCP_FIN_WAIT1, + TCP_FIN_WAIT2, + TCP_TIME_WAIT, + TCP_CLOSE, + TCP_CLOSE_WAIT, + TCP_LAST_ACK, + TCP_LISTEN, + TCP_CLOSING, /* Now a valid state */ + TCP_NEW_SYN_RECV, + + TCP_MAX_STATES /* Leave at the end! */ +}; + +static void test_sock_checkpoint_queue(int sk, int queue, int qlen, + struct tcp_sock_queue *q) +{ + socklen_t len; + int ret; + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue))) + test_error("setsockopt(TCP_REPAIR_QUEUE)"); + + len = sizeof(q->seq); + ret = getsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &q->seq, &len); + if (ret || len != sizeof(q->seq)) + test_error("getsockopt(TCP_QUEUE_SEQ): %d", (int)len); + + if (!qlen) { + q->buf = NULL; + return; + } + + q->buf = malloc(qlen); + if (q->buf == NULL) + test_error("malloc()"); + ret = recv(sk, q->buf, qlen, MSG_PEEK | MSG_DONTWAIT); + if (ret != qlen) + test_error("recv(%d): %d", qlen, ret); +} + +void __test_sock_checkpoint(int sk, struct tcp_sock_state *state, + void *addr, size_t addr_size) +{ + socklen_t len = sizeof(state->info); + int ret; + + memset(state, 0, sizeof(*state)); + + ret = getsockopt(sk, SOL_TCP, TCP_INFO, &state->info, &len); + if (ret || len != sizeof(state->info)) + test_error("getsockopt(TCP_INFO): %d", (int)len); + + len = addr_size; + if (getsockname(sk, addr, &len) || len != addr_size) + test_error("getsockname(): %d", (int)len); + + len = sizeof(state->trw); + ret = getsockopt(sk, SOL_TCP, TCP_REPAIR_WINDOW, &state->trw, &len); + if (ret || len != sizeof(state->trw)) + test_error("getsockopt(TCP_REPAIR_WINDOW): %d", (int)len); + + if (ioctl(sk, SIOCOUTQ, &state->outq_len)) + test_error("ioctl(SIOCOUTQ)"); + + if (ioctl(sk, SIOCOUTQNSD, &state->outq_nsd_len)) + test_error("ioctl(SIOCOUTQNSD)"); + test_sock_checkpoint_queue(sk, TCP_SEND_QUEUE, state->outq_len, &state->out); + + if (ioctl(sk, SIOCINQ, &state->inq_len)) + test_error("ioctl(SIOCINQ)"); + test_sock_checkpoint_queue(sk, TCP_RECV_QUEUE, state->inq_len, &state->in); + + if (state->info.tcpi_state == TCP_CLOSE) + state->outq_len = state->outq_nsd_len = 0; + + len = sizeof(state->mss); + ret = getsockopt(sk, SOL_TCP, TCP_MAXSEG, &state->mss, &len); + if (ret || len != sizeof(state->mss)) + test_error("getsockopt(TCP_MAXSEG): %d", (int)len); + + len = sizeof(state->timestamp); + ret = getsockopt(sk, SOL_TCP, TCP_TIMESTAMP, &state->timestamp, &len); + if (ret || len != sizeof(state->timestamp)) + test_error("getsockopt(TCP_TIMESTAMP): %d", (int)len); +} + +void test_ao_checkpoint(int sk, struct tcp_ao_repair *state) +{ + socklen_t len = sizeof(*state); + int ret; + + memset(state, 0, sizeof(*state)); + + ret = getsockopt(sk, SOL_TCP, TCP_AO_REPAIR, state, &len); + if (ret || len != sizeof(*state)) + test_error("getsockopt(TCP_AO_REPAIR): %d", (int)len); +} + +static void test_sock_restore_seq(int sk, int queue, uint32_t seq) +{ + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue))) + test_error("setsockopt(TCP_REPAIR_QUEUE)"); + + if (setsockopt(sk, SOL_TCP, TCP_QUEUE_SEQ, &seq, sizeof(seq))) + test_error("setsockopt(TCP_QUEUE_SEQ)"); +} + +static void test_sock_restore_queue(int sk, int queue, void *buf, int len) +{ + int chunk = len; + size_t off = 0; + + if (len == 0) + return; + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_QUEUE, &queue, sizeof(queue))) + test_error("setsockopt(TCP_REPAIR_QUEUE)"); + + do { + int ret; + + ret = send(sk, buf + off, chunk, 0); + if (ret <= 0) { + if (chunk > 1024) { + chunk >>= 1; + continue; + } + test_error("send()"); + } + off += ret; + len -= ret; + } while (len > 0); +} + +void __test_sock_restore(int sk, const char *device, + struct tcp_sock_state *state, + void *saddr, void *daddr, size_t addr_size) +{ + struct tcp_repair_opt opts[4]; + unsigned int opt_nr = 0; + long flags; + + if (bind(sk, saddr, addr_size)) + test_error("bind()"); + + flags = fcntl(sk, F_GETFL); + if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0)) + test_error("fcntl()"); + + test_sock_restore_seq(sk, TCP_RECV_QUEUE, state->in.seq - state->inq_len); + test_sock_restore_seq(sk, TCP_SEND_QUEUE, state->out.seq - state->outq_len); + + if (device != NULL && setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, + device, strlen(device) + 1)) + test_error("setsockopt(SO_BINDTODEVICE, %s)", device); + + if (connect(sk, daddr, addr_size)) + test_error("connect()"); + + if (state->info.tcpi_options & TCPI_OPT_SACK) { + opts[opt_nr].opt_code = TCPOPT_SACK_PERMITTED; + opts[opt_nr].opt_val = 0; + opt_nr++; + } + if (state->info.tcpi_options & TCPI_OPT_WSCALE) { + opts[opt_nr].opt_code = TCPOPT_WINDOW; + opts[opt_nr].opt_val = state->info.tcpi_snd_wscale + + (state->info.tcpi_rcv_wscale << 16); + opt_nr++; + } + if (state->info.tcpi_options & TCPI_OPT_TIMESTAMPS) { + opts[opt_nr].opt_code = TCPOPT_TIMESTAMP; + opts[opt_nr].opt_val = 0; + opt_nr++; + } + opts[opt_nr].opt_code = TCPOPT_MAXSEG; + opts[opt_nr].opt_val = state->mss; + opt_nr++; + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_OPTIONS, opts, opt_nr * sizeof(opts[0]))) + test_error("setsockopt(TCP_REPAIR_OPTIONS)"); + + if (state->info.tcpi_options & TCPI_OPT_TIMESTAMPS) { + if (setsockopt(sk, SOL_TCP, TCP_TIMESTAMP, + &state->timestamp, opt_nr * sizeof(opts[0]))) + test_error("setsockopt(TCP_TIMESTAMP)"); + } + test_sock_restore_queue(sk, TCP_RECV_QUEUE, state->in.buf, state->inq_len); + test_sock_restore_queue(sk, TCP_SEND_QUEUE, state->out.buf, state->outq_len); + if (setsockopt(sk, SOL_TCP, TCP_REPAIR_WINDOW, &state->trw, sizeof(state->trw))) + test_error("setsockopt(TCP_REPAIR_WINDOW)"); +} + +void test_ao_restore(int sk, struct tcp_ao_repair *state) +{ + if (setsockopt(sk, SOL_TCP, TCP_AO_REPAIR, state, sizeof(*state))) + test_error("setsockopt(TCP_AO_REPAIR)"); +} + +void test_sock_state_free(struct tcp_sock_state *state) +{ + free(state->out.buf); + free(state->in.buf); +} + +void test_enable_repair(int sk) +{ + int val = TCP_REPAIR_ON; + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val))) + test_error("setsockopt(TCP_REPAIR)"); +} + +void test_disable_repair(int sk) +{ + int val = TCP_REPAIR_OFF_NO_WP; + + if (setsockopt(sk, SOL_TCP, TCP_REPAIR, &val, sizeof(val))) + test_error("setsockopt(TCP_REPAIR)"); +} + +void test_kill_sk(int sk) +{ + test_enable_repair(sk); + close(sk); +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/setup.c b/tools/testing/selftests/net/tcp_ao/lib/setup.c new file mode 100644 index 000000000000..a27cc03c9fbd --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/setup.c @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <fcntl.h> +#include <pthread.h> +#include <sched.h> +#include <signal.h> +#include "aolib.h" + +/* + * Can't be included in the header: it defines static variables which + * will be unique to every object. Let's include it only once here. + */ +#include "../../../kselftest.h" + +/* Prevent overriding of one thread's output by another */ +static pthread_mutex_t ksft_print_lock = PTHREAD_MUTEX_INITIALIZER; + +void __test_msg(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_print_msg("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} +void __test_ok(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_test_result_pass("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} +void __test_fail(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_test_result_fail("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} +void __test_xfail(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_test_result_xfail("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} +void __test_error(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_test_result_error("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} +void __test_skip(const char *buf) +{ + pthread_mutex_lock(&ksft_print_lock); + ksft_test_result_skip("%s", buf); + pthread_mutex_unlock(&ksft_print_lock); +} + +static volatile int failed; +static volatile int skipped; + +void test_failed(void) +{ + failed = 1; +} + +static void test_exit(void) +{ + if (failed) { + ksft_exit_fail(); + } else if (skipped) { + /* ksft_exit_skip() is different from ksft_exit_*() */ + ksft_print_cnts(); + exit(KSFT_SKIP); + } else { + ksft_exit_pass(); + } +} + +struct dlist_t { + void (*destruct)(void); + struct dlist_t *next; +}; +static struct dlist_t *destructors_list; + +void test_add_destructor(void (*d)(void)) +{ + struct dlist_t *p; + + p = malloc(sizeof(struct dlist_t)); + if (p == NULL) + test_error("malloc() failed"); + + p->next = destructors_list; + p->destruct = d; + destructors_list = p; +} + +static void test_destructor(void) __attribute__((destructor)); +static void test_destructor(void) +{ + while (destructors_list) { + struct dlist_t *p = destructors_list->next; + + destructors_list->destruct(); + free(destructors_list); + destructors_list = p; + } + test_exit(); +} + +static void sig_int(int signo) +{ + test_error("Caught SIGINT - exiting"); +} + +int open_netns(void) +{ + const char *netns_path = "/proc/thread-self/ns/net"; + int fd; + + fd = open(netns_path, O_RDONLY); + if (fd < 0) + test_error("open(%s)", netns_path); + return fd; +} + +int unshare_open_netns(void) +{ + if (unshare(CLONE_NEWNET) != 0) + test_error("unshare()"); + + return open_netns(); +} + +void switch_ns(int fd) +{ + if (setns(fd, CLONE_NEWNET)) + test_error("setns()"); +} + +int switch_save_ns(int new_ns) +{ + int ret = open_netns(); + + switch_ns(new_ns); + return ret; +} + +void switch_close_ns(int fd) +{ + if (setns(fd, CLONE_NEWNET)) + test_error("setns()"); + close(fd); +} + +static int nsfd_outside = -1; +static int nsfd_parent = -1; +static int nsfd_child = -1; +const char veth_name[] = "ktst-veth"; + +static void init_namespaces(void) +{ + nsfd_outside = open_netns(); + nsfd_parent = unshare_open_netns(); + nsfd_child = unshare_open_netns(); +} + +static void link_init(const char *veth, int family, uint8_t prefix, + union tcp_addr addr, union tcp_addr dest) +{ + if (link_set_up(veth)) + test_error("Failed to set link up"); + if (ip_addr_add(veth, family, addr, prefix)) + test_error("Failed to add ip address"); + if (ip_route_add(veth, family, addr, dest)) + test_error("Failed to add route"); +} + +static unsigned int nr_threads = 1; + +static pthread_mutex_t sync_lock = PTHREAD_MUTEX_INITIALIZER; +static pthread_cond_t sync_cond = PTHREAD_COND_INITIALIZER; +static volatile unsigned int stage_threads[2]; +static volatile unsigned int stage_nr; + +/* synchronize all threads in the same stage */ +void synchronize_threads(void) +{ + unsigned int q = stage_nr; + + pthread_mutex_lock(&sync_lock); + stage_threads[q]++; + if (stage_threads[q] == nr_threads) { + stage_nr ^= 1; + stage_threads[stage_nr] = 0; + pthread_cond_signal(&sync_cond); + } + while (stage_threads[q] < nr_threads) + pthread_cond_wait(&sync_cond, &sync_lock); + pthread_mutex_unlock(&sync_lock); +} + +__thread union tcp_addr this_ip_addr; +__thread union tcp_addr this_ip_dest; +int test_family; + +struct new_pthread_arg { + thread_fn func; + union tcp_addr my_ip; + union tcp_addr dest_ip; +}; +static void *new_pthread_entry(void *arg) +{ + struct new_pthread_arg *p = arg; + + this_ip_addr = p->my_ip; + this_ip_dest = p->dest_ip; + p->func(NULL); /* shouldn't return */ + exit(KSFT_FAIL); +} + +static void __test_skip_all(const char *msg) +{ + ksft_set_plan(1); + ksft_print_header(); + skipped = 1; + test_skip("%s", msg); + exit(KSFT_SKIP); +} + +void __test_init(unsigned int ntests, int family, unsigned int prefix, + union tcp_addr addr1, union tcp_addr addr2, + thread_fn peer1, thread_fn peer2) +{ + struct sigaction sa = { + .sa_handler = sig_int, + .sa_flags = SA_RESTART, + }; + time_t seed = time(NULL); + + sigemptyset(&sa.sa_mask); + if (sigaction(SIGINT, &sa, NULL)) + test_error("Can't set SIGINT handler"); + + test_family = family; + if (!kernel_config_has(KCONFIG_NET_NS)) + __test_skip_all(tests_skip_reason[KCONFIG_NET_NS]); + if (!kernel_config_has(KCONFIG_VETH)) + __test_skip_all(tests_skip_reason[KCONFIG_VETH]); + if (!kernel_config_has(KCONFIG_TCP_AO)) + __test_skip_all(tests_skip_reason[KCONFIG_TCP_AO]); + + ksft_set_plan(ntests); + test_print("rand seed %u", (unsigned int)seed); + srand(seed); + + ksft_print_header(); + init_namespaces(); + test_init_ftrace(nsfd_parent, nsfd_child); + + if (add_veth(veth_name, nsfd_parent, nsfd_child)) + test_error("Failed to add veth"); + + switch_ns(nsfd_child); + link_init(veth_name, family, prefix, addr2, addr1); + if (peer2) { + struct new_pthread_arg targ; + pthread_t t; + + targ.my_ip = addr2; + targ.dest_ip = addr1; + targ.func = peer2; + nr_threads++; + if (pthread_create(&t, NULL, new_pthread_entry, &targ)) + test_error("Failed to create pthread"); + } + switch_ns(nsfd_parent); + link_init(veth_name, family, prefix, addr1, addr2); + + this_ip_addr = addr1; + this_ip_dest = addr2; + peer1(NULL); + if (failed) + exit(KSFT_FAIL); + else + exit(KSFT_PASS); +} + +/* /proc/sys/net/core/optmem_max artifically limits the amount of memory + * that can be allocated with sock_kmalloc() on each socket in the system. + * It is not virtualized in v6.7, so it has to written outside test + * namespaces. To be nice a test will revert optmem back to the old value. + * Keeping it simple without any file lock, which means the tests that + * need to set/increase optmem value shouldn't run in parallel. + * Also, not re-entrant. + * Since commit f5769faeec36 ("net: Namespace-ify sysctl_optmem_max") + * it is per-namespace, keeping logic for non-virtualized optmem_max + * for v6.7, which supports TCP-AO. + */ +static const char *optmem_file = "/proc/sys/net/core/optmem_max"; +static size_t saved_optmem; +static int optmem_ns = -1; + +static bool is_optmem_namespaced(void) +{ + if (optmem_ns == -1) { + int old_ns = switch_save_ns(nsfd_child); + + optmem_ns = !access(optmem_file, F_OK); + switch_close_ns(old_ns); + } + return !!optmem_ns; +} + +size_t test_get_optmem(void) +{ + int old_ns = 0; + FILE *foptmem; + size_t ret; + + if (!is_optmem_namespaced()) + old_ns = switch_save_ns(nsfd_outside); + foptmem = fopen(optmem_file, "r"); + if (!foptmem) + test_error("failed to open %s", optmem_file); + + if (fscanf(foptmem, "%zu", &ret) != 1) + test_error("can't read from %s", optmem_file); + fclose(foptmem); + if (!is_optmem_namespaced()) + switch_close_ns(old_ns); + return ret; +} + +static void __test_set_optmem(size_t new, size_t *old) +{ + int old_ns = 0; + FILE *foptmem; + + if (old != NULL) + *old = test_get_optmem(); + + if (!is_optmem_namespaced()) + old_ns = switch_save_ns(nsfd_outside); + foptmem = fopen(optmem_file, "w"); + if (!foptmem) + test_error("failed to open %s", optmem_file); + + if (fprintf(foptmem, "%zu", new) <= 0) + test_error("can't write %zu to %s", new, optmem_file); + fclose(foptmem); + if (!is_optmem_namespaced()) + switch_close_ns(old_ns); +} + +static void test_revert_optmem(void) +{ + if (saved_optmem == 0) + return; + + __test_set_optmem(saved_optmem, NULL); +} + +void test_set_optmem(size_t value) +{ + if (saved_optmem == 0) { + __test_set_optmem(value, &saved_optmem); + test_add_destructor(test_revert_optmem); + } else { + __test_set_optmem(value, NULL); + } +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/sock.c b/tools/testing/selftests/net/tcp_ao/lib/sock.c new file mode 100644 index 000000000000..ef8e9031d47a --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/sock.c @@ -0,0 +1,730 @@ +// SPDX-License-Identifier: GPL-2.0 +#include <alloca.h> +#include <fcntl.h> +#include <inttypes.h> +#include <string.h> +#include "../../../../../include/linux/kernel.h" +#include "../../../../../include/linux/stringify.h" +#include "aolib.h" + +const unsigned int test_server_port = 7010; +int __test_listen_socket(int backlog, void *addr, size_t addr_sz) +{ + int err, sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP); + long flags; + + if (sk < 0) + test_error("socket()"); + + err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, veth_name, + strlen(veth_name) + 1); + if (err < 0) + test_error("setsockopt(SO_BINDTODEVICE)"); + + if (bind(sk, (struct sockaddr *)addr, addr_sz) < 0) + test_error("bind()"); + + flags = fcntl(sk, F_GETFL); + if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0)) + test_error("fcntl()"); + + if (listen(sk, backlog)) + test_error("listen()"); + + return sk; +} + +static int __test_wait_fd(int sk, struct timeval *tv, bool write) +{ + fd_set fds, efds; + int ret; + socklen_t slen = sizeof(ret); + + FD_ZERO(&fds); + FD_SET(sk, &fds); + FD_ZERO(&efds); + FD_SET(sk, &efds); + + errno = 0; + if (write) + ret = select(sk + 1, NULL, &fds, &efds, tv); + else + ret = select(sk + 1, &fds, NULL, &efds, tv); + if (ret < 0) + return -errno; + if (ret == 0) { + errno = ETIMEDOUT; + return -ETIMEDOUT; + } + + if (getsockopt(sk, SOL_SOCKET, SO_ERROR, &ret, &slen)) + return -errno; + if (ret) + return -ret; + return 0; +} + +int test_wait_fd(int sk, time_t sec, bool write) +{ + struct timeval tv = { .tv_sec = sec, }; + + return __test_wait_fd(sk, sec ? &tv : NULL, write); +} + +static bool __skpair_poll_should_stop(int sk, struct tcp_counters *c, + test_cnt condition) +{ + struct tcp_counters c2; + test_cnt diff; + + if (test_get_tcp_counters(sk, &c2)) + test_error("test_get_tcp_counters()"); + + diff = test_cmp_counters(c, &c2); + test_tcp_counters_free(&c2); + return (diff & condition) == condition; +} + +/* How often wake up and check netns counters & paired (*err) */ +#define POLL_USEC 150 +static int __test_skpair_poll(int sk, bool write, uint64_t timeout, + struct tcp_counters *c, test_cnt cond, + volatile int *err) +{ + uint64_t t; + + for (t = 0; t <= timeout * 1000000; t += POLL_USEC) { + struct timeval tv = { .tv_usec = POLL_USEC, }; + int ret; + + ret = __test_wait_fd(sk, &tv, write); + if (ret != -ETIMEDOUT) + return ret; + if (c && cond && __skpair_poll_should_stop(sk, c, cond)) + break; + if (err && *err) + return *err; + } + if (err) + *err = -ETIMEDOUT; + return -ETIMEDOUT; +} + +int __test_connect_socket(int sk, const char *device, + void *addr, size_t addr_sz, bool async) +{ + long flags; + int err; + + if (device != NULL) { + err = setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE, device, + strlen(device) + 1); + if (err < 0) + test_error("setsockopt(SO_BINDTODEVICE, %s)", device); + } + + flags = fcntl(sk, F_GETFL); + if ((flags < 0) || (fcntl(sk, F_SETFL, flags | O_NONBLOCK) < 0)) + test_error("fcntl()"); + + if (connect(sk, addr, addr_sz) < 0) { + if (errno != EINPROGRESS) { + err = -errno; + goto out; + } + if (async) + return sk; + err = test_wait_fd(sk, TEST_TIMEOUT_SEC, 1); + if (err) + goto out; + } + return sk; + +out: + close(sk); + return err; +} + +int test_skpair_wait_poll(int sk, bool write, + test_cnt cond, volatile int *err) +{ + struct tcp_counters c; + int ret; + + *err = 0; + if (test_get_tcp_counters(sk, &c)) + test_error("test_get_tcp_counters()"); + synchronize_threads(); /* 1: init skpair & read nscounters */ + + ret = __test_skpair_poll(sk, write, TEST_TIMEOUT_SEC, &c, cond, err); + test_tcp_counters_free(&c); + return ret; +} + +int _test_skpair_connect_poll(int sk, const char *device, + void *addr, size_t addr_sz, + test_cnt condition, volatile int *err) +{ + struct tcp_counters c; + int ret; + + *err = 0; + if (test_get_tcp_counters(sk, &c)) + test_error("test_get_tcp_counters()"); + synchronize_threads(); /* 1: init skpair & read nscounters */ + ret = __test_connect_socket(sk, device, addr, addr_sz, true); + if (ret < 0) { + test_tcp_counters_free(&c); + return (*err = ret); + } + ret = __test_skpair_poll(sk, 1, TEST_TIMEOUT_SEC, &c, condition, err); + if (ret < 0) + close(sk); + test_tcp_counters_free(&c); + return ret; +} + +int __test_set_md5(int sk, void *addr, size_t addr_sz, uint8_t prefix, + int vrf, const char *password) +{ + size_t pwd_len = strlen(password); + struct tcp_md5sig md5sig = {}; + + md5sig.tcpm_keylen = pwd_len; + memcpy(md5sig.tcpm_key, password, pwd_len); + md5sig.tcpm_flags = TCP_MD5SIG_FLAG_PREFIX; + md5sig.tcpm_prefixlen = prefix; + if (vrf >= 0) { + md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX; + md5sig.tcpm_ifindex = (uint8_t)vrf; + } + memcpy(&md5sig.tcpm_addr, addr, addr_sz); + + errno = 0; + return setsockopt(sk, IPPROTO_TCP, TCP_MD5SIG_EXT, + &md5sig, sizeof(md5sig)); +} + + +int test_prepare_key_sockaddr(struct tcp_ao_add *ao, const char *alg, + void *addr, size_t addr_sz, bool set_current, bool set_rnext, + uint8_t prefix, uint8_t vrf, uint8_t sndid, uint8_t rcvid, + uint8_t maclen, uint8_t keyflags, + uint8_t keylen, const char *key) +{ + memset(ao, 0, sizeof(struct tcp_ao_add)); + + ao->set_current = !!set_current; + ao->set_rnext = !!set_rnext; + ao->prefix = prefix; + ao->sndid = sndid; + ao->rcvid = rcvid; + ao->maclen = maclen; + ao->keyflags = keyflags; + ao->keylen = keylen; + ao->ifindex = vrf; + + memcpy(&ao->addr, addr, addr_sz); + + if (strlen(alg) > 64) + return -ENOBUFS; + strncpy(ao->alg_name, alg, 64); + + memcpy(ao->key, key, + (keylen > TCP_AO_MAXKEYLEN) ? TCP_AO_MAXKEYLEN : keylen); + return 0; +} + +static int test_get_ao_keys_nr(int sk) +{ + struct tcp_ao_getsockopt tmp = {}; + socklen_t tmp_sz = sizeof(tmp); + int ret; + + tmp.nkeys = 1; + tmp.get_all = 1; + + ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz); + if (ret) + return -errno; + return (int)tmp.nkeys; +} + +int test_get_one_ao(int sk, struct tcp_ao_getsockopt *out, + void *addr, size_t addr_sz, uint8_t prefix, + uint8_t sndid, uint8_t rcvid) +{ + struct tcp_ao_getsockopt tmp = {}; + socklen_t tmp_sz = sizeof(tmp); + int ret; + + memcpy(&tmp.addr, addr, addr_sz); + tmp.prefix = prefix; + tmp.sndid = sndid; + tmp.rcvid = rcvid; + tmp.nkeys = 1; + + ret = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, &tmp, &tmp_sz); + if (ret) + return ret; + if (tmp.nkeys != 1) + return -E2BIG; + *out = tmp; + return 0; +} + +int test_get_ao_info(int sk, struct tcp_ao_info_opt *out) +{ + socklen_t sz = sizeof(*out); + + out->reserved = 0; + out->reserved2 = 0; + if (getsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, out, &sz)) + return -errno; + if (sz != sizeof(*out)) + return -EMSGSIZE; + return 0; +} + +int test_set_ao_info(int sk, struct tcp_ao_info_opt *in) +{ + socklen_t sz = sizeof(*in); + + in->reserved = 0; + in->reserved2 = 0; + if (setsockopt(sk, IPPROTO_TCP, TCP_AO_INFO, in, sz)) + return -errno; + return 0; +} + +int test_cmp_getsockopt_setsockopt(const struct tcp_ao_add *a, + const struct tcp_ao_getsockopt *b) +{ + bool is_kdf_aes_128_cmac = false; + bool is_cmac_aes = false; + + if (!strcmp("cmac(aes128)", a->alg_name)) { + is_kdf_aes_128_cmac = (a->keylen != 16); + is_cmac_aes = true; + } + +#define __cmp_ao(member) \ +do { \ + if (b->member != a->member) { \ + test_fail("getsockopt(): " __stringify(member) " %u != %u", \ + b->member, a->member); \ + return -1; \ + } \ +} while(0) + __cmp_ao(sndid); + __cmp_ao(rcvid); + __cmp_ao(prefix); + __cmp_ao(keyflags); + __cmp_ao(ifindex); + if (a->maclen) { + __cmp_ao(maclen); + } else if (b->maclen != 12) { + test_fail("getsockopt(): expected default maclen 12, but it's %u", + b->maclen); + return -1; + } + if (!is_kdf_aes_128_cmac) { + __cmp_ao(keylen); + } else if (b->keylen != 16) { + test_fail("getsockopt(): expected keylen 16 for cmac(aes128), but it's %u", + b->keylen); + return -1; + } +#undef __cmp_ao + if (!is_kdf_aes_128_cmac && memcmp(b->key, a->key, a->keylen)) { + test_fail("getsockopt(): returned key is different `%s' != `%s'", + b->key, a->key); + return -1; + } + if (memcmp(&b->addr, &a->addr, sizeof(b->addr))) { + test_fail("getsockopt(): returned address is different"); + return -1; + } + if (!is_cmac_aes && strcmp(b->alg_name, a->alg_name)) { + test_fail("getsockopt(): returned algorithm %s is different than %s", b->alg_name, a->alg_name); + return -1; + } + if (is_cmac_aes && strcmp(b->alg_name, "cmac(aes)")) { + test_fail("getsockopt(): returned algorithm %s is different than cmac(aes)", b->alg_name); + return -1; + } + /* For a established key rotation test don't add a key with + * set_current = 1, as it's likely to change by peer's request; + * rather use setsockopt(TCP_AO_INFO) + */ + if (a->set_current != b->is_current) { + test_fail("getsockopt(): returned key is not Current_key"); + return -1; + } + if (a->set_rnext != b->is_rnext) { + test_fail("getsockopt(): returned key is not RNext_key"); + return -1; + } + + return 0; +} + +int test_cmp_getsockopt_setsockopt_ao(const struct tcp_ao_info_opt *a, + const struct tcp_ao_info_opt *b) +{ + /* No check for ::current_key, as it may change by the peer */ + if (a->ao_required != b->ao_required) { + test_fail("getsockopt(): returned ao doesn't have ao_required"); + return -1; + } + if (a->accept_icmps != b->accept_icmps) { + test_fail("getsockopt(): returned ao doesn't accept ICMPs"); + return -1; + } + if (a->set_rnext && a->rnext != b->rnext) { + test_fail("getsockopt(): RNext KeyID has changed"); + return -1; + } +#define __cmp_cnt(member) \ +do { \ + if (b->member != a->member) { \ + test_fail("getsockopt(): " __stringify(member) " %llu != %llu", \ + b->member, a->member); \ + return -1; \ + } \ +} while(0) + if (a->set_counters) { + __cmp_cnt(pkt_good); + __cmp_cnt(pkt_bad); + __cmp_cnt(pkt_key_not_found); + __cmp_cnt(pkt_ao_required); + __cmp_cnt(pkt_dropped_icmp); + } +#undef __cmp_cnt + return 0; +} + +int test_get_tcp_counters(int sk, struct tcp_counters *out) +{ + struct tcp_ao_getsockopt *key_dump; + socklen_t key_dump_sz = sizeof(*key_dump); + struct tcp_ao_info_opt info = {}; + bool c1, c2, c3, c4, c5, c6, c7, c8; + struct netstat *ns; + int err, nr_keys; + + memset(out, 0, sizeof(*out)); + + /* per-netns */ + ns = netstat_read(); + out->ao.netns_ao_good = netstat_get(ns, "TCPAOGood", &c1); + out->ao.netns_ao_bad = netstat_get(ns, "TCPAOBad", &c2); + out->ao.netns_ao_key_not_found = netstat_get(ns, "TCPAOKeyNotFound", &c3); + out->ao.netns_ao_required = netstat_get(ns, "TCPAORequired", &c4); + out->ao.netns_ao_dropped_icmp = netstat_get(ns, "TCPAODroppedIcmps", &c5); + out->netns_md5_notfound = netstat_get(ns, "TCPMD5NotFound", &c6); + out->netns_md5_unexpected = netstat_get(ns, "TCPMD5Unexpected", &c7); + out->netns_md5_failure = netstat_get(ns, "TCPMD5Failure", &c8); + netstat_free(ns); + if (c1 || c2 || c3 || c4 || c5 || c6 || c7 || c8) + return -EOPNOTSUPP; + + err = test_get_ao_info(sk, &info); + if (err == -ENOENT) + return 0; + if (err) + return err; + + /* per-socket */ + out->ao.ao_info_pkt_good = info.pkt_good; + out->ao.ao_info_pkt_bad = info.pkt_bad; + out->ao.ao_info_pkt_key_not_found = info.pkt_key_not_found; + out->ao.ao_info_pkt_ao_required = info.pkt_ao_required; + out->ao.ao_info_pkt_dropped_icmp = info.pkt_dropped_icmp; + + /* per-key */ + nr_keys = test_get_ao_keys_nr(sk); + if (nr_keys < 0) + return nr_keys; + if (nr_keys == 0) + test_error("test_get_ao_keys_nr() == 0"); + out->ao.nr_keys = (size_t)nr_keys; + key_dump = calloc(nr_keys, key_dump_sz); + if (!key_dump) + return -errno; + + key_dump[0].nkeys = nr_keys; + key_dump[0].get_all = 1; + err = getsockopt(sk, IPPROTO_TCP, TCP_AO_GET_KEYS, + key_dump, &key_dump_sz); + if (err) { + free(key_dump); + return -errno; + } + + out->ao.key_cnts = calloc(nr_keys, sizeof(out->ao.key_cnts[0])); + if (!out->ao.key_cnts) { + free(key_dump); + return -errno; + } + + while (nr_keys--) { + out->ao.key_cnts[nr_keys].sndid = key_dump[nr_keys].sndid; + out->ao.key_cnts[nr_keys].rcvid = key_dump[nr_keys].rcvid; + out->ao.key_cnts[nr_keys].pkt_good = key_dump[nr_keys].pkt_good; + out->ao.key_cnts[nr_keys].pkt_bad = key_dump[nr_keys].pkt_bad; + } + free(key_dump); + + return 0; +} + +test_cnt test_cmp_counters(struct tcp_counters *before, + struct tcp_counters *after) +{ +#define __cmp(cnt, e_cnt) \ +do { \ + if (before->cnt > after->cnt) \ + test_error("counter " __stringify(cnt) " decreased"); \ + if (before->cnt != after->cnt) \ + ret |= e_cnt; \ +} while (0) + + test_cnt ret = 0; + size_t i; + + if (before->ao.nr_keys != after->ao.nr_keys) + test_error("the number of keys has changed"); + + _for_each_counter(__cmp); + + i = before->ao.nr_keys; + while (i--) { + __cmp(ao.key_cnts[i].pkt_good, TEST_CNT_KEY_GOOD); + __cmp(ao.key_cnts[i].pkt_bad, TEST_CNT_KEY_BAD); + } +#undef __cmp + return ret; +} + +int test_assert_counters_sk(const char *tst_name, + struct tcp_counters *before, + struct tcp_counters *after, + test_cnt expected) +{ +#define __cmp_ao(cnt, e_cnt) \ +do { \ + if (before->cnt > after->cnt) { \ + test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64, \ + tst_name ?: "", before->cnt, after->cnt); \ + return -1; \ + } \ + if ((before->cnt != after->cnt) != !!(expected & e_cnt)) { \ + test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64, \ + tst_name ?: "", (expected & e_cnt) ? "" : "not ", \ + before->cnt, after->cnt); \ + return -1; \ + } \ +} while (0) + + errno = 0; + _for_each_counter(__cmp_ao); + return 0; +#undef __cmp_ao +} + +int test_assert_counters_key(const char *tst_name, + struct tcp_ao_counters *before, + struct tcp_ao_counters *after, + test_cnt expected, int sndid, int rcvid) +{ + size_t i; +#define __cmp_ao(i, cnt, e_cnt) \ +do { \ + if (before->key_cnts[i].cnt > after->key_cnts[i].cnt) { \ + test_fail("%s: Decreased counter " __stringify(cnt) " %" PRIu64 " > %" PRIu64 " for key %u:%u", \ + tst_name ?: "", before->key_cnts[i].cnt, \ + after->key_cnts[i].cnt, \ + before->key_cnts[i].sndid, \ + before->key_cnts[i].rcvid); \ + return -1; \ + } \ + if ((before->key_cnts[i].cnt != after->key_cnts[i].cnt) != !!(expected & e_cnt)) { \ + test_fail("%s: Counter " __stringify(cnt) " was %sexpected to increase %" PRIu64 " => %" PRIu64 " for key %u:%u", \ + tst_name ?: "", (expected & e_cnt) ? "" : "not ",\ + before->key_cnts[i].cnt, \ + after->key_cnts[i].cnt, \ + before->key_cnts[i].sndid, \ + before->key_cnts[i].rcvid); \ + return -1; \ + } \ +} while (0) + + if (before->nr_keys != after->nr_keys) { + test_fail("%s: Keys changed on the socket %zu != %zu", + tst_name, before->nr_keys, after->nr_keys); + return -1; + } + + /* per-key */ + i = before->nr_keys; + while (i--) { + if (sndid >= 0 && before->key_cnts[i].sndid != sndid) + continue; + if (rcvid >= 0 && before->key_cnts[i].rcvid != rcvid) + continue; + __cmp_ao(i, pkt_good, TEST_CNT_KEY_GOOD); + __cmp_ao(i, pkt_bad, TEST_CNT_KEY_BAD); + } + return 0; +#undef __cmp_ao +} + +void test_tcp_counters_free(struct tcp_counters *cnts) +{ + free(cnts->ao.key_cnts); +} + +#define TEST_BUF_SIZE 4096 +static ssize_t _test_server_run(int sk, ssize_t quota, struct tcp_counters *c, + test_cnt cond, volatile int *err, + time_t timeout_sec) +{ + ssize_t total = 0; + + do { + char buf[TEST_BUF_SIZE]; + ssize_t bytes, sent; + int ret; + + ret = __test_skpair_poll(sk, 0, timeout_sec, c, cond, err); + if (ret) + return ret; + + bytes = recv(sk, buf, sizeof(buf), 0); + + if (bytes < 0) + test_error("recv(): %zd", bytes); + if (bytes == 0) + break; + + ret = __test_skpair_poll(sk, 1, timeout_sec, c, cond, err); + if (ret) + return ret; + + sent = send(sk, buf, bytes, 0); + if (sent == 0) + break; + if (sent != bytes) + test_error("send()"); + total += bytes; + } while (!quota || total < quota); + + return total; +} + +ssize_t test_server_run(int sk, ssize_t quota, time_t timeout_sec) +{ + return _test_server_run(sk, quota, NULL, 0, NULL, + timeout_sec ?: TEST_TIMEOUT_SEC); +} + +int test_skpair_server(int sk, ssize_t quota, test_cnt cond, volatile int *err) +{ + struct tcp_counters c; + ssize_t ret; + + *err = 0; + if (test_get_tcp_counters(sk, &c)) + test_error("test_get_tcp_counters()"); + synchronize_threads(); /* 1: init skpair & read nscounters */ + + ret = _test_server_run(sk, quota, &c, cond, err, TEST_TIMEOUT_SEC); + test_tcp_counters_free(&c); + return ret; +} + +static ssize_t test_client_loop(int sk, size_t buf_sz, const size_t msg_len, + struct tcp_counters *c, test_cnt cond, + volatile int *err) +{ + char msg[msg_len]; + int nodelay = 1; + char *buf; + size_t i; + + buf = alloca(buf_sz); + if (!buf) + return -ENOMEM; + randomize_buffer(buf, buf_sz); + + if (setsockopt(sk, IPPROTO_TCP, TCP_NODELAY, &nodelay, sizeof(nodelay))) + test_error("setsockopt(TCP_NODELAY)"); + + for (i = 0; i < buf_sz; i += min(msg_len, buf_sz - i)) { + size_t sent, bytes = min(msg_len, buf_sz - i); + int ret; + + ret = __test_skpair_poll(sk, 1, TEST_TIMEOUT_SEC, c, cond, err); + if (ret) + return ret; + + sent = send(sk, buf + i, bytes, 0); + if (sent == 0) + break; + if (sent != bytes) + test_error("send()"); + + bytes = 0; + do { + ssize_t got; + + ret = __test_skpair_poll(sk, 0, TEST_TIMEOUT_SEC, + c, cond, err); + if (ret) + return ret; + + got = recv(sk, msg + bytes, sizeof(msg) - bytes, 0); + if (got <= 0) + return i; + bytes += got; + } while (bytes < sent); + if (bytes > sent) + test_error("recv(): %zd > %zd", bytes, sent); + if (memcmp(buf + i, msg, bytes) != 0) { + test_fail("received message differs"); + return -1; + } + } + return i; +} + +int test_client_verify(int sk, const size_t msg_len, const size_t nr) +{ + size_t buf_sz = msg_len * nr; + ssize_t ret; + + ret = test_client_loop(sk, buf_sz, msg_len, NULL, 0, NULL); + if (ret < 0) + return (int)ret; + return ret != buf_sz ? -1 : 0; +} + +int test_skpair_client(int sk, const size_t msg_len, const size_t nr, + test_cnt cond, volatile int *err) +{ + struct tcp_counters c; + size_t buf_sz = msg_len * nr; + ssize_t ret; + + *err = 0; + if (test_get_tcp_counters(sk, &c)) + test_error("test_get_tcp_counters()"); + synchronize_threads(); /* 1: init skpair & read nscounters */ + + ret = test_client_loop(sk, buf_sz, msg_len, &c, cond, err); + test_tcp_counters_free(&c); + if (ret < 0) + return (int)ret; + return ret != buf_sz ? -1 : 0; +} diff --git a/tools/testing/selftests/net/tcp_ao/lib/utils.c b/tools/testing/selftests/net/tcp_ao/lib/utils.c new file mode 100644 index 000000000000..bdf5522c9213 --- /dev/null +++ b/tools/testing/selftests/net/tcp_ao/lib/utils.c @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: GPL-2.0 +#include "aolib.h" +#include <string.h> + +void randomize_buffer(void *buf, size_t buflen) +{ + int *p = (int *)buf; + size_t words = buflen / sizeof(int); + size_t leftover = buflen % sizeof(int); + + if (!buflen) + return; + + while (words--) + *p++ = rand(); + + if (leftover) { + int tmp = rand(); + + memcpy(buf + buflen - leftover, &tmp, leftover); + } +} + +__printf(3, 4) int test_echo(const char *fname, bool append, + const char *fmt, ...) +{ + size_t len, written; + va_list vargs; + char *msg; + FILE *f; + + f = fopen(fname, append ? "a" : "w"); + if (!f) + return -errno; + + va_start(vargs, fmt); + msg = test_snprintf(fmt, vargs); + va_end(vargs); + if (!msg) { + fclose(f); + return -1; + } + len = strlen(msg); + written = fwrite(msg, 1, len, f); + fclose(f); + free(msg); + return written == len ? 0 : -1; +} + +const struct sockaddr_in6 addr_any6 = { + .sin6_family = AF_INET6, +}; + +const struct sockaddr_in addr_any4 = { + .sin_family = AF_INET, +}; |