diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-04-21 15:21:52 -0600 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-04-22 00:02:28 -0600 |
commit | 2018285962935293a0da19f16378169d9521342f (patch) | |
tree | 59baa87dd4640aab1d18fb8fd0adc1d9036be823 | |
parent | wg_cookie: ensure gc is called regularly (diff) | |
download | wireguard-freebsd-2018285962935293a0da19f16378169d9521342f.tar.xz wireguard-freebsd-2018285962935293a0da19f16378169d9521342f.zip |
if_wg: port allowedips selftest from Linux code and fix bugs
And then fix broken allowedips implementation for the static unit tests
to pass.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | TODO.md | 3 | ||||
-rw-r--r-- | src/if_wg.c | 129 | ||||
-rw-r--r-- | src/selftest/allowedips.c | 608 |
3 files changed, 674 insertions, 66 deletions
@@ -7,8 +7,7 @@ - Work out `priv_check` from vnet perspective. (There's no `ns_capable()` on FreeBSD, just `capable()`, which makes it a bit weird for one jail to have permissions in another.) -- Audit allowedips / radix tree checks, and make sure it's actually behaving as - expected. (It might be useful to port [this selftest](https://git.zx2c4.com/wireguard-linux/tree/drivers/net/wireguard/selftest/allowedips.c).) +- Port ratelimiter and counter [selftests](https://git.zx2c4.com/wireguard-linux/tree/drivers/net/wireguard/selftest). - Make code style consistent with one FreeBSD way, rather than a mix of styles. - Run ratelimiter gc in a properly scheduled manner. - Make sure noise state machine is correct. diff --git a/src/if_wg.c b/src/if_wg.c index fee26ae..27e94f0 100644 --- a/src/if_wg.c +++ b/src/if_wg.c @@ -151,12 +151,24 @@ struct wg_endpoint { } e_local; }; +struct aip_addr { + uint8_t length; + union { + uint8_t bytes[16]; + uint32_t ip; + uint32_t ip6[4]; + struct in_addr in; + struct in6_addr in6; + }; +}; + struct wg_aip { struct radix_node a_nodes[2]; LIST_ENTRY(wg_aip) a_entry; - struct sockaddr_storage a_addr; - struct sockaddr_storage a_mask; + struct aip_addr a_addr; + struct aip_addr a_mask; struct wg_peer *a_peer; + sa_family_t a_af; }; struct wg_packet { @@ -518,64 +530,49 @@ wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void struct radix_node_head *root; struct radix_node *node; struct wg_aip *aip; - struct sockaddr_in *sin_addr, *sin_mask; - struct sockaddr_in6 *sin6_addr, *sin6_mask; - bool need_free = false; int i, ret = 0; if ((aip = malloc(sizeof(*aip), M_WG, M_NOWAIT | M_ZERO)) == NULL) return (ENOBUFS); + aip->a_peer = peer; + aip->a_af = af; switch (af) { case AF_INET: if (cidr > 32) cidr = 32; root = sc->sc_aip4; - - sin_addr = (struct sockaddr_in *)&aip->a_addr; - sin_mask = (struct sockaddr_in *)&aip->a_mask; - - sin_addr->sin_len = sizeof(struct sockaddr_in); - sin_addr->sin_family = AF_INET; - sin_addr->sin_addr = *(const struct in_addr *)addr; - - sin_mask->sin_len = sizeof(struct sockaddr_in); - sin_mask->sin_addr.s_addr = - htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff); - sin_addr->sin_addr.s_addr &= sin_mask->sin_addr.s_addr; + aip->a_addr.in = *(const struct in_addr *)addr; + aip->a_mask.ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff); + aip->a_addr.ip &= aip->a_mask.ip; + aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr); break; case AF_INET6: if (cidr > 128) cidr = 128; root = sc->sc_aip6; - - sin6_addr = (struct sockaddr_in6 *)&aip->a_addr; - sin6_mask = (struct sockaddr_in6 *)&aip->a_mask; - - sin6_addr->sin6_len = sizeof(struct sockaddr_in6); - sin6_addr->sin6_family = AF_INET6; - sin6_addr->sin6_addr = *(const struct in6_addr *)addr; - - sin6_mask->sin6_len = sizeof(struct sockaddr_in6); - in6_prefixlen2mask(&sin6_mask->sin6_addr, cidr); + aip->a_addr.in6 = *(const struct in6_addr *)addr; + in6_prefixlen2mask(&aip->a_mask.in6, cidr); for (i = 0; i < 4; i++) - sin6_addr->sin6_addr.__u6_addr.__u6_addr32[i] &= - sin6_mask->sin6_addr.__u6_addr.__u6_addr32[i]; + aip->a_addr.ip6[i] &= aip->a_mask.ip6[i]; + aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr); break; default: free(aip, M_WG); return (EAFNOSUPPORT); } - aip->a_peer = peer; - RADIX_NODE_HEAD_LOCK(root); node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes); - if (node == aip->a_nodes) { LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry); peer->p_aips_num++; - } else { - need_free = true; - aip = (struct wg_aip *) node; + } else if (!node) + node = root->rnh_lookup(&aip->a_addr, &aip->a_mask, &root->rh); + if (!node) { + free(aip, M_WG); + return (ENOMEM); + } else if (node != aip->a_nodes) { + free(aip, M_WG); + aip = (struct wg_aip *)node; if (aip->a_peer != peer) { LIST_REMOVE(aip, a_entry); aip->a_peer->p_aips_num--; @@ -585,40 +582,36 @@ wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void } } RADIX_NODE_HEAD_UNLOCK(root); - if (need_free) - free(aip, M_WG); return (ret); } static struct wg_peer * -wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *addr) +wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a) { struct radix_node_head *root; struct radix_node *node; struct wg_peer *peer; - struct sockaddr_storage ss; - struct sockaddr_in *sin = (struct sockaddr_in *)&ss; - struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&ss; + struct aip_addr addr; RADIX_NODE_HEAD_RLOCK_TRACKER; switch (af) { case AF_INET: root = sc->sc_aip4; - sin->sin_len = sizeof(struct sockaddr_in); - sin->sin_addr = *(struct in_addr *)addr; + memcpy(&addr.in, a, sizeof(addr.in)); + addr.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr); break; case AF_INET6: root = sc->sc_aip6; - sin6->sin6_len = sizeof(struct sockaddr_in6); - sin6->sin6_addr = *(struct in6_addr *)addr; + memcpy(&addr.in6, a, sizeof(addr.in6)); + addr.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr); break; default: - panic("invalid wg_aip_lookup af"); + return NULL; } RADIX_NODE_HEAD_RLOCK(root); - node = root->rnh_matchaddr((struct sockaddr *)&ss, &root->rh); - peer = node != NULL ? ((struct wg_aip *) node)->a_peer : NULL; + node = root->rnh_matchaddr(&addr, &root->rh); + peer = node != NULL ? ((struct wg_aip *)node)->a_peer : NULL; RADIX_NODE_HEAD_RUNLOCK(root); return (peer); @@ -631,9 +624,9 @@ wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer) RADIX_NODE_HEAD_LOCK(sc->sc_aip4); LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) { - if (aip->a_addr.ss_family == AF_INET) { + if (aip->a_af == AF_INET) { if (sc->sc_aip4->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip4->rh) == NULL) - panic("art_delete failed to delete aip %p", aip); + panic("failed to delete aip %p", aip); LIST_REMOVE(aip, a_entry); peer->p_aips_num--; free(aip, M_WG); @@ -643,9 +636,9 @@ wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer) RADIX_NODE_HEAD_LOCK(sc->sc_aip6); LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) { - if (aip->a_addr.ss_family == AF_INET6) { + if (aip->a_af == AF_INET6) { if (sc->sc_aip6->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip6->rh) == NULL) - panic("art_delete failed to delete aip %p", aip); + panic("failed to delete aip %p", aip); LIST_REMOVE(aip, a_entry); peer->p_aips_num--; free(aip, M_WG); @@ -2357,16 +2350,12 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) err = ENOMEM; goto err_aip; } - if (aip->a_addr.ss_family == AF_INET) { - struct sockaddr_in *sin = (struct sockaddr_in *)&aip->a_addr; - nvlist_add_binary(nvl_aip, "ipv4", &sin->sin_addr, sizeof(sin->sin_addr)); - sin = (struct sockaddr_in *)&aip->a_mask; - nvlist_add_number(nvl_aip, "cidr", __builtin_popcount(sin->sin_addr.s_addr)); - } else if (aip->a_addr.ss_family == AF_INET6) { - struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&aip->a_addr; - nvlist_add_binary(nvl_aip, "ipv6", &sin6->sin6_addr, sizeof(sin6->sin6_addr)); - sin6 = (struct sockaddr_in6 *)&aip->a_mask; - nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&sin6->sin6_addr, NULL)); + if (aip->a_af == AF_INET) { + nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in)); + nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip)); + } else if (aip->a_af == AF_INET6) { + nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6)); + nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL)); } } nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count); @@ -2613,8 +2602,8 @@ wg_clone_create(struct if_clone *ifc, int unit, caddr_t params) wg_queue_init(&sc->sc_decrypt_parallel, "decp"); /* TODO check rn_inithead return value */ - rn_inithead((void **)&sc->sc_aip4, offsetof(struct sockaddr_in, sin_addr)*NBBY); - rn_inithead((void **)&sc->sc_aip6, offsetof(struct sockaddr_in6, sin6_addr)*NBBY); + rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY); + rn_inithead((void **)&sc->sc_aip6, offsetof(struct aip_addr, in6) * NBBY); RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip4); RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip6); @@ -2792,6 +2781,16 @@ wg_prison_remove(void *obj, void *data __unused) return (0); } +#ifdef SELFTESTS +#include "selftest/allowedips.c" +static void wg_run_selftests(void) +{ + wg_allowedips_selftest(); +} +#else +static inline void wg_run_selftests(void) { } +#endif + static void wg_module_init(void) { @@ -2804,6 +2803,8 @@ wg_module_init(void) ratelimit_zone = uma_zcreate("wg ratelimit", sizeof(struct ratelimit), NULL, NULL, NULL, NULL, 0, 0); wg_osd_jail_slot = osd_jail_register(NULL, methods); + + wg_run_selftests(); } static void diff --git a/src/selftest/allowedips.c b/src/selftest/allowedips.c new file mode 100644 index 0000000..9678790 --- /dev/null +++ b/src/selftest/allowedips.c @@ -0,0 +1,608 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + */ + +#include <sys/cdefs.h> +#include <sys/param.h> +#include <sys/types.h> +#include <sys/systm.h> +#include <sys/queue.h> +#include <sys/endian.h> +#include <netinet/in.h> +#include <net/radix.h> + +#ifdef WG_ALLOWEDIPS_RANDOMIZED_TEST +enum { + NUM_PEERS = 2000, + NUM_RAND_ROUTES = 400, + NUM_MUTATED_ROUTES = 100, + NUM_QUERIES = NUM_RAND_ROUTES * NUM_MUTATED_ROUTES * 30 +}; + +struct horrible_allowedips { + LIST_HEAD(, horrible_allowedips_node) head; +}; + +struct horrible_allowedips_node { + LIST_ENTRY(horrible_allowedips_node) table; + struct aip_addr ip; + struct aip_addr mask; + uint8_t ip_version; + void *value; +}; + +static void horrible_allowedips_init(struct horrible_allowedips *table) +{ + LIST_INIT(&table->head); +} + +static void horrible_allowedips_free(struct horrible_allowedips *table) +{ + struct horrible_allowedips_node *node, *temp_node; + + LIST_FOREACH_SAFE(node, &table->head, table, temp_node) { + LIST_REMOVE(node, table); + free(node, M_WG); + } +} + +static inline struct aip_addr horrible_cidr_to_mask(uint8_t cidr) +{ + struct aip_addr mask; + + memset(&mask.in6, 0x00, 128 / 8); + memset(&mask.in6, 0xff, cidr / 8); + if (cidr % 32) + mask.ip6[cidr / 32] = (uint32_t)htonl( + (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); + return mask; +} + +static inline uint8_t horrible_mask_to_cidr(struct aip_addr subnet) +{ + return bitcount32(subnet.ip6[0]) + bitcount32(subnet.ip6[1]) + + bitcount32(subnet.ip6[2]) + bitcount32(subnet.ip6[3]); +} + +static inline void +horrible_mask_self(struct horrible_allowedips_node *node) +{ + if (node->ip_version == 4) { + node->ip.ip &= node->mask.ip; + } else if (node->ip_version == 6) { + node->ip.ip6[0] &= node->mask.ip6[0]; + node->ip.ip6[1] &= node->mask.ip6[1]; + node->ip.ip6[2] &= node->mask.ip6[2]; + node->ip.ip6[3] &= node->mask.ip6[3]; + } +} + +static inline bool +horrible_match_v4(const struct horrible_allowedips_node *node, + struct in_addr *ip) +{ + return (ip->s_addr & node->mask.ip) == node->ip.ip; +} + +static inline bool +horrible_match_v6(const struct horrible_allowedips_node *node, + struct in6_addr *ip) +{ + return (ip->__u6_addr.__u6_addr32[0] & node->mask.ip6[0]) == + node->ip.ip6[0] && + (ip->__u6_addr.__u6_addr32[1] & node->mask.ip6[1]) == + node->ip.ip6[1] && + (ip->__u6_addr.__u6_addr32[2] & node->mask.ip6[2]) == + node->ip.ip6[2] && + (ip->__u6_addr.__u6_addr32[3] & node->mask.ip6[3]) == + node->ip.ip6[3]; +} + +static void +horrible_insert_ordered(struct horrible_allowedips *table, + struct horrible_allowedips_node *node) +{ + struct horrible_allowedips_node *other = NULL, *where = NULL; + uint8_t my_cidr = horrible_mask_to_cidr(node->mask); + + LIST_FOREACH(other, &table->head, table) { + if (!memcmp(&other->mask, &node->mask, + sizeof(struct aip_addr)) && + !memcmp(&other->ip, &node->ip, + sizeof(struct aip_addr)) && + other->ip_version == node->ip_version) { + other->value = node->value; + free(node, M_WG); + return; + } + where = other; + if (horrible_mask_to_cidr(other->mask) <= my_cidr) + break; + } + if (!other && !where) + LIST_INSERT_HEAD(&table->head, node, table); + else if (!other) + LIST_INSERT_AFTER(where, node, table); + else + LIST_INSERT_BEFORE(where, node, table); +} + +static int +horrible_allowedips_insert_v4(struct horrible_allowedips *table, + struct in_addr *ip, uint8_t cidr, void *value) +{ + struct horrible_allowedips_node *node = malloc(sizeof(*node), M_WG, M_NOWAIT | M_ZERO); + + if (!node) + return ENOMEM; + node->ip.in = *ip; + node->mask = horrible_cidr_to_mask(cidr); + node->ip_version = 4; + node->value = value; + horrible_mask_self(node); + horrible_insert_ordered(table, node); + return 0; +} + +static int +horrible_allowedips_insert_v6(struct horrible_allowedips *table, + struct in6_addr *ip, uint8_t cidr, void *value) +{ + struct horrible_allowedips_node *node = malloc(sizeof(*node), M_WG, M_NOWAIT | M_ZERO); + + if (!node) + return ENOMEM; + node->ip.in6 = *ip; + node->mask = horrible_cidr_to_mask(cidr); + node->ip_version = 6; + node->value = value; + horrible_mask_self(node); + horrible_insert_ordered(table, node); + return 0; +} + +static void * +horrible_allowedips_lookup_v4(struct horrible_allowedips *table, + struct in_addr *ip) +{ + struct horrible_allowedips_node *node; + void *ret = NULL; + + LIST_FOREACH(node, &table->head, table) { + if (node->ip_version != 4) + continue; + if (horrible_match_v4(node, ip)) { + ret = node->value; + break; + } + } + return ret; +} + +static void * +horrible_allowedips_lookup_v6(struct horrible_allowedips *table, + struct in6_addr *ip) +{ + struct horrible_allowedips_node *node; + void *ret = NULL; + + LIST_FOREACH(node, &table->head, table) { + if (node->ip_version != 6) + continue; + if (horrible_match_v6(node, ip)) { + ret = node->value; + break; + } + } + return ret; +} + +static bool randomized_test(void) +{ + unsigned int i, j, k, mutate_amount, cidr; + uint8_t ip[16], mutate_mask[16], mutated[16]; + struct wg_peer **peers, *peer; + struct horrible_allowedips h; + struct wg_softc sc = { 0 }; + bool ret = false; + + rn_inithead((void **)&sc.sc_aip4, offsetof(struct aip_addr, in) * NBBY); + rn_inithead((void **)&sc.sc_aip6, offsetof(struct aip_addr, in6) * NBBY); + RADIX_NODE_HEAD_LOCK_INIT(sc.sc_aip4); + RADIX_NODE_HEAD_LOCK_INIT(sc.sc_aip6); + horrible_allowedips_init(&h); + + peers = mallocarray(NUM_PEERS, sizeof(*peers), M_WG, M_NOWAIT | M_ZERO); + if (!peers) { + printf("allowedips random self-test malloc: FAIL\n"); + goto free; + } + for (i = 0; i < NUM_PEERS; ++i) { + peers[i] = malloc(sizeof(*peers[i]), M_WG, M_NOWAIT | M_ZERO); + if (!peers[i]) { + printf("allowedips random self-test malloc: FAIL\n"); + goto free; + } + LIST_INIT(&peers[i]->p_aips); + peers[i]->p_aips_num = 0; + } + + for (i = 0; i < NUM_RAND_ROUTES; ++i) { + arc4random_buf(ip, 4); + cidr = arc4random_uniform(32) + 1; + peer = peers[arc4random_uniform(NUM_PEERS)]; + if (wg_aip_add(&sc, peer, AF_INET, ip, cidr)) { + printf("allowedips random self-test malloc: FAIL\n"); + goto free; + } + if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip, + cidr, peer)) { + printf("allowedips random self-test malloc: FAIL\n"); + goto free; + } + for (j = 0; j < NUM_MUTATED_ROUTES; ++j) { + memcpy(mutated, ip, 4); + arc4random_buf(mutate_mask, 4); + mutate_amount = arc4random_uniform(32); + for (k = 0; k < mutate_amount / 8; ++k) + mutate_mask[k] = 0xff; + mutate_mask[k] = 0xff + << ((8 - (mutate_amount % 8)) % 8); + for (; k < 4; ++k) + mutate_mask[k] = 0; + for (k = 0; k < 4; ++k) + mutated[k] = (mutated[k] & mutate_mask[k]) | + (~mutate_mask[k] & + arc4random_uniform(256)); + cidr = arc4random_uniform(32) + 1; + peer = peers[arc4random_uniform(NUM_PEERS)]; + if (wg_aip_add(&sc, peer, AF_INET, mutated, cidr)) { + printf("allowedips random self-test malloc: FAIL\n"); + goto free; + } + if (horrible_allowedips_insert_v4(&h, + (struct in_addr *)mutated, cidr, peer)) { + printf("allowedips random self-test malloc: FAIL\n"); + goto free; + } + } + } + + for (i = 0; i < NUM_RAND_ROUTES; ++i) { + arc4random_buf(ip, 16); + cidr = arc4random_uniform(128) + 1; + peer = peers[arc4random_uniform(NUM_PEERS)]; + if (wg_aip_add(&sc, peer, AF_INET6, ip, cidr)) { + printf("allowedips random self-test malloc: FAIL\n"); + goto free; + } + if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip, + cidr, peer)) { + printf("allowedips random self-test malloc: FAIL\n"); + goto free; + } + for (j = 0; j < NUM_MUTATED_ROUTES; ++j) { + memcpy(mutated, ip, 16); + arc4random_buf(mutate_mask, 16); + mutate_amount = arc4random_uniform(128); + for (k = 0; k < mutate_amount / 8; ++k) + mutate_mask[k] = 0xff; + mutate_mask[k] = 0xff + << ((8 - (mutate_amount % 8)) % 8); + for (; k < 4; ++k) + mutate_mask[k] = 0; + for (k = 0; k < 4; ++k) + mutated[k] = (mutated[k] & mutate_mask[k]) | + (~mutate_mask[k] & + arc4random_uniform(256)); + cidr = arc4random_uniform(128) + 1; + peer = peers[arc4random_uniform(NUM_PEERS)]; + if (wg_aip_add(&sc, peer, AF_INET6, mutated, cidr)) { + printf("allowedips random self-test malloc: FAIL\n"); + goto free; + } + if (horrible_allowedips_insert_v6( + &h, (struct in6_addr *)mutated, cidr, + peer)) { + printf("allowedips random self-test malloc: FAIL\n"); + goto free; + } + } + } + + for (i = 0; i < NUM_QUERIES; ++i) { + arc4random_buf(ip, 4); + if (wg_aip_lookup(&sc, AF_INET, ip) != + horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { + printf("allowedips random self-test: FAIL\n"); + goto free; + } + } + + for (i = 0; i < NUM_QUERIES; ++i) { + arc4random_buf(ip, 16); + if (wg_aip_lookup(&sc, AF_INET6, ip) != + horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { + printf("allowedips random self-test: FAIL\n"); + goto free; + } + } + ret = true; + +free: + horrible_allowedips_free(&h); + if (peers) { + for (i = 0; i < NUM_PEERS; ++i) + wg_aip_remove_all(&sc, peers[i]); + for (i = 0; i < NUM_PEERS; ++i) + free(peers[i], M_WG); + } + free(peers, M_WG); + return ret; +} +#endif + +static struct in_addr *ip4(uint8_t a, uint8_t b, uint8_t c, uint8_t d) +{ + static struct in_addr ip; + uint8_t *split = (uint8_t *)&ip; + + split[0] = a; + split[1] = b; + split[2] = c; + split[3] = d; + return &ip; +} + +static struct in6_addr *ip6(uint32_t a, uint32_t b, uint32_t c, uint32_t d) +{ + static struct in6_addr ip; + uint32_t *split = ip.__u6_addr.__u6_addr32; + + split[0] = htobe32(a); + split[1] = htobe32(b); + split[2] = htobe32(c); + split[3] = htobe32(d); + return &ip; +} + +static struct wg_peer *init_peer(void) +{ + struct wg_peer *peer = malloc(sizeof(*peer), M_WG, M_NOWAIT | M_ZERO); + + if (!peer) + return NULL; + LIST_INIT(&peer->p_aips); + peer->p_aips_num = 0; + return peer; +} + +#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) do { \ + int _r = wg_aip_add(&sc, mem, (version) == 6 ? AF_INET6 : AF_INET, \ + ip##version(ipa, ipb, ipc, ipd), cidr); \ + if (_r) { \ + printf("allowedips self-test insertion: FAIL (%d)\n", _r); \ + success = false; \ + } \ + } while (0) + +#define maybe_fail() do { \ + ++i; \ + if (!_s) { \ + printf("allowedips self-test %zu: FAIL\n", i); \ + success = false; \ + } \ + } while (0) + +#define test(version, mem, ipa, ipb, ipc, ipd) do { \ + bool _s = wg_aip_lookup(&sc, (version) == 6 ? AF_INET6 : AF_INET, \ + ip##version(ipa, ipb, ipc, ipd)) == (mem); \ + maybe_fail(); \ + } while (0) + +#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \ + bool _s = wg_aip_lookup(&sc, (version) == 6 ? AF_INET6 : AF_INET, \ + ip##version(ipa, ipb, ipc, ipd)) != (mem); \ + maybe_fail(); \ + } while (0) + +#define test_boolean(cond) do { \ + bool _s = (cond); \ + maybe_fail(); \ + } while (0) + +#define free_all() do { \ + if (a) wg_aip_remove_all(&sc, a); \ + if (b) wg_aip_remove_all(&sc, b); \ + if (c) wg_aip_remove_all(&sc, c); \ + if (d) wg_aip_remove_all(&sc, d); \ + if (e) wg_aip_remove_all(&sc, e); \ + if (f) wg_aip_remove_all(&sc, f); \ + if (g) wg_aip_remove_all(&sc, g); \ + if (h) wg_aip_remove_all(&sc, h); \ + } while (0) + +static bool wg_allowedips_selftest(void) +{ + bool found_a = false, found_b = false, found_c = false, found_d = false, + found_e = false, found_other = false; + struct wg_peer *a = init_peer(), *b = init_peer(), *c = init_peer(), + *d = init_peer(), *e = init_peer(), *f = init_peer(), + *g = init_peer(), *h = init_peer(); + struct wg_softc sc = { 0 }; + struct wg_aip *iter_node; + size_t i = 0, count = 0; + bool success = false; + struct in6_addr ip; + uint64_t part; + + rn_inithead((void **)&sc.sc_aip4, offsetof(struct aip_addr, in) * NBBY); + rn_inithead((void **)&sc.sc_aip6, offsetof(struct aip_addr, in6) * NBBY); + RADIX_NODE_HEAD_LOCK_INIT(sc.sc_aip4); + RADIX_NODE_HEAD_LOCK_INIT(sc.sc_aip6); + + if (!a || !b || !c || !d || !e || !f || !g || !h) { + printf("allowedips self-test malloc: FAIL\n"); + goto free; + } + + insert(4, a, 192, 168, 4, 0, 24); + insert(4, b, 192, 168, 4, 4, 32); + insert(4, c, 192, 168, 0, 0, 16); + insert(4, d, 192, 95, 5, 64, 27); + /* replaces previous entry, and maskself is required */ + insert(4, c, 192, 95, 5, 65, 27); + insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128); + insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64); + insert(4, e, 0, 0, 0, 0, 0); + insert(6, e, 0, 0, 0, 0, 0); + /* replaces previous entry */ + insert(6, f, 0, 0, 0, 0, 0); + insert(6, g, 0x24046800, 0, 0, 0, 32); + /* maskself is required */ + insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64); + insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128); + insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128); + insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); + insert(4, g, 64, 15, 112, 0, 20); + /* maskself is required */ + insert(4, h, 64, 15, 123, 211, 25); + insert(4, a, 10, 0, 0, 0, 25); + insert(4, b, 10, 0, 0, 128, 25); + insert(4, a, 10, 1, 0, 0, 30); + insert(4, b, 10, 1, 0, 4, 30); + insert(4, c, 10, 1, 0, 8, 29); + insert(4, d, 10, 1, 0, 16, 29); + + success = true; + + test(4, a, 192, 168, 4, 20); + test(4, a, 192, 168, 4, 0); + test(4, b, 192, 168, 4, 4); + test(4, c, 192, 168, 200, 182); + test(4, c, 192, 95, 5, 68); + test(4, e, 192, 95, 5, 96); + test(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543); + test(6, c, 0x26075300, 0x60006b00, 0, 0xc02e01ee); + test(6, f, 0x26075300, 0x60006b01, 0, 0); + test(6, g, 0x24046800, 0x40040806, 0, 0x1006); + test(6, g, 0x24046800, 0x40040806, 0x1234, 0x5678); + test(6, f, 0x240467ff, 0x40040806, 0x1234, 0x5678); + test(6, f, 0x24046801, 0x40040806, 0x1234, 0x5678); + test(6, h, 0x24046800, 0x40040800, 0x1234, 0x5678); + test(6, h, 0x24046800, 0x40040800, 0, 0); + test(6, h, 0x24046800, 0x40040800, 0x10101010, 0x10101010); + test(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef); + test(4, g, 64, 15, 116, 26); + test(4, g, 64, 15, 127, 3); + test(4, g, 64, 15, 123, 1); + test(4, h, 64, 15, 123, 128); + test(4, h, 64, 15, 123, 129); + test(4, a, 10, 0, 0, 52); + test(4, b, 10, 0, 0, 220); + test(4, a, 10, 1, 0, 2); + test(4, b, 10, 1, 0, 6); + test(4, c, 10, 1, 0, 10); + test(4, d, 10, 1, 0, 20); + + insert(4, a, 1, 0, 0, 0, 32); + insert(4, a, 64, 0, 0, 0, 32); + insert(4, a, 128, 0, 0, 0, 32); + insert(4, a, 192, 0, 0, 0, 32); + insert(4, a, 255, 0, 0, 0, 32); + wg_aip_remove_all(&sc, a); + test_negative(4, a, 1, 0, 0, 0); + test_negative(4, a, 64, 0, 0, 0); + test_negative(4, a, 128, 0, 0, 0); + test_negative(4, a, 192, 0, 0, 0); + test_negative(4, a, 255, 0, 0, 0); + + free_all(); + insert(4, a, 192, 168, 0, 0, 16); + insert(4, a, 192, 168, 0, 0, 24); + wg_aip_remove_all(&sc, a); + test_negative(4, a, 192, 168, 0, 1); + + for (i = 0; i < 128; ++i) { + part = htobe64(~(1LLU << (i % 64))); + memset(&ip, 0xff, 16); + memcpy((uint8_t *)&ip + (i < 64) * 8, &part, 8); + wg_aip_add(&sc, a, AF_INET6, &ip, 128); + } + + free_all(); + insert(4, a, 192, 95, 5, 93, 27); + insert(6, a, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128); + insert(4, a, 10, 1, 0, 20, 29); + insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 83); + insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 21); + LIST_FOREACH(iter_node, &a->p_aips, a_entry) { + uint8_t cidr, *ip = iter_node->a_addr.bytes; + sa_family_t family = iter_node->a_af; + if (family == AF_INET) + cidr = bitcount32(iter_node->a_mask.ip); + else if (family == AF_INET6) + cidr = in6_mask2len(&iter_node->a_mask.in6, NULL); + else + continue; + + count++; + + if (cidr == 27 && family == AF_INET && + !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr))) + found_a = true; + else if (cidr == 128 && family == AF_INET6 && + !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543), + sizeof(struct in6_addr))) + found_b = true; + else if (cidr == 29 && family == AF_INET && + !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr))) + found_c = true; + else if (cidr == 83 && family == AF_INET6 && + !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0), + sizeof(struct in6_addr))) + found_d = true; + else if (cidr == 21 && family == AF_INET6 && + !memcmp(ip, ip6(0x26075000, 0, 0, 0), + sizeof(struct in6_addr))) + found_e = true; + else + found_other = true; + } + test_boolean(count == 5); + test_boolean(found_a); + test_boolean(found_b); + test_boolean(found_c); + test_boolean(found_d); + test_boolean(found_e); + test_boolean(!found_other); + +#ifdef WG_ALLOWEDIPS_RANDOMIZED_TEST + if (success) + success = randomized_test(); +#endif + + if (success) + printf("allowedips self-tests: pass\n"); + +free: + free_all(); + free(a, M_WG); + free(b, M_WG); + free(c, M_WG); + free(d, M_WG); + free(e, M_WG); + free(f, M_WG); + free(g, M_WG); + free(h, M_WG); + + return success; +} + +#undef test_negative +#undef test +#undef remove +#undef insert +#undef init_peer +#undef free_all |