aboutsummaryrefslogtreecommitdiffstats
path: root/src/if_wg.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/if_wg.c')
-rw-r--r--src/if_wg.c125
1 files changed, 86 insertions, 39 deletions
diff --git a/src/if_wg.c b/src/if_wg.c
index 1c56ccc..01888f9 100644
--- a/src/if_wg.c
+++ b/src/if_wg.c
@@ -464,9 +464,23 @@ static void
wg_peer_free_deferred(struct noise_remote *r)
{
struct wg_peer *peer = noise_remote_arg(r);
+
+ /* While there are no references remaining, we may still have
+ * p_{send,recv} executing (think empty queue, but wg_deliver_{in,out}
+ * needs to check the queue. We should wait for them and then free. */
+ GROUPTASK_DRAIN(&peer->p_recv);
+ GROUPTASK_DRAIN(&peer->p_send);
+ taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv);
+ taskqgroup_detach(qgroup_wg_tqg, &peer->p_send);
+
+ wg_queue_deinit(&peer->p_decrypt_serial);
+ wg_queue_deinit(&peer->p_encrypt_serial);
+ wg_queue_deinit(&peer->p_stage_queue);
+
counter_u64_free(peer->p_tx_bytes);
counter_u64_free(peer->p_rx_bytes);
rw_destroy(&peer->p_endpoint_lock);
+
free(peer, M_WG);
}
@@ -476,24 +490,18 @@ wg_peer_destroy(struct wg_peer *peer)
struct wg_softc *sc = peer->p_sc;
sx_assert(&sc->sc_lock, SX_XLOCKED);
+ /* Disable remote and timers. This will prevent any new handshakes
+ * occuring. */
noise_remote_disable(peer->p_remote);
- wg_aip_remove_all(sc, peer);
-
- /* We disable all timers, so we can't call the following tasks. */
wg_timers_disable(peer);
- /* Ensure the tasks have finished running */
- GROUPTASK_DRAIN(&peer->p_recv);
- GROUPTASK_DRAIN(&peer->p_send);
-
- taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv);
- taskqgroup_detach(qgroup_wg_tqg, &peer->p_send);
-
- wg_queue_deinit(&peer->p_decrypt_serial);
- wg_queue_deinit(&peer->p_encrypt_serial);
- wg_queue_deinit(&peer->p_stage_queue);
+ /* Now we can remove all allowed IPs so no more packets will be routed
+ * to the peer. */
+ wg_aip_remove_all(sc, peer);
- /* Final cleanup */
+ /* Remove peer from the interface, then free. Some references may still
+ * exist to p_remote, so noise_remote_free will wait until they're all
+ * put to call wg_peer_free_deferred. */
sc->sc_peers_num--;
TAILQ_REMOVE(&sc->sc_peers, peer, p_entry);
DPRINTF(sc, "Peer %" PRIu64 " destroyed\n", peer->p_id);
@@ -624,7 +632,16 @@ wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a)
RADIX_NODE_HEAD_RLOCK(root);
node = root->rnh_matchaddr(&addr, &root->rh);
- peer = node != NULL ? ((struct wg_aip *)node)->a_peer : NULL;
+ if (node != NULL) {
+ peer = ((struct wg_aip *)node)->a_peer;
+ /* If we have a remote, we should take a reference. The only
+ * cases where we don't have a remote is in the allowedips
+ * selftest. */
+ if (peer->p_remote != NULL)
+ noise_remote_ref(peer->p_remote);
+ } else {
+ peer = NULL;
+ }
RADIX_NODE_HEAD_RUNLOCK(root);
return (peer);
@@ -945,6 +962,18 @@ wg_timers_enable(struct wg_peer *peer)
static void
wg_timers_disable(struct wg_peer *peer)
{
+ /* By setting p_enabled = false, then calling NET_EPOCH_WAIT, we can be
+ * sure no new handshakes are created after the wait. This is because
+ * all callout_resets (scheduling the callout) are guarded by
+ * p_enabled. We can be sure all sections that read p_enabled and then
+ * optionally call callout_reset are finished as they are surrounded by
+ * NET_EPOCH_{ENTER,EXIT}.
+ *
+ * However, as new callouts may be scheduled during NET_EPOCH_WAIT (but
+ * not after), we stop all callouts leaving no callouts active.
+ *
+ * We should also pull NET_EPOCH_WAIT out of the FOREACH(peer) loops, but the
+ * performance impact is acceptable for the time being. */
ck_pr_store_bool(&peer->p_enabled, false);
NET_EPOCH_WAIT();
ck_pr_store_bool(&peer->p_need_another_keepalive, false);
@@ -1096,6 +1125,7 @@ wg_timers_run_send_initiation(struct wg_peer *peer, int is_retry)
static void
wg_timers_run_retry_handshake(void *_peer)
{
+ struct epoch_tracker et;
struct wg_peer *peer = _peer;
mtx_lock(&peer->p_handshake_mtx);
@@ -1117,25 +1147,32 @@ wg_timers_run_retry_handshake(void *_peer)
callout_stop(&peer->p_send_keepalive);
wg_queue_purge(&peer->p_stage_queue);
- if (!callout_pending(&peer->p_zero_key_material))
+ NET_EPOCH_ENTER(et);
+ if (ck_pr_load_bool(&peer->p_enabled) &&
+ !callout_pending(&peer->p_zero_key_material))
callout_reset(&peer->p_zero_key_material,
MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
wg_timers_run_zero_key_material, peer);
+ NET_EPOCH_EXIT(et);
}
}
static void
wg_timers_run_send_keepalive(void *_peer)
{
+ struct epoch_tracker et;
struct wg_peer *peer = _peer;
wg_send_keepalive(peer);
- if (ck_pr_load_bool(&peer->p_need_another_keepalive)) {
+ NET_EPOCH_ENTER(et);
+ if (ck_pr_load_bool(&peer->p_enabled) &&
+ ck_pr_load_bool(&peer->p_need_another_keepalive)) {
ck_pr_store_bool(&peer->p_need_another_keepalive, false);
callout_reset(&peer->p_send_keepalive,
MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
wg_timers_run_send_keepalive, peer);
}
+ NET_EPOCH_EXIT(et);
}
static void
@@ -1188,12 +1225,10 @@ static void
wg_send_initiation(struct wg_peer *peer)
{
struct wg_pkt_initiation pkt;
- struct epoch_tracker et;
- NET_EPOCH_ENTER(et);
if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue,
pkt.es, pkt.ets) != 0)
- goto out;
+ return;
DPRINTF(peer->p_sc, "Sending handshake initiation to peer %" PRIu64 "\n", peer->p_id);
@@ -1202,20 +1237,16 @@ wg_send_initiation(struct wg_peer *peer)
sizeof(pkt)-sizeof(pkt.m));
wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
wg_timers_event_handshake_initiated(peer);
-out:
- NET_EPOCH_EXIT(et);
}
static void
wg_send_response(struct wg_peer *peer)
{
struct wg_pkt_response pkt;
- struct epoch_tracker et;
- NET_EPOCH_ENTER(et);
if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx,
pkt.ue, pkt.en) != 0)
- goto out;
+ return;
DPRINTF(peer->p_sc, "Sending handshake response to peer %" PRIu64 "\n", peer->p_id);
@@ -1224,8 +1255,6 @@ wg_send_response(struct wg_peer *peer)
cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
sizeof(pkt)-sizeof(pkt.m));
wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt));
-out:
- NET_EPOCH_EXIT(et);
}
static void
@@ -1455,12 +1484,14 @@ wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt)
{
struct wg_pkt_data data;
struct wg_peer *peer;
+ struct noise_remote *remote;
struct mbuf *m;
uint32_t idx;
uint8_t zeroed[NOISE_AUTHTAG_LEN] = { 0 };
int pad;
- peer = noise_keypair_remote_arg(pkt->p_keypair);
+ remote = noise_keypair_remote(pkt->p_keypair);
+ peer = noise_remote_arg(remote);
m = pkt->p_mbuf;
/* Calculate what padding we need to add then limit it to the mtu of
@@ -1492,11 +1523,13 @@ wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt)
pkt->p_mbuf = m;
pkt->p_state = WG_PACKET_CRYPTED;
GROUPTASK_ENQUEUE(&peer->p_send);
+ noise_remote_put(remote);
return;
error:
pkt->p_mbuf = m;
pkt->p_state = WG_PACKET_DEAD;
GROUPTASK_ENQUEUE(&peer->p_send);
+ noise_remote_put(remote);
}
static void
@@ -1504,12 +1537,14 @@ wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt)
{
struct wg_pkt_data data;
struct wg_peer *peer, *allowed_peer;
+ struct noise_remote *remote;
struct mbuf *m;
struct ip *ip;
struct ip6_hdr *ip6;
int len;
- peer = noise_keypair_remote_arg(pkt->p_keypair);
+ remote = noise_keypair_remote(pkt->p_keypair);
+ peer = noise_remote_arg(remote);
m = pkt->p_mbuf;
/* Read index, nonce and then adjust to remove the header. */
@@ -1557,6 +1592,10 @@ wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt)
goto error;
}
+ /* We only want to compare the address, not dereference, so drop the ref. */
+ if (allowed_peer != NULL)
+ noise_remote_put(allowed_peer->p_remote);
+
if (__predict_false(peer != allowed_peer)) {
DPRINTF(sc, "Packet has unallowed src IP from peer %" PRIu64 "\n", peer->p_id);
goto error;
@@ -1567,11 +1606,13 @@ done:
pkt->p_mbuf = m;
pkt->p_state = WG_PACKET_CRYPTED;
GROUPTASK_ENQUEUE(&peer->p_recv);
+ noise_remote_put(remote);
return;
error:
pkt->p_mbuf = m;
pkt->p_state = WG_PACKET_DEAD;
GROUPTASK_ENQUEUE(&peer->p_recv);
+ noise_remote_put(remote);
}
static void
@@ -1914,6 +1955,7 @@ wg_input(struct mbuf *m, int offset, struct inpcb *inpcb,
{
const struct sockaddr_in *sin;
const struct sockaddr_in6 *sin6;
+ struct noise_remote *remote;
struct wg_pkt_data *data;
struct wg_packet *pkt;
struct wg_peer *peer;
@@ -1970,10 +2012,12 @@ wg_input(struct mbuf *m, int offset, struct inpcb *inpcb,
if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL)
goto error;
- peer = noise_keypair_remote_arg(pkt->p_keypair);
+ remote = noise_keypair_remote(pkt->p_keypair);
+ peer = noise_remote_arg(remote);
if (wg_queue_both(&sc->sc_decrypt_parallel, &peer->p_decrypt_serial, pkt) != 0)
if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
wg_decrypt_dispatch(sc);
+ noise_remote_put(remote);
} else {
goto error;
}
@@ -2023,7 +2067,6 @@ error:
static int
wg_transmit(struct ifnet *ifp, struct mbuf *m)
{
- struct epoch_tracker et;
struct wg_packet *pkt = m->m_pkthdr.PH_loc.ptr;
struct wg_softc *sc = ifp->if_softc;
struct wg_peer *peer;
@@ -2031,7 +2074,6 @@ wg_transmit(struct ifnet *ifp, struct mbuf *m)
int rc = 0;
sa_family_t peer_af;
- NET_EPOCH_ENTER(et);
/* Work around lifetime issue in the ipv6 mld code. */
if (__predict_false((ifp->if_flags & IFF_DYING) || !sc)) {
rc = ENXIO;
@@ -2057,7 +2099,7 @@ wg_transmit(struct ifnet *ifp, struct mbuf *m)
if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP, MAX_LOOPS))) {
DPRINTF(sc, "Packet looped");
rc = ELOOP;
- goto err;
+ goto err_peer;
}
peer_af = peer->p_endpoint.e_remote.r_sa.sa_family;
@@ -2065,17 +2107,17 @@ wg_transmit(struct ifnet *ifp, struct mbuf *m)
DPRINTF(sc, "No valid endpoint has been configured or "
"discovered for peer %" PRIu64 "\n", peer->p_id);
rc = EHOSTUNREACH;
- goto err;
+ goto err_peer;
}
wg_queue_push_staged(&peer->p_stage_queue, pkt);
wg_peer_send_staged(peer);
- NET_EPOCH_EXIT(et);
+ noise_remote_put(peer->p_remote);
return (0);
+err_peer:
+ noise_remote_put(peer->p_remote);
err:
- NET_EPOCH_EXIT(et);
if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
- /* TODO: send ICMP unreachable? */
wg_packet_free(pkt);
return (rc);
}
@@ -2124,8 +2166,10 @@ wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
peer = noise_remote_arg(remote);
if (nvlist_exists_bool(nvl, "remove") &&
nvlist_get_bool(nvl, "remove")) {
- if (peer != NULL)
+ if (remote != NULL) {
wg_peer_destroy(peer);
+ noise_remote_put(remote);
+ }
return (0);
}
if (nvlist_exists_bool(nvl, "replace-allowedips") &&
@@ -2205,11 +2249,14 @@ wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
if (sc->sc_ifp->if_link_state == LINK_STATE_UP)
wg_timers_enable(peer);
}
+ if (remote != NULL)
+ noise_remote_put(remote);
return (0);
-
out:
if (need_insert) /* If we fail, only destroy if it was new. */
wg_peer_destroy(peer);
+ if (remote != NULL)
+ noise_remote_put(remote);
return (err);
}