diff options
Diffstat (limited to 'src/if_wg.c')
-rw-r--r-- | src/if_wg.c | 125 |
1 files changed, 86 insertions, 39 deletions
diff --git a/src/if_wg.c b/src/if_wg.c index 1c56ccc..01888f9 100644 --- a/src/if_wg.c +++ b/src/if_wg.c @@ -464,9 +464,23 @@ static void wg_peer_free_deferred(struct noise_remote *r) { struct wg_peer *peer = noise_remote_arg(r); + + /* While there are no references remaining, we may still have + * p_{send,recv} executing (think empty queue, but wg_deliver_{in,out} + * needs to check the queue. We should wait for them and then free. */ + GROUPTASK_DRAIN(&peer->p_recv); + GROUPTASK_DRAIN(&peer->p_send); + taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv); + taskqgroup_detach(qgroup_wg_tqg, &peer->p_send); + + wg_queue_deinit(&peer->p_decrypt_serial); + wg_queue_deinit(&peer->p_encrypt_serial); + wg_queue_deinit(&peer->p_stage_queue); + counter_u64_free(peer->p_tx_bytes); counter_u64_free(peer->p_rx_bytes); rw_destroy(&peer->p_endpoint_lock); + free(peer, M_WG); } @@ -476,24 +490,18 @@ wg_peer_destroy(struct wg_peer *peer) struct wg_softc *sc = peer->p_sc; sx_assert(&sc->sc_lock, SX_XLOCKED); + /* Disable remote and timers. This will prevent any new handshakes + * occuring. */ noise_remote_disable(peer->p_remote); - wg_aip_remove_all(sc, peer); - - /* We disable all timers, so we can't call the following tasks. */ wg_timers_disable(peer); - /* Ensure the tasks have finished running */ - GROUPTASK_DRAIN(&peer->p_recv); - GROUPTASK_DRAIN(&peer->p_send); - - taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv); - taskqgroup_detach(qgroup_wg_tqg, &peer->p_send); - - wg_queue_deinit(&peer->p_decrypt_serial); - wg_queue_deinit(&peer->p_encrypt_serial); - wg_queue_deinit(&peer->p_stage_queue); + /* Now we can remove all allowed IPs so no more packets will be routed + * to the peer. */ + wg_aip_remove_all(sc, peer); - /* Final cleanup */ + /* Remove peer from the interface, then free. Some references may still + * exist to p_remote, so noise_remote_free will wait until they're all + * put to call wg_peer_free_deferred. */ sc->sc_peers_num--; TAILQ_REMOVE(&sc->sc_peers, peer, p_entry); DPRINTF(sc, "Peer %" PRIu64 " destroyed\n", peer->p_id); @@ -624,7 +632,16 @@ wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a) RADIX_NODE_HEAD_RLOCK(root); node = root->rnh_matchaddr(&addr, &root->rh); - peer = node != NULL ? ((struct wg_aip *)node)->a_peer : NULL; + if (node != NULL) { + peer = ((struct wg_aip *)node)->a_peer; + /* If we have a remote, we should take a reference. The only + * cases where we don't have a remote is in the allowedips + * selftest. */ + if (peer->p_remote != NULL) + noise_remote_ref(peer->p_remote); + } else { + peer = NULL; + } RADIX_NODE_HEAD_RUNLOCK(root); return (peer); @@ -945,6 +962,18 @@ wg_timers_enable(struct wg_peer *peer) static void wg_timers_disable(struct wg_peer *peer) { + /* By setting p_enabled = false, then calling NET_EPOCH_WAIT, we can be + * sure no new handshakes are created after the wait. This is because + * all callout_resets (scheduling the callout) are guarded by + * p_enabled. We can be sure all sections that read p_enabled and then + * optionally call callout_reset are finished as they are surrounded by + * NET_EPOCH_{ENTER,EXIT}. + * + * However, as new callouts may be scheduled during NET_EPOCH_WAIT (but + * not after), we stop all callouts leaving no callouts active. + * + * We should also pull NET_EPOCH_WAIT out of the FOREACH(peer) loops, but the + * performance impact is acceptable for the time being. */ ck_pr_store_bool(&peer->p_enabled, false); NET_EPOCH_WAIT(); ck_pr_store_bool(&peer->p_need_another_keepalive, false); @@ -1096,6 +1125,7 @@ wg_timers_run_send_initiation(struct wg_peer *peer, int is_retry) static void wg_timers_run_retry_handshake(void *_peer) { + struct epoch_tracker et; struct wg_peer *peer = _peer; mtx_lock(&peer->p_handshake_mtx); @@ -1117,25 +1147,32 @@ wg_timers_run_retry_handshake(void *_peer) callout_stop(&peer->p_send_keepalive); wg_queue_purge(&peer->p_stage_queue); - if (!callout_pending(&peer->p_zero_key_material)) + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled) && + !callout_pending(&peer->p_zero_key_material)) callout_reset(&peer->p_zero_key_material, MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000), wg_timers_run_zero_key_material, peer); + NET_EPOCH_EXIT(et); } } static void wg_timers_run_send_keepalive(void *_peer) { + struct epoch_tracker et; struct wg_peer *peer = _peer; wg_send_keepalive(peer); - if (ck_pr_load_bool(&peer->p_need_another_keepalive)) { + NET_EPOCH_ENTER(et); + if (ck_pr_load_bool(&peer->p_enabled) && + ck_pr_load_bool(&peer->p_need_another_keepalive)) { ck_pr_store_bool(&peer->p_need_another_keepalive, false); callout_reset(&peer->p_send_keepalive, MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000), wg_timers_run_send_keepalive, peer); } + NET_EPOCH_EXIT(et); } static void @@ -1188,12 +1225,10 @@ static void wg_send_initiation(struct wg_peer *peer) { struct wg_pkt_initiation pkt; - struct epoch_tracker et; - NET_EPOCH_ENTER(et); if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue, pkt.es, pkt.ets) != 0) - goto out; + return; DPRINTF(peer->p_sc, "Sending handshake initiation to peer %" PRIu64 "\n", peer->p_id); @@ -1202,20 +1237,16 @@ wg_send_initiation(struct wg_peer *peer) sizeof(pkt)-sizeof(pkt.m)); wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt)); wg_timers_event_handshake_initiated(peer); -out: - NET_EPOCH_EXIT(et); } static void wg_send_response(struct wg_peer *peer) { struct wg_pkt_response pkt; - struct epoch_tracker et; - NET_EPOCH_ENTER(et); if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx, pkt.ue, pkt.en) != 0) - goto out; + return; DPRINTF(peer->p_sc, "Sending handshake response to peer %" PRIu64 "\n", peer->p_id); @@ -1224,8 +1255,6 @@ wg_send_response(struct wg_peer *peer) cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, sizeof(pkt)-sizeof(pkt.m)); wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt)); -out: - NET_EPOCH_EXIT(et); } static void @@ -1455,12 +1484,14 @@ wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt) { struct wg_pkt_data data; struct wg_peer *peer; + struct noise_remote *remote; struct mbuf *m; uint32_t idx; uint8_t zeroed[NOISE_AUTHTAG_LEN] = { 0 }; int pad; - peer = noise_keypair_remote_arg(pkt->p_keypair); + remote = noise_keypair_remote(pkt->p_keypair); + peer = noise_remote_arg(remote); m = pkt->p_mbuf; /* Calculate what padding we need to add then limit it to the mtu of @@ -1492,11 +1523,13 @@ wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt) pkt->p_mbuf = m; pkt->p_state = WG_PACKET_CRYPTED; GROUPTASK_ENQUEUE(&peer->p_send); + noise_remote_put(remote); return; error: pkt->p_mbuf = m; pkt->p_state = WG_PACKET_DEAD; GROUPTASK_ENQUEUE(&peer->p_send); + noise_remote_put(remote); } static void @@ -1504,12 +1537,14 @@ wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt) { struct wg_pkt_data data; struct wg_peer *peer, *allowed_peer; + struct noise_remote *remote; struct mbuf *m; struct ip *ip; struct ip6_hdr *ip6; int len; - peer = noise_keypair_remote_arg(pkt->p_keypair); + remote = noise_keypair_remote(pkt->p_keypair); + peer = noise_remote_arg(remote); m = pkt->p_mbuf; /* Read index, nonce and then adjust to remove the header. */ @@ -1557,6 +1592,10 @@ wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt) goto error; } + /* We only want to compare the address, not dereference, so drop the ref. */ + if (allowed_peer != NULL) + noise_remote_put(allowed_peer->p_remote); + if (__predict_false(peer != allowed_peer)) { DPRINTF(sc, "Packet has unallowed src IP from peer %" PRIu64 "\n", peer->p_id); goto error; @@ -1567,11 +1606,13 @@ done: pkt->p_mbuf = m; pkt->p_state = WG_PACKET_CRYPTED; GROUPTASK_ENQUEUE(&peer->p_recv); + noise_remote_put(remote); return; error: pkt->p_mbuf = m; pkt->p_state = WG_PACKET_DEAD; GROUPTASK_ENQUEUE(&peer->p_recv); + noise_remote_put(remote); } static void @@ -1914,6 +1955,7 @@ wg_input(struct mbuf *m, int offset, struct inpcb *inpcb, { const struct sockaddr_in *sin; const struct sockaddr_in6 *sin6; + struct noise_remote *remote; struct wg_pkt_data *data; struct wg_packet *pkt; struct wg_peer *peer; @@ -1970,10 +2012,12 @@ wg_input(struct mbuf *m, int offset, struct inpcb *inpcb, if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL) goto error; - peer = noise_keypair_remote_arg(pkt->p_keypair); + remote = noise_keypair_remote(pkt->p_keypair); + peer = noise_remote_arg(remote); if (wg_queue_both(&sc->sc_decrypt_parallel, &peer->p_decrypt_serial, pkt) != 0) if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); wg_decrypt_dispatch(sc); + noise_remote_put(remote); } else { goto error; } @@ -2023,7 +2067,6 @@ error: static int wg_transmit(struct ifnet *ifp, struct mbuf *m) { - struct epoch_tracker et; struct wg_packet *pkt = m->m_pkthdr.PH_loc.ptr; struct wg_softc *sc = ifp->if_softc; struct wg_peer *peer; @@ -2031,7 +2074,6 @@ wg_transmit(struct ifnet *ifp, struct mbuf *m) int rc = 0; sa_family_t peer_af; - NET_EPOCH_ENTER(et); /* Work around lifetime issue in the ipv6 mld code. */ if (__predict_false((ifp->if_flags & IFF_DYING) || !sc)) { rc = ENXIO; @@ -2057,7 +2099,7 @@ wg_transmit(struct ifnet *ifp, struct mbuf *m) if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP, MAX_LOOPS))) { DPRINTF(sc, "Packet looped"); rc = ELOOP; - goto err; + goto err_peer; } peer_af = peer->p_endpoint.e_remote.r_sa.sa_family; @@ -2065,17 +2107,17 @@ wg_transmit(struct ifnet *ifp, struct mbuf *m) DPRINTF(sc, "No valid endpoint has been configured or " "discovered for peer %" PRIu64 "\n", peer->p_id); rc = EHOSTUNREACH; - goto err; + goto err_peer; } wg_queue_push_staged(&peer->p_stage_queue, pkt); wg_peer_send_staged(peer); - NET_EPOCH_EXIT(et); + noise_remote_put(peer->p_remote); return (0); +err_peer: + noise_remote_put(peer->p_remote); err: - NET_EPOCH_EXIT(et); if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1); - /* TODO: send ICMP unreachable? */ wg_packet_free(pkt); return (rc); } @@ -2124,8 +2166,10 @@ wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl) peer = noise_remote_arg(remote); if (nvlist_exists_bool(nvl, "remove") && nvlist_get_bool(nvl, "remove")) { - if (peer != NULL) + if (remote != NULL) { wg_peer_destroy(peer); + noise_remote_put(remote); + } return (0); } if (nvlist_exists_bool(nvl, "replace-allowedips") && @@ -2205,11 +2249,14 @@ wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl) if (sc->sc_ifp->if_link_state == LINK_STATE_UP) wg_timers_enable(peer); } + if (remote != NULL) + noise_remote_put(remote); return (0); - out: if (need_insert) /* If we fail, only destroy if it was new. */ wg_peer_destroy(peer); + if (remote != NULL) + noise_remote_put(remote); return (err); } |