aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-04-21 15:21:52 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2021-04-22 00:02:28 -0600
commit2018285962935293a0da19f16378169d9521342f (patch)
tree59baa87dd4640aab1d18fb8fd0adc1d9036be823
parentwg_cookie: ensure gc is called regularly (diff)
downloadwireguard-freebsd-2018285962935293a0da19f16378169d9521342f.tar.xz
wireguard-freebsd-2018285962935293a0da19f16378169d9521342f.zip
if_wg: port allowedips selftest from Linux code and fix bugs
And then fix broken allowedips implementation for the static unit tests to pass. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--TODO.md3
-rw-r--r--src/if_wg.c129
-rw-r--r--src/selftest/allowedips.c608
3 files changed, 674 insertions, 66 deletions
diff --git a/TODO.md b/TODO.md
index a69ac25..9f56bdb 100644
--- a/TODO.md
+++ b/TODO.md
@@ -7,8 +7,7 @@
- Work out `priv_check` from vnet perspective. (There's no `ns_capable()` on
FreeBSD, just `capable()`, which makes it a bit weird for one jail to have
permissions in another.)
-- Audit allowedips / radix tree checks, and make sure it's actually behaving as
- expected. (It might be useful to port [this selftest](https://git.zx2c4.com/wireguard-linux/tree/drivers/net/wireguard/selftest/allowedips.c).)
+- Port ratelimiter and counter [selftests](https://git.zx2c4.com/wireguard-linux/tree/drivers/net/wireguard/selftest).
- Make code style consistent with one FreeBSD way, rather than a mix of styles.
- Run ratelimiter gc in a properly scheduled manner.
- Make sure noise state machine is correct.
diff --git a/src/if_wg.c b/src/if_wg.c
index fee26ae..27e94f0 100644
--- a/src/if_wg.c
+++ b/src/if_wg.c
@@ -151,12 +151,24 @@ struct wg_endpoint {
} e_local;
};
+struct aip_addr {
+ uint8_t length;
+ union {
+ uint8_t bytes[16];
+ uint32_t ip;
+ uint32_t ip6[4];
+ struct in_addr in;
+ struct in6_addr in6;
+ };
+};
+
struct wg_aip {
struct radix_node a_nodes[2];
LIST_ENTRY(wg_aip) a_entry;
- struct sockaddr_storage a_addr;
- struct sockaddr_storage a_mask;
+ struct aip_addr a_addr;
+ struct aip_addr a_mask;
struct wg_peer *a_peer;
+ sa_family_t a_af;
};
struct wg_packet {
@@ -518,64 +530,49 @@ wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void
struct radix_node_head *root;
struct radix_node *node;
struct wg_aip *aip;
- struct sockaddr_in *sin_addr, *sin_mask;
- struct sockaddr_in6 *sin6_addr, *sin6_mask;
- bool need_free = false;
int i, ret = 0;
if ((aip = malloc(sizeof(*aip), M_WG, M_NOWAIT | M_ZERO)) == NULL)
return (ENOBUFS);
+ aip->a_peer = peer;
+ aip->a_af = af;
switch (af) {
case AF_INET:
if (cidr > 32) cidr = 32;
root = sc->sc_aip4;
-
- sin_addr = (struct sockaddr_in *)&aip->a_addr;
- sin_mask = (struct sockaddr_in *)&aip->a_mask;
-
- sin_addr->sin_len = sizeof(struct sockaddr_in);
- sin_addr->sin_family = AF_INET;
- sin_addr->sin_addr = *(const struct in_addr *)addr;
-
- sin_mask->sin_len = sizeof(struct sockaddr_in);
- sin_mask->sin_addr.s_addr =
- htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff);
- sin_addr->sin_addr.s_addr &= sin_mask->sin_addr.s_addr;
+ aip->a_addr.in = *(const struct in_addr *)addr;
+ aip->a_mask.ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff);
+ aip->a_addr.ip &= aip->a_mask.ip;
+ aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
break;
case AF_INET6:
if (cidr > 128) cidr = 128;
root = sc->sc_aip6;
-
- sin6_addr = (struct sockaddr_in6 *)&aip->a_addr;
- sin6_mask = (struct sockaddr_in6 *)&aip->a_mask;
-
- sin6_addr->sin6_len = sizeof(struct sockaddr_in6);
- sin6_addr->sin6_family = AF_INET6;
- sin6_addr->sin6_addr = *(const struct in6_addr *)addr;
-
- sin6_mask->sin6_len = sizeof(struct sockaddr_in6);
- in6_prefixlen2mask(&sin6_mask->sin6_addr, cidr);
+ aip->a_addr.in6 = *(const struct in6_addr *)addr;
+ in6_prefixlen2mask(&aip->a_mask.in6, cidr);
for (i = 0; i < 4; i++)
- sin6_addr->sin6_addr.__u6_addr.__u6_addr32[i] &=
- sin6_mask->sin6_addr.__u6_addr.__u6_addr32[i];
+ aip->a_addr.ip6[i] &= aip->a_mask.ip6[i];
+ aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
break;
default:
free(aip, M_WG);
return (EAFNOSUPPORT);
}
- aip->a_peer = peer;
-
RADIX_NODE_HEAD_LOCK(root);
node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes);
-
if (node == aip->a_nodes) {
LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
peer->p_aips_num++;
- } else {
- need_free = true;
- aip = (struct wg_aip *) node;
+ } else if (!node)
+ node = root->rnh_lookup(&aip->a_addr, &aip->a_mask, &root->rh);
+ if (!node) {
+ free(aip, M_WG);
+ return (ENOMEM);
+ } else if (node != aip->a_nodes) {
+ free(aip, M_WG);
+ aip = (struct wg_aip *)node;
if (aip->a_peer != peer) {
LIST_REMOVE(aip, a_entry);
aip->a_peer->p_aips_num--;
@@ -585,40 +582,36 @@ wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void
}
}
RADIX_NODE_HEAD_UNLOCK(root);
- if (need_free)
- free(aip, M_WG);
return (ret);
}
static struct wg_peer *
-wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *addr)
+wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a)
{
struct radix_node_head *root;
struct radix_node *node;
struct wg_peer *peer;
- struct sockaddr_storage ss;
- struct sockaddr_in *sin = (struct sockaddr_in *)&ss;
- struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&ss;
+ struct aip_addr addr;
RADIX_NODE_HEAD_RLOCK_TRACKER;
switch (af) {
case AF_INET:
root = sc->sc_aip4;
- sin->sin_len = sizeof(struct sockaddr_in);
- sin->sin_addr = *(struct in_addr *)addr;
+ memcpy(&addr.in, a, sizeof(addr.in));
+ addr.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
break;
case AF_INET6:
root = sc->sc_aip6;
- sin6->sin6_len = sizeof(struct sockaddr_in6);
- sin6->sin6_addr = *(struct in6_addr *)addr;
+ memcpy(&addr.in6, a, sizeof(addr.in6));
+ addr.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
break;
default:
- panic("invalid wg_aip_lookup af");
+ return NULL;
}
RADIX_NODE_HEAD_RLOCK(root);
- node = root->rnh_matchaddr((struct sockaddr *)&ss, &root->rh);
- peer = node != NULL ? ((struct wg_aip *) node)->a_peer : NULL;
+ node = root->rnh_matchaddr(&addr, &root->rh);
+ peer = node != NULL ? ((struct wg_aip *)node)->a_peer : NULL;
RADIX_NODE_HEAD_RUNLOCK(root);
return (peer);
@@ -631,9 +624,9 @@ wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer)
RADIX_NODE_HEAD_LOCK(sc->sc_aip4);
LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
- if (aip->a_addr.ss_family == AF_INET) {
+ if (aip->a_af == AF_INET) {
if (sc->sc_aip4->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip4->rh) == NULL)
- panic("art_delete failed to delete aip %p", aip);
+ panic("failed to delete aip %p", aip);
LIST_REMOVE(aip, a_entry);
peer->p_aips_num--;
free(aip, M_WG);
@@ -643,9 +636,9 @@ wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer)
RADIX_NODE_HEAD_LOCK(sc->sc_aip6);
LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
- if (aip->a_addr.ss_family == AF_INET6) {
+ if (aip->a_af == AF_INET6) {
if (sc->sc_aip6->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip6->rh) == NULL)
- panic("art_delete failed to delete aip %p", aip);
+ panic("failed to delete aip %p", aip);
LIST_REMOVE(aip, a_entry);
peer->p_aips_num--;
free(aip, M_WG);
@@ -2357,16 +2350,12 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd)
err = ENOMEM;
goto err_aip;
}
- if (aip->a_addr.ss_family == AF_INET) {
- struct sockaddr_in *sin = (struct sockaddr_in *)&aip->a_addr;
- nvlist_add_binary(nvl_aip, "ipv4", &sin->sin_addr, sizeof(sin->sin_addr));
- sin = (struct sockaddr_in *)&aip->a_mask;
- nvlist_add_number(nvl_aip, "cidr", __builtin_popcount(sin->sin_addr.s_addr));
- } else if (aip->a_addr.ss_family == AF_INET6) {
- struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&aip->a_addr;
- nvlist_add_binary(nvl_aip, "ipv6", &sin6->sin6_addr, sizeof(sin6->sin6_addr));
- sin6 = (struct sockaddr_in6 *)&aip->a_mask;
- nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&sin6->sin6_addr, NULL));
+ if (aip->a_af == AF_INET) {
+ nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in));
+ nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip));
+ } else if (aip->a_af == AF_INET6) {
+ nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6));
+ nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL));
}
}
nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count);
@@ -2613,8 +2602,8 @@ wg_clone_create(struct if_clone *ifc, int unit, caddr_t params)
wg_queue_init(&sc->sc_decrypt_parallel, "decp");
/* TODO check rn_inithead return value */
- rn_inithead((void **)&sc->sc_aip4, offsetof(struct sockaddr_in, sin_addr)*NBBY);
- rn_inithead((void **)&sc->sc_aip6, offsetof(struct sockaddr_in6, sin6_addr)*NBBY);
+ rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY);
+ rn_inithead((void **)&sc->sc_aip6, offsetof(struct aip_addr, in6) * NBBY);
RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip4);
RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip6);
@@ -2792,6 +2781,16 @@ wg_prison_remove(void *obj, void *data __unused)
return (0);
}
+#ifdef SELFTESTS
+#include "selftest/allowedips.c"
+static void wg_run_selftests(void)
+{
+ wg_allowedips_selftest();
+}
+#else
+static inline void wg_run_selftests(void) { }
+#endif
+
static void
wg_module_init(void)
{
@@ -2804,6 +2803,8 @@ wg_module_init(void)
ratelimit_zone = uma_zcreate("wg ratelimit", sizeof(struct ratelimit),
NULL, NULL, NULL, NULL, 0, 0);
wg_osd_jail_slot = osd_jail_register(NULL, methods);
+
+ wg_run_selftests();
}
static void
diff --git a/src/selftest/allowedips.c b/src/selftest/allowedips.c
new file mode 100644
index 0000000..9678790
--- /dev/null
+++ b/src/selftest/allowedips.c
@@ -0,0 +1,608 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ */
+
+#include <sys/cdefs.h>
+#include <sys/param.h>
+#include <sys/types.h>
+#include <sys/systm.h>
+#include <sys/queue.h>
+#include <sys/endian.h>
+#include <netinet/in.h>
+#include <net/radix.h>
+
+#ifdef WG_ALLOWEDIPS_RANDOMIZED_TEST
+enum {
+ NUM_PEERS = 2000,
+ NUM_RAND_ROUTES = 400,
+ NUM_MUTATED_ROUTES = 100,
+ NUM_QUERIES = NUM_RAND_ROUTES * NUM_MUTATED_ROUTES * 30
+};
+
+struct horrible_allowedips {
+ LIST_HEAD(, horrible_allowedips_node) head;
+};
+
+struct horrible_allowedips_node {
+ LIST_ENTRY(horrible_allowedips_node) table;
+ struct aip_addr ip;
+ struct aip_addr mask;
+ uint8_t ip_version;
+ void *value;
+};
+
+static void horrible_allowedips_init(struct horrible_allowedips *table)
+{
+ LIST_INIT(&table->head);
+}
+
+static void horrible_allowedips_free(struct horrible_allowedips *table)
+{
+ struct horrible_allowedips_node *node, *temp_node;
+
+ LIST_FOREACH_SAFE(node, &table->head, table, temp_node) {
+ LIST_REMOVE(node, table);
+ free(node, M_WG);
+ }
+}
+
+static inline struct aip_addr horrible_cidr_to_mask(uint8_t cidr)
+{
+ struct aip_addr mask;
+
+ memset(&mask.in6, 0x00, 128 / 8);
+ memset(&mask.in6, 0xff, cidr / 8);
+ if (cidr % 32)
+ mask.ip6[cidr / 32] = (uint32_t)htonl(
+ (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
+ return mask;
+}
+
+static inline uint8_t horrible_mask_to_cidr(struct aip_addr subnet)
+{
+ return bitcount32(subnet.ip6[0]) + bitcount32(subnet.ip6[1]) +
+ bitcount32(subnet.ip6[2]) + bitcount32(subnet.ip6[3]);
+}
+
+static inline void
+horrible_mask_self(struct horrible_allowedips_node *node)
+{
+ if (node->ip_version == 4) {
+ node->ip.ip &= node->mask.ip;
+ } else if (node->ip_version == 6) {
+ node->ip.ip6[0] &= node->mask.ip6[0];
+ node->ip.ip6[1] &= node->mask.ip6[1];
+ node->ip.ip6[2] &= node->mask.ip6[2];
+ node->ip.ip6[3] &= node->mask.ip6[3];
+ }
+}
+
+static inline bool
+horrible_match_v4(const struct horrible_allowedips_node *node,
+ struct in_addr *ip)
+{
+ return (ip->s_addr & node->mask.ip) == node->ip.ip;
+}
+
+static inline bool
+horrible_match_v6(const struct horrible_allowedips_node *node,
+ struct in6_addr *ip)
+{
+ return (ip->__u6_addr.__u6_addr32[0] & node->mask.ip6[0]) ==
+ node->ip.ip6[0] &&
+ (ip->__u6_addr.__u6_addr32[1] & node->mask.ip6[1]) ==
+ node->ip.ip6[1] &&
+ (ip->__u6_addr.__u6_addr32[2] & node->mask.ip6[2]) ==
+ node->ip.ip6[2] &&
+ (ip->__u6_addr.__u6_addr32[3] & node->mask.ip6[3]) ==
+ node->ip.ip6[3];
+}
+
+static void
+horrible_insert_ordered(struct horrible_allowedips *table,
+ struct horrible_allowedips_node *node)
+{
+ struct horrible_allowedips_node *other = NULL, *where = NULL;
+ uint8_t my_cidr = horrible_mask_to_cidr(node->mask);
+
+ LIST_FOREACH(other, &table->head, table) {
+ if (!memcmp(&other->mask, &node->mask,
+ sizeof(struct aip_addr)) &&
+ !memcmp(&other->ip, &node->ip,
+ sizeof(struct aip_addr)) &&
+ other->ip_version == node->ip_version) {
+ other->value = node->value;
+ free(node, M_WG);
+ return;
+ }
+ where = other;
+ if (horrible_mask_to_cidr(other->mask) <= my_cidr)
+ break;
+ }
+ if (!other && !where)
+ LIST_INSERT_HEAD(&table->head, node, table);
+ else if (!other)
+ LIST_INSERT_AFTER(where, node, table);
+ else
+ LIST_INSERT_BEFORE(where, node, table);
+}
+
+static int
+horrible_allowedips_insert_v4(struct horrible_allowedips *table,
+ struct in_addr *ip, uint8_t cidr, void *value)
+{
+ struct horrible_allowedips_node *node = malloc(sizeof(*node), M_WG, M_NOWAIT | M_ZERO);
+
+ if (!node)
+ return ENOMEM;
+ node->ip.in = *ip;
+ node->mask = horrible_cidr_to_mask(cidr);
+ node->ip_version = 4;
+ node->value = value;
+ horrible_mask_self(node);
+ horrible_insert_ordered(table, node);
+ return 0;
+}
+
+static int
+horrible_allowedips_insert_v6(struct horrible_allowedips *table,
+ struct in6_addr *ip, uint8_t cidr, void *value)
+{
+ struct horrible_allowedips_node *node = malloc(sizeof(*node), M_WG, M_NOWAIT | M_ZERO);
+
+ if (!node)
+ return ENOMEM;
+ node->ip.in6 = *ip;
+ node->mask = horrible_cidr_to_mask(cidr);
+ node->ip_version = 6;
+ node->value = value;
+ horrible_mask_self(node);
+ horrible_insert_ordered(table, node);
+ return 0;
+}
+
+static void *
+horrible_allowedips_lookup_v4(struct horrible_allowedips *table,
+ struct in_addr *ip)
+{
+ struct horrible_allowedips_node *node;
+ void *ret = NULL;
+
+ LIST_FOREACH(node, &table->head, table) {
+ if (node->ip_version != 4)
+ continue;
+ if (horrible_match_v4(node, ip)) {
+ ret = node->value;
+ break;
+ }
+ }
+ return ret;
+}
+
+static void *
+horrible_allowedips_lookup_v6(struct horrible_allowedips *table,
+ struct in6_addr *ip)
+{
+ struct horrible_allowedips_node *node;
+ void *ret = NULL;
+
+ LIST_FOREACH(node, &table->head, table) {
+ if (node->ip_version != 6)
+ continue;
+ if (horrible_match_v6(node, ip)) {
+ ret = node->value;
+ break;
+ }
+ }
+ return ret;
+}
+
+static bool randomized_test(void)
+{
+ unsigned int i, j, k, mutate_amount, cidr;
+ uint8_t ip[16], mutate_mask[16], mutated[16];
+ struct wg_peer **peers, *peer;
+ struct horrible_allowedips h;
+ struct wg_softc sc = { 0 };
+ bool ret = false;
+
+ rn_inithead((void **)&sc.sc_aip4, offsetof(struct aip_addr, in) * NBBY);
+ rn_inithead((void **)&sc.sc_aip6, offsetof(struct aip_addr, in6) * NBBY);
+ RADIX_NODE_HEAD_LOCK_INIT(sc.sc_aip4);
+ RADIX_NODE_HEAD_LOCK_INIT(sc.sc_aip6);
+ horrible_allowedips_init(&h);
+
+ peers = mallocarray(NUM_PEERS, sizeof(*peers), M_WG, M_NOWAIT | M_ZERO);
+ if (!peers) {
+ printf("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ for (i = 0; i < NUM_PEERS; ++i) {
+ peers[i] = malloc(sizeof(*peers[i]), M_WG, M_NOWAIT | M_ZERO);
+ if (!peers[i]) {
+ printf("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ LIST_INIT(&peers[i]->p_aips);
+ peers[i]->p_aips_num = 0;
+ }
+
+ for (i = 0; i < NUM_RAND_ROUTES; ++i) {
+ arc4random_buf(ip, 4);
+ cidr = arc4random_uniform(32) + 1;
+ peer = peers[arc4random_uniform(NUM_PEERS)];
+ if (wg_aip_add(&sc, peer, AF_INET, ip, cidr)) {
+ printf("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip,
+ cidr, peer)) {
+ printf("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
+ memcpy(mutated, ip, 4);
+ arc4random_buf(mutate_mask, 4);
+ mutate_amount = arc4random_uniform(32);
+ for (k = 0; k < mutate_amount / 8; ++k)
+ mutate_mask[k] = 0xff;
+ mutate_mask[k] = 0xff
+ << ((8 - (mutate_amount % 8)) % 8);
+ for (; k < 4; ++k)
+ mutate_mask[k] = 0;
+ for (k = 0; k < 4; ++k)
+ mutated[k] = (mutated[k] & mutate_mask[k]) |
+ (~mutate_mask[k] &
+ arc4random_uniform(256));
+ cidr = arc4random_uniform(32) + 1;
+ peer = peers[arc4random_uniform(NUM_PEERS)];
+ if (wg_aip_add(&sc, peer, AF_INET, mutated, cidr)) {
+ printf("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ if (horrible_allowedips_insert_v4(&h,
+ (struct in_addr *)mutated, cidr, peer)) {
+ printf("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ }
+ }
+
+ for (i = 0; i < NUM_RAND_ROUTES; ++i) {
+ arc4random_buf(ip, 16);
+ cidr = arc4random_uniform(128) + 1;
+ peer = peers[arc4random_uniform(NUM_PEERS)];
+ if (wg_aip_add(&sc, peer, AF_INET6, ip, cidr)) {
+ printf("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip,
+ cidr, peer)) {
+ printf("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
+ memcpy(mutated, ip, 16);
+ arc4random_buf(mutate_mask, 16);
+ mutate_amount = arc4random_uniform(128);
+ for (k = 0; k < mutate_amount / 8; ++k)
+ mutate_mask[k] = 0xff;
+ mutate_mask[k] = 0xff
+ << ((8 - (mutate_amount % 8)) % 8);
+ for (; k < 4; ++k)
+ mutate_mask[k] = 0;
+ for (k = 0; k < 4; ++k)
+ mutated[k] = (mutated[k] & mutate_mask[k]) |
+ (~mutate_mask[k] &
+ arc4random_uniform(256));
+ cidr = arc4random_uniform(128) + 1;
+ peer = peers[arc4random_uniform(NUM_PEERS)];
+ if (wg_aip_add(&sc, peer, AF_INET6, mutated, cidr)) {
+ printf("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ if (horrible_allowedips_insert_v6(
+ &h, (struct in6_addr *)mutated, cidr,
+ peer)) {
+ printf("allowedips random self-test malloc: FAIL\n");
+ goto free;
+ }
+ }
+ }
+
+ for (i = 0; i < NUM_QUERIES; ++i) {
+ arc4random_buf(ip, 4);
+ if (wg_aip_lookup(&sc, AF_INET, ip) !=
+ horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
+ printf("allowedips random self-test: FAIL\n");
+ goto free;
+ }
+ }
+
+ for (i = 0; i < NUM_QUERIES; ++i) {
+ arc4random_buf(ip, 16);
+ if (wg_aip_lookup(&sc, AF_INET6, ip) !=
+ horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
+ printf("allowedips random self-test: FAIL\n");
+ goto free;
+ }
+ }
+ ret = true;
+
+free:
+ horrible_allowedips_free(&h);
+ if (peers) {
+ for (i = 0; i < NUM_PEERS; ++i)
+ wg_aip_remove_all(&sc, peers[i]);
+ for (i = 0; i < NUM_PEERS; ++i)
+ free(peers[i], M_WG);
+ }
+ free(peers, M_WG);
+ return ret;
+}
+#endif
+
+static struct in_addr *ip4(uint8_t a, uint8_t b, uint8_t c, uint8_t d)
+{
+ static struct in_addr ip;
+ uint8_t *split = (uint8_t *)&ip;
+
+ split[0] = a;
+ split[1] = b;
+ split[2] = c;
+ split[3] = d;
+ return &ip;
+}
+
+static struct in6_addr *ip6(uint32_t a, uint32_t b, uint32_t c, uint32_t d)
+{
+ static struct in6_addr ip;
+ uint32_t *split = ip.__u6_addr.__u6_addr32;
+
+ split[0] = htobe32(a);
+ split[1] = htobe32(b);
+ split[2] = htobe32(c);
+ split[3] = htobe32(d);
+ return &ip;
+}
+
+static struct wg_peer *init_peer(void)
+{
+ struct wg_peer *peer = malloc(sizeof(*peer), M_WG, M_NOWAIT | M_ZERO);
+
+ if (!peer)
+ return NULL;
+ LIST_INIT(&peer->p_aips);
+ peer->p_aips_num = 0;
+ return peer;
+}
+
+#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) do { \
+ int _r = wg_aip_add(&sc, mem, (version) == 6 ? AF_INET6 : AF_INET, \
+ ip##version(ipa, ipb, ipc, ipd), cidr); \
+ if (_r) { \
+ printf("allowedips self-test insertion: FAIL (%d)\n", _r); \
+ success = false; \
+ } \
+ } while (0)
+
+#define maybe_fail() do { \
+ ++i; \
+ if (!_s) { \
+ printf("allowedips self-test %zu: FAIL\n", i); \
+ success = false; \
+ } \
+ } while (0)
+
+#define test(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = wg_aip_lookup(&sc, (version) == 6 ? AF_INET6 : AF_INET, \
+ ip##version(ipa, ipb, ipc, ipd)) == (mem); \
+ maybe_fail(); \
+ } while (0)
+
+#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = wg_aip_lookup(&sc, (version) == 6 ? AF_INET6 : AF_INET, \
+ ip##version(ipa, ipb, ipc, ipd)) != (mem); \
+ maybe_fail(); \
+ } while (0)
+
+#define test_boolean(cond) do { \
+ bool _s = (cond); \
+ maybe_fail(); \
+ } while (0)
+
+#define free_all() do { \
+ if (a) wg_aip_remove_all(&sc, a); \
+ if (b) wg_aip_remove_all(&sc, b); \
+ if (c) wg_aip_remove_all(&sc, c); \
+ if (d) wg_aip_remove_all(&sc, d); \
+ if (e) wg_aip_remove_all(&sc, e); \
+ if (f) wg_aip_remove_all(&sc, f); \
+ if (g) wg_aip_remove_all(&sc, g); \
+ if (h) wg_aip_remove_all(&sc, h); \
+ } while (0)
+
+static bool wg_allowedips_selftest(void)
+{
+ bool found_a = false, found_b = false, found_c = false, found_d = false,
+ found_e = false, found_other = false;
+ struct wg_peer *a = init_peer(), *b = init_peer(), *c = init_peer(),
+ *d = init_peer(), *e = init_peer(), *f = init_peer(),
+ *g = init_peer(), *h = init_peer();
+ struct wg_softc sc = { 0 };
+ struct wg_aip *iter_node;
+ size_t i = 0, count = 0;
+ bool success = false;
+ struct in6_addr ip;
+ uint64_t part;
+
+ rn_inithead((void **)&sc.sc_aip4, offsetof(struct aip_addr, in) * NBBY);
+ rn_inithead((void **)&sc.sc_aip6, offsetof(struct aip_addr, in6) * NBBY);
+ RADIX_NODE_HEAD_LOCK_INIT(sc.sc_aip4);
+ RADIX_NODE_HEAD_LOCK_INIT(sc.sc_aip6);
+
+ if (!a || !b || !c || !d || !e || !f || !g || !h) {
+ printf("allowedips self-test malloc: FAIL\n");
+ goto free;
+ }
+
+ insert(4, a, 192, 168, 4, 0, 24);
+ insert(4, b, 192, 168, 4, 4, 32);
+ insert(4, c, 192, 168, 0, 0, 16);
+ insert(4, d, 192, 95, 5, 64, 27);
+ /* replaces previous entry, and maskself is required */
+ insert(4, c, 192, 95, 5, 65, 27);
+ insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
+ insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64);
+ insert(4, e, 0, 0, 0, 0, 0);
+ insert(6, e, 0, 0, 0, 0, 0);
+ /* replaces previous entry */
+ insert(6, f, 0, 0, 0, 0, 0);
+ insert(6, g, 0x24046800, 0, 0, 0, 32);
+ /* maskself is required */
+ insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64);
+ insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128);
+ insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
+ insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
+ insert(4, g, 64, 15, 112, 0, 20);
+ /* maskself is required */
+ insert(4, h, 64, 15, 123, 211, 25);
+ insert(4, a, 10, 0, 0, 0, 25);
+ insert(4, b, 10, 0, 0, 128, 25);
+ insert(4, a, 10, 1, 0, 0, 30);
+ insert(4, b, 10, 1, 0, 4, 30);
+ insert(4, c, 10, 1, 0, 8, 29);
+ insert(4, d, 10, 1, 0, 16, 29);
+
+ success = true;
+
+ test(4, a, 192, 168, 4, 20);
+ test(4, a, 192, 168, 4, 0);
+ test(4, b, 192, 168, 4, 4);
+ test(4, c, 192, 168, 200, 182);
+ test(4, c, 192, 95, 5, 68);
+ test(4, e, 192, 95, 5, 96);
+ test(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543);
+ test(6, c, 0x26075300, 0x60006b00, 0, 0xc02e01ee);
+ test(6, f, 0x26075300, 0x60006b01, 0, 0);
+ test(6, g, 0x24046800, 0x40040806, 0, 0x1006);
+ test(6, g, 0x24046800, 0x40040806, 0x1234, 0x5678);
+ test(6, f, 0x240467ff, 0x40040806, 0x1234, 0x5678);
+ test(6, f, 0x24046801, 0x40040806, 0x1234, 0x5678);
+ test(6, h, 0x24046800, 0x40040800, 0x1234, 0x5678);
+ test(6, h, 0x24046800, 0x40040800, 0, 0);
+ test(6, h, 0x24046800, 0x40040800, 0x10101010, 0x10101010);
+ test(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef);
+ test(4, g, 64, 15, 116, 26);
+ test(4, g, 64, 15, 127, 3);
+ test(4, g, 64, 15, 123, 1);
+ test(4, h, 64, 15, 123, 128);
+ test(4, h, 64, 15, 123, 129);
+ test(4, a, 10, 0, 0, 52);
+ test(4, b, 10, 0, 0, 220);
+ test(4, a, 10, 1, 0, 2);
+ test(4, b, 10, 1, 0, 6);
+ test(4, c, 10, 1, 0, 10);
+ test(4, d, 10, 1, 0, 20);
+
+ insert(4, a, 1, 0, 0, 0, 32);
+ insert(4, a, 64, 0, 0, 0, 32);
+ insert(4, a, 128, 0, 0, 0, 32);
+ insert(4, a, 192, 0, 0, 0, 32);
+ insert(4, a, 255, 0, 0, 0, 32);
+ wg_aip_remove_all(&sc, a);
+ test_negative(4, a, 1, 0, 0, 0);
+ test_negative(4, a, 64, 0, 0, 0);
+ test_negative(4, a, 128, 0, 0, 0);
+ test_negative(4, a, 192, 0, 0, 0);
+ test_negative(4, a, 255, 0, 0, 0);
+
+ free_all();
+ insert(4, a, 192, 168, 0, 0, 16);
+ insert(4, a, 192, 168, 0, 0, 24);
+ wg_aip_remove_all(&sc, a);
+ test_negative(4, a, 192, 168, 0, 1);
+
+ for (i = 0; i < 128; ++i) {
+ part = htobe64(~(1LLU << (i % 64)));
+ memset(&ip, 0xff, 16);
+ memcpy((uint8_t *)&ip + (i < 64) * 8, &part, 8);
+ wg_aip_add(&sc, a, AF_INET6, &ip, 128);
+ }
+
+ free_all();
+ insert(4, a, 192, 95, 5, 93, 27);
+ insert(6, a, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
+ insert(4, a, 10, 1, 0, 20, 29);
+ insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 83);
+ insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 21);
+ LIST_FOREACH(iter_node, &a->p_aips, a_entry) {
+ uint8_t cidr, *ip = iter_node->a_addr.bytes;
+ sa_family_t family = iter_node->a_af;
+ if (family == AF_INET)
+ cidr = bitcount32(iter_node->a_mask.ip);
+ else if (family == AF_INET6)
+ cidr = in6_mask2len(&iter_node->a_mask.in6, NULL);
+ else
+ continue;
+
+ count++;
+
+ if (cidr == 27 && family == AF_INET &&
+ !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr)))
+ found_a = true;
+ else if (cidr == 128 && family == AF_INET6 &&
+ !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543),
+ sizeof(struct in6_addr)))
+ found_b = true;
+ else if (cidr == 29 && family == AF_INET &&
+ !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr)))
+ found_c = true;
+ else if (cidr == 83 && family == AF_INET6 &&
+ !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0),
+ sizeof(struct in6_addr)))
+ found_d = true;
+ else if (cidr == 21 && family == AF_INET6 &&
+ !memcmp(ip, ip6(0x26075000, 0, 0, 0),
+ sizeof(struct in6_addr)))
+ found_e = true;
+ else
+ found_other = true;
+ }
+ test_boolean(count == 5);
+ test_boolean(found_a);
+ test_boolean(found_b);
+ test_boolean(found_c);
+ test_boolean(found_d);
+ test_boolean(found_e);
+ test_boolean(!found_other);
+
+#ifdef WG_ALLOWEDIPS_RANDOMIZED_TEST
+ if (success)
+ success = randomized_test();
+#endif
+
+ if (success)
+ printf("allowedips self-tests: pass\n");
+
+free:
+ free_all();
+ free(a, M_WG);
+ free(b, M_WG);
+ free(c, M_WG);
+ free(d, M_WG);
+ free(e, M_WG);
+ free(f, M_WG);
+ free(g, M_WG);
+ free(h, M_WG);
+
+ return success;
+}
+
+#undef test_negative
+#undef test
+#undef remove
+#undef insert
+#undef init_peer
+#undef free_all