aboutsummaryrefslogtreecommitdiffstats
path: root/src/if_wg.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/if_wg.c')
-rw-r--r--src/if_wg.c129
1 files changed, 65 insertions, 64 deletions
diff --git a/src/if_wg.c b/src/if_wg.c
index fee26ae..27e94f0 100644
--- a/src/if_wg.c
+++ b/src/if_wg.c
@@ -151,12 +151,24 @@ struct wg_endpoint {
} e_local;
};
+struct aip_addr {
+ uint8_t length;
+ union {
+ uint8_t bytes[16];
+ uint32_t ip;
+ uint32_t ip6[4];
+ struct in_addr in;
+ struct in6_addr in6;
+ };
+};
+
struct wg_aip {
struct radix_node a_nodes[2];
LIST_ENTRY(wg_aip) a_entry;
- struct sockaddr_storage a_addr;
- struct sockaddr_storage a_mask;
+ struct aip_addr a_addr;
+ struct aip_addr a_mask;
struct wg_peer *a_peer;
+ sa_family_t a_af;
};
struct wg_packet {
@@ -518,64 +530,49 @@ wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void
struct radix_node_head *root;
struct radix_node *node;
struct wg_aip *aip;
- struct sockaddr_in *sin_addr, *sin_mask;
- struct sockaddr_in6 *sin6_addr, *sin6_mask;
- bool need_free = false;
int i, ret = 0;
if ((aip = malloc(sizeof(*aip), M_WG, M_NOWAIT | M_ZERO)) == NULL)
return (ENOBUFS);
+ aip->a_peer = peer;
+ aip->a_af = af;
switch (af) {
case AF_INET:
if (cidr > 32) cidr = 32;
root = sc->sc_aip4;
-
- sin_addr = (struct sockaddr_in *)&aip->a_addr;
- sin_mask = (struct sockaddr_in *)&aip->a_mask;
-
- sin_addr->sin_len = sizeof(struct sockaddr_in);
- sin_addr->sin_family = AF_INET;
- sin_addr->sin_addr = *(const struct in_addr *)addr;
-
- sin_mask->sin_len = sizeof(struct sockaddr_in);
- sin_mask->sin_addr.s_addr =
- htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff);
- sin_addr->sin_addr.s_addr &= sin_mask->sin_addr.s_addr;
+ aip->a_addr.in = *(const struct in_addr *)addr;
+ aip->a_mask.ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff);
+ aip->a_addr.ip &= aip->a_mask.ip;
+ aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
break;
case AF_INET6:
if (cidr > 128) cidr = 128;
root = sc->sc_aip6;
-
- sin6_addr = (struct sockaddr_in6 *)&aip->a_addr;
- sin6_mask = (struct sockaddr_in6 *)&aip->a_mask;
-
- sin6_addr->sin6_len = sizeof(struct sockaddr_in6);
- sin6_addr->sin6_family = AF_INET6;
- sin6_addr->sin6_addr = *(const struct in6_addr *)addr;
-
- sin6_mask->sin6_len = sizeof(struct sockaddr_in6);
- in6_prefixlen2mask(&sin6_mask->sin6_addr, cidr);
+ aip->a_addr.in6 = *(const struct in6_addr *)addr;
+ in6_prefixlen2mask(&aip->a_mask.in6, cidr);
for (i = 0; i < 4; i++)
- sin6_addr->sin6_addr.__u6_addr.__u6_addr32[i] &=
- sin6_mask->sin6_addr.__u6_addr.__u6_addr32[i];
+ aip->a_addr.ip6[i] &= aip->a_mask.ip6[i];
+ aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
break;
default:
free(aip, M_WG);
return (EAFNOSUPPORT);
}
- aip->a_peer = peer;
-
RADIX_NODE_HEAD_LOCK(root);
node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes);
-
if (node == aip->a_nodes) {
LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
peer->p_aips_num++;
- } else {
- need_free = true;
- aip = (struct wg_aip *) node;
+ } else if (!node)
+ node = root->rnh_lookup(&aip->a_addr, &aip->a_mask, &root->rh);
+ if (!node) {
+ free(aip, M_WG);
+ return (ENOMEM);
+ } else if (node != aip->a_nodes) {
+ free(aip, M_WG);
+ aip = (struct wg_aip *)node;
if (aip->a_peer != peer) {
LIST_REMOVE(aip, a_entry);
aip->a_peer->p_aips_num--;
@@ -585,40 +582,36 @@ wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void
}
}
RADIX_NODE_HEAD_UNLOCK(root);
- if (need_free)
- free(aip, M_WG);
return (ret);
}
static struct wg_peer *
-wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *addr)
+wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a)
{
struct radix_node_head *root;
struct radix_node *node;
struct wg_peer *peer;
- struct sockaddr_storage ss;
- struct sockaddr_in *sin = (struct sockaddr_in *)&ss;
- struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&ss;
+ struct aip_addr addr;
RADIX_NODE_HEAD_RLOCK_TRACKER;
switch (af) {
case AF_INET:
root = sc->sc_aip4;
- sin->sin_len = sizeof(struct sockaddr_in);
- sin->sin_addr = *(struct in_addr *)addr;
+ memcpy(&addr.in, a, sizeof(addr.in));
+ addr.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
break;
case AF_INET6:
root = sc->sc_aip6;
- sin6->sin6_len = sizeof(struct sockaddr_in6);
- sin6->sin6_addr = *(struct in6_addr *)addr;
+ memcpy(&addr.in6, a, sizeof(addr.in6));
+ addr.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
break;
default:
- panic("invalid wg_aip_lookup af");
+ return NULL;
}
RADIX_NODE_HEAD_RLOCK(root);
- node = root->rnh_matchaddr((struct sockaddr *)&ss, &root->rh);
- peer = node != NULL ? ((struct wg_aip *) node)->a_peer : NULL;
+ node = root->rnh_matchaddr(&addr, &root->rh);
+ peer = node != NULL ? ((struct wg_aip *)node)->a_peer : NULL;
RADIX_NODE_HEAD_RUNLOCK(root);
return (peer);
@@ -631,9 +624,9 @@ wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer)
RADIX_NODE_HEAD_LOCK(sc->sc_aip4);
LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
- if (aip->a_addr.ss_family == AF_INET) {
+ if (aip->a_af == AF_INET) {
if (sc->sc_aip4->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip4->rh) == NULL)
- panic("art_delete failed to delete aip %p", aip);
+ panic("failed to delete aip %p", aip);
LIST_REMOVE(aip, a_entry);
peer->p_aips_num--;
free(aip, M_WG);
@@ -643,9 +636,9 @@ wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer)
RADIX_NODE_HEAD_LOCK(sc->sc_aip6);
LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
- if (aip->a_addr.ss_family == AF_INET6) {
+ if (aip->a_af == AF_INET6) {
if (sc->sc_aip6->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip6->rh) == NULL)
- panic("art_delete failed to delete aip %p", aip);
+ panic("failed to delete aip %p", aip);
LIST_REMOVE(aip, a_entry);
peer->p_aips_num--;
free(aip, M_WG);
@@ -2357,16 +2350,12 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd)
err = ENOMEM;
goto err_aip;
}
- if (aip->a_addr.ss_family == AF_INET) {
- struct sockaddr_in *sin = (struct sockaddr_in *)&aip->a_addr;
- nvlist_add_binary(nvl_aip, "ipv4", &sin->sin_addr, sizeof(sin->sin_addr));
- sin = (struct sockaddr_in *)&aip->a_mask;
- nvlist_add_number(nvl_aip, "cidr", __builtin_popcount(sin->sin_addr.s_addr));
- } else if (aip->a_addr.ss_family == AF_INET6) {
- struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&aip->a_addr;
- nvlist_add_binary(nvl_aip, "ipv6", &sin6->sin6_addr, sizeof(sin6->sin6_addr));
- sin6 = (struct sockaddr_in6 *)&aip->a_mask;
- nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&sin6->sin6_addr, NULL));
+ if (aip->a_af == AF_INET) {
+ nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in));
+ nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip));
+ } else if (aip->a_af == AF_INET6) {
+ nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6));
+ nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL));
}
}
nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count);
@@ -2613,8 +2602,8 @@ wg_clone_create(struct if_clone *ifc, int unit, caddr_t params)
wg_queue_init(&sc->sc_decrypt_parallel, "decp");
/* TODO check rn_inithead return value */
- rn_inithead((void **)&sc->sc_aip4, offsetof(struct sockaddr_in, sin_addr)*NBBY);
- rn_inithead((void **)&sc->sc_aip6, offsetof(struct sockaddr_in6, sin6_addr)*NBBY);
+ rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY);
+ rn_inithead((void **)&sc->sc_aip6, offsetof(struct aip_addr, in6) * NBBY);
RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip4);
RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip6);
@@ -2792,6 +2781,16 @@ wg_prison_remove(void *obj, void *data __unused)
return (0);
}
+#ifdef SELFTESTS
+#include "selftest/allowedips.c"
+static void wg_run_selftests(void)
+{
+ wg_allowedips_selftest();
+}
+#else
+static inline void wg_run_selftests(void) { }
+#endif
+
static void
wg_module_init(void)
{
@@ -2804,6 +2803,8 @@ wg_module_init(void)
ratelimit_zone = uma_zcreate("wg ratelimit", sizeof(struct ratelimit),
NULL, NULL, NULL, NULL, 0, 0);
wg_osd_jail_slot = osd_jail_register(NULL, methods);
+
+ wg_run_selftests();
}
static void