diff options
Diffstat (limited to 'sys/net/if_wg.c')
-rw-r--r-- | sys/net/if_wg.c | 562 |
1 files changed, 166 insertions, 396 deletions
diff --git a/sys/net/if_wg.c b/sys/net/if_wg.c index 807d08a7b6e..dddce93f54f 100644 --- a/sys/net/if_wg.c +++ b/sys/net/if_wg.c @@ -49,8 +49,6 @@ #include <netinet/udp.h> #include <netinet/in_pcb.h> -#include <crypto/siphash.h> - #define DEFAULT_MTU 1420 #define MAX_STAGED_PKT 128 @@ -59,11 +57,6 @@ #define MAX_QUEUED_HANDSHAKES 4096 -#define HASHTABLE_PEER_SIZE (1 << 11) -#define HASHTABLE_INDEX_SIZE (1 << 13) -#define MAX_PEERS_PER_IFACE (1 << 20) - -#define REKEY_TIMEOUT 5 #define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */ #define KEEPALIVE_TIMEOUT 10 #define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT) @@ -135,13 +128,6 @@ struct wg_endpoint { } e_local; }; -struct wg_index { - LIST_ENTRY(wg_index) i_entry; - SLIST_ENTRY(wg_index) i_unused_entry; - uint32_t i_key; - struct noise_remote *i_value; -}; - struct wg_timers { /* t_lock is for blocking wg_timers_event_* when setting t_disabled. */ struct rwlock t_lock; @@ -156,7 +142,6 @@ struct wg_timers { struct timeout t_persistent_keepalive; struct mutex t_handshake_mtx; - struct timespec t_handshake_last_sent; /* nanouptime */ struct timespec t_handshake_complete; /* nanotime */ int t_handshake_retries; }; @@ -177,7 +162,8 @@ struct wg_packet { SIMPLEQ_ENTRY(wg_packet) p_serial; SIMPLEQ_ENTRY(wg_packet) p_parallel; struct wg_endpoint p_endpoint; - struct wg_peer *p_peer; + struct noise_keypair *p_keypair; + uint64_t p_nonce; struct mbuf *p_mbuf; int p_mtu; enum wg_ring_state { @@ -194,12 +180,11 @@ struct wg_queue { }; struct wg_peer { - LIST_ENTRY(wg_peer) p_pubkey_entry; - TAILQ_ENTRY(wg_peer) p_seq_entry; + TAILQ_ENTRY(wg_peer) p_entry; uint64_t p_id; struct wg_softc *p_sc; - struct noise_remote p_remote; + struct noise_remote *p_remote; struct cookie_maker p_cookie; struct wg_timers p_timers; @@ -220,9 +205,6 @@ struct wg_peer { struct wg_queue p_encap_serial; struct wg_queue p_decap_serial; - SLIST_HEAD(,wg_index) p_unused_index; - struct wg_index p_index[3]; - LIST_HEAD(,wg_aip) p_aip; SLIST_ENTRY(wg_peer) p_start_list; @@ -231,13 +213,14 @@ struct wg_peer { struct wg_softc { struct ifnet sc_if; - SIPHASH_KEY sc_secret; struct rwlock sc_lock; - struct noise_local sc_local; + struct noise_local *sc_local; struct cookie_checker sc_cookie; in_port_t sc_udp_port; int sc_udp_rtable; + TAILQ_HEAD(,wg_peer) sc_peers; + size_t sc_peer_num; struct rwlock sc_so_lock; struct socket *sc_so4; @@ -251,16 +234,6 @@ struct wg_softc { struct art_root *sc_aip6; #endif - struct rwlock sc_peer_lock; - size_t sc_peer_num; - LIST_HEAD(,wg_peer) *sc_peer; - TAILQ_HEAD(,wg_peer) sc_peer_seq; - u_long sc_peer_mask; - - struct mutex sc_index_mtx; - LIST_HEAD(,wg_index) *sc_index; - u_long sc_index_mask; - struct task sc_handshake; struct mbuf_queue sc_handshake_queue; @@ -273,8 +246,6 @@ struct wg_softc { struct wg_peer * wg_peer_create(struct wg_softc *, uint8_t[WG_KEY_SIZE], uint8_t[WG_KEY_SIZE]); -struct wg_peer * - wg_peer_lookup(struct wg_softc *, const uint8_t[WG_KEY_SIZE]); void wg_peer_destroy(struct wg_peer *); void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *); void wg_peer_set_sockaddr(struct wg_peer *, struct sockaddr *); @@ -319,7 +290,6 @@ void wg_timers_event_handshake_responded(struct wg_timers *); void wg_timers_event_handshake_complete(struct wg_timers *); void wg_timers_event_session_derived(struct wg_timers *); void wg_timers_event_want_initiation(struct wg_timers *); -void wg_timers_event_reset_handshake_last_sent(struct wg_timers *); void wg_timers_run_send_initiation(void *, int); void wg_timers_run_retry_handshake(void *); @@ -354,14 +324,6 @@ struct wg_packet * struct wg_packet * wg_queue_parallel_dequeue(struct wg_queue *); -struct noise_remote * - wg_remote_get(void *, uint8_t[NOISE_PUBLIC_KEY_LEN]); -uint32_t - wg_index_set(void *, struct noise_remote *); -struct noise_remote * - wg_index_get(void *, uint32_t); -void wg_index_drop(void *, uint32_t); - struct mbuf * wg_input(void *, struct mbuf *, struct ip *, struct ip6_hdr *, void *, int); @@ -398,21 +360,15 @@ wg_peer_create(struct wg_softc *sc, uint8_t public[WG_KEY_SIZE], uint8_t psk[WG_KEY_SIZE]) { struct wg_peer *peer; - uint64_t idx; rw_assert_wrlock(&sc->sc_lock); - if (sc->sc_peer_num >= MAX_PEERS_PER_IFACE) - return NULL; - if ((peer = pool_get(&wg_peer_pool, PR_NOWAIT)) == NULL) return NULL; peer->p_id = peer_counter++; peer->p_sc = sc; - noise_remote_init(&peer->p_remote, public, &sc->sc_local); - noise_remote_set_psk(&peer->p_remote, psk); cookie_maker_init(&peer->p_cookie, public); wg_timers_init(&peer->p_timers); @@ -433,51 +389,28 @@ wg_peer_create(struct wg_softc *sc, uint8_t public[WG_KEY_SIZE], wg_queue_init(&peer->p_encap_serial); wg_queue_init(&peer->p_decap_serial); - SLIST_INIT(&peer->p_unused_index); - SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[0], - i_unused_entry); - SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[1], - i_unused_entry); - SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[2], - i_unused_entry); - LIST_INIT(&peer->p_aip); peer->p_start_onlist = 0; - idx = SipHash24(&sc->sc_secret, public, WG_KEY_SIZE); - idx &= sc->sc_peer_mask; + if ((peer->p_remote = noise_remote_alloc(sc->sc_local, peer, public, psk)) == NULL) { + pool_put(&wg_peer_pool, peer); + return NULL; + } - rw_enter_write(&sc->sc_peer_lock); - LIST_INSERT_HEAD(&sc->sc_peer[idx], peer, p_pubkey_entry); - TAILQ_INSERT_TAIL(&sc->sc_peer_seq, peer, p_seq_entry); + DPRINTF(sc, "Peer %llu created\n", peer->p_id); + TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry); sc->sc_peer_num++; - rw_exit_write(&sc->sc_peer_lock); - DPRINTF(sc, "Peer %llu created\n", peer->p_id); return peer; } -struct wg_peer * -wg_peer_lookup(struct wg_softc *sc, const uint8_t public[WG_KEY_SIZE]) +void +wg_peer_free(struct noise_remote *r) { - uint8_t peer_key[WG_KEY_SIZE]; - struct wg_peer *peer; - uint64_t idx; - - idx = SipHash24(&sc->sc_secret, public, WG_KEY_SIZE); - idx &= sc->sc_peer_mask; - - rw_enter_read(&sc->sc_peer_lock); - LIST_FOREACH(peer, &sc->sc_peer[idx], p_pubkey_entry) { - noise_remote_keys(&peer->p_remote, peer_key, NULL); - if (timingsafe_bcmp(peer_key, public, WG_KEY_SIZE) == 0) - goto done; - } - peer = NULL; -done: - rw_exit_read(&sc->sc_peer_lock); - return peer; + struct wg_peer *peer; + peer = noise_remote_arg(r); + pool_put(&wg_peer_pool, peer); } void @@ -488,46 +421,16 @@ wg_peer_destroy(struct wg_peer *peer) rw_assert_wrlock(&sc->sc_lock); - /* - * Remove peer from the pubkey hashtable and disable all timeouts. - * After this, and flushing wg_handshake_taskq, then no more handshakes - * can be started. - */ - rw_enter_write(&sc->sc_peer_lock); - LIST_REMOVE(peer, p_pubkey_entry); - TAILQ_REMOVE(&sc->sc_peer_seq, peer, p_seq_entry); + TAILQ_REMOVE(&sc->sc_peers, peer, p_entry); sc->sc_peer_num--; - rw_exit_write(&sc->sc_peer_lock); - - wg_timers_disable(&peer->p_timers); - - taskq_barrier(wg_handshake_taskq); - /* - * Now we drop all allowed ips, to drop all outgoing packets to the - * peer. Then drop all the indexes to drop all incoming packets to the - * peer. Then we can flush if_snd, wg_crypt_taskq and then nettq to - * ensure no more references to the peer exist. - */ LIST_FOREACH_SAFE(aip, &peer->p_aip, a_entry, taip) wg_aip_remove(sc, peer, &aip->a_data); - noise_remote_clear(&peer->p_remote); - - NET_LOCK(); - while (!ifq_empty(&sc->sc_if.if_snd)) { - NET_UNLOCK(); - tsleep_nsec(sc, PWAIT, "wg_ifq", 1000); - NET_LOCK(); - } - NET_UNLOCK(); - - taskq_barrier(wg_crypt_taskq); - taskq_barrier(net_tq(sc->sc_if.if_index)); + wg_timers_disable(&peer->p_timers); + noise_remote_free(peer->p_remote, wg_peer_free); DPRINTF(sc, "Peer %llu destroyed\n", peer->p_id); - explicit_bzero(peer, sizeof(*peer)); - pool_put(&wg_peer_pool, peer); } void @@ -917,8 +820,6 @@ wg_tag_get(struct mbuf *m) * tx: response, rx: response * wg_timers_event_want_initiation: * tx: data failed, old keys expiring - * wg_timers_event_reset_handshake_last_sent: - * anytime we may immediately want a new handshake */ void wg_timers_init(struct wg_timers *t) @@ -987,28 +888,6 @@ wg_timers_get_last_handshake(struct wg_timers *t, struct timespec *time) mtx_leave(&t->t_handshake_mtx); } -int -wg_timers_expired_handshake_last_sent(struct wg_timers *t) -{ - struct timespec uptime; - struct timespec expire = { .tv_sec = REKEY_TIMEOUT, .tv_nsec = 0 }; - - getnanouptime(&uptime); - timespecadd(&t->t_handshake_last_sent, &expire, &expire); - return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0; -} - -int -wg_timers_check_handshake_last_sent(struct wg_timers *t) -{ - int ret; - mtx_enter(&t->t_handshake_mtx); - if ((ret = wg_timers_expired_handshake_last_sent(t)) == ETIMEDOUT) - getnanouptime(&t->t_handshake_last_sent); - mtx_leave(&t->t_handshake_mtx); - return ret; -} - void wg_timers_event_data_sent(struct wg_timers *t) { @@ -1070,14 +949,6 @@ wg_timers_event_handshake_initiated(struct wg_timers *t) } void -wg_timers_event_handshake_responded(struct wg_timers *t) -{ - mtx_enter(&t->t_handshake_mtx); - getnanouptime(&t->t_handshake_last_sent); - mtx_leave(&t->t_handshake_mtx); -} - -void wg_timers_event_handshake_complete(struct wg_timers *t) { rw_enter_read(&t->t_lock); @@ -1111,21 +982,13 @@ wg_timers_event_want_initiation(struct wg_timers *t) } void -wg_timers_event_reset_handshake_last_sent(struct wg_timers *t) -{ - mtx_enter(&t->t_handshake_mtx); - t->t_handshake_last_sent.tv_sec -= (REKEY_TIMEOUT + 1); - mtx_leave(&t->t_handshake_mtx); -} - -void wg_timers_run_send_initiation(void *_t, int is_retry) { struct wg_timers *t = _t; struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers); if (!is_retry) t->t_handshake_retries = 0; - if (wg_timers_expired_handshake_last_sent(t) == ETIMEDOUT) + if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT) task_add(wg_handshake_taskq, &peer->p_send_initiation); } @@ -1227,15 +1090,13 @@ wg_send_initiation(void *_peer) struct wg_peer *peer = _peer; struct wg_pkt_initiation pkt; - if (wg_timers_check_handshake_last_sent(&peer->p_timers) != ETIMEDOUT) + if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue, pkt.es, + pkt.ets) != 0) return; DPRINTF(peer->p_sc, "Sending handshake initiation to peer %llu\n", peer->p_id); - if (noise_create_initiation(&peer->p_remote, &pkt.s_idx, pkt.ue, pkt.es, - pkt.ets) != 0) - return; pkt.t = WG_PKT_INITIATION; cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, sizeof(pkt)-sizeof(pkt.m)); @@ -1251,16 +1112,14 @@ wg_send_response(struct wg_peer *peer) DPRINTF(peer->p_sc, "Sending handshake response to peer %llu\n", peer->p_id); - if (noise_create_response(&peer->p_remote, &pkt.s_idx, &pkt.r_idx, + if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx, pkt.ue, pkt.en) != 0) return; - if (noise_remote_begin_session(&peer->p_remote) != 0) - return; wg_timers_event_session_derived(&peer->p_timers); pkt.t = WG_PKT_RESPONSE; cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, sizeof(pkt)-sizeof(pkt.m)); - wg_timers_event_handshake_responded(&peer->p_timers); + wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt)); } @@ -1302,7 +1161,6 @@ wg_send_keepalive(void *_peer) } pkt->p_mbuf = m; - pkt->p_peer = peer; pkt->p_mtu = 0; m->m_pkthdr.ph_cookie = pkt; @@ -1318,7 +1176,7 @@ void wg_peer_clear_secrets(void *_peer) { struct wg_peer *peer = _peer; - noise_remote_clear(&peer->p_remote); + noise_remote_keypairs_clear(peer->p_remote); } void @@ -1330,7 +1188,8 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) struct wg_endpoint *e; struct wg_peer *peer; struct mbuf *m; - struct noise_remote *remote; + struct noise_keypair *keypair; + struct noise_remote *remote = NULL; int res, underload = 0; static struct timeval wg_last_underload; /* microuptime */ @@ -1371,13 +1230,13 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) panic("unexpected response: %d\n", res); } - if (noise_consume_initiation(&sc->sc_local, &remote, + if (noise_consume_initiation(sc->sc_local, &remote, init->s_idx, init->ue, init->es, init->ets) != 0) { DPRINTF(sc, "Invalid handshake initiation\n"); goto error; } - peer = CONTAINER_OF(remote, struct wg_peer, p_remote); + peer = noise_remote_arg(remote); DPRINTF(sc, "Receiving handshake initiation from peer %llu\n", peer->p_id); @@ -1405,45 +1264,39 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) panic("unexpected response: %d\n", res); } - if ((remote = wg_index_get(sc, resp->r_idx)) == NULL) { - DPRINTF(sc, "Unknown handshake response\n"); - goto error; - } - - peer = CONTAINER_OF(remote, struct wg_peer, p_remote); - - if (noise_consume_response(remote, resp->s_idx, resp->r_idx, - resp->ue, resp->en) != 0) { + if (noise_consume_response(sc->sc_local, &remote, + resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) { DPRINTF(sc, "Invalid handshake response\n"); goto error; } - DPRINTF(sc, "Receiving handshake response from peer %llu\n", - peer->p_id); + peer = noise_remote_arg(remote); + DPRINTF(sc, "Receiving handshake response from peer %llu\n", peer->p_id); wg_peer_set_endpoint(peer, e); - if (noise_remote_begin_session(&peer->p_remote) == 0) { - wg_timers_event_session_derived(&peer->p_timers); - wg_timers_event_handshake_complete(&peer->p_timers); - } + wg_timers_event_session_derived(&peer->p_timers); + wg_timers_event_handshake_complete(&peer->p_timers); break; case WG_PKT_COOKIE: cook = mtod(m, struct wg_pkt_cookie *); - if ((remote = wg_index_get(sc, cook->r_idx)) == NULL) { - DPRINTF(sc, "Unknown cookie index\n"); - goto error; + if ((remote = noise_remote_index_lookup(sc->sc_local, cook->r_idx)) == NULL) { + if ((keypair = noise_keypair_lookup(sc->sc_local, cook->r_idx)) == NULL) { + DPRINTF(sc, "Unknown cookie index\n"); + goto error; + } + remote = noise_keypair_remote(keypair); + noise_keypair_put(keypair); } - peer = CONTAINER_OF(remote, struct wg_peer, p_remote); + peer = noise_remote_arg(remote); if (cookie_maker_consume_payload(&peer->p_cookie, - cook->nonce, cook->ec) != 0) { + cook->nonce, cook->ec) == 0) + DPRINTF(sc, "Receiving cookie response\n"); + else DPRINTF(sc, "Could not decrypt cookie response\n"); - goto error; - } - DPRINTF(sc, "Receiving cookie response\n"); goto error; default: panic("invalid packet in handshake queue"); @@ -1453,6 +1306,8 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers); wg_peer_counters_add(peer, 0, m->m_pkthdr.len); error: + if (remote != NULL) + noise_remote_put(remote); m_freem(m); pool_put(&wg_packet_pool, pkt); } @@ -1484,12 +1339,16 @@ wg_handshake_worker(void *_sc) void wg_encap(struct wg_softc *sc, struct wg_packet *pkt) { - struct wg_pkt_data data; + struct noise_remote *remote; + struct wg_pkt_data *data; struct wg_peer *peer; struct mbuf *m, *ms; - int res, pad, off, len; + uint32_t idx; + int pad, off, len; - peer = pkt->p_peer; + remote = noise_keypair_remote(pkt->p_keypair); + peer = noise_remote_arg(remote); + noise_remote_put(remote); m = pkt->p_mbuf; /* Calculate what padding we need to add then limit it to the mtu of @@ -1506,7 +1365,7 @@ wg_encap(struct wg_softc *sc, struct wg_packet *pkt) bzero(mtod(ms, uint8_t *) + off, pad); } - /* TODO teach noise_remote_encrypt about mbufs. Currently we have to + /* TODO teach noise_keypair_encrypt about mbufs. Currently we have to * resort to m_pullup to create an encryptable buffer. */ len = m->m_pkthdr.len; if (m_makespace(m, len, NOISE_AUTHTAG_LEN, &off) == NULL) @@ -1515,30 +1374,20 @@ wg_encap(struct wg_softc *sc, struct wg_packet *pkt) goto error; /* Do encryption */ - res = noise_remote_encrypt(&peer->p_remote, &data.r_idx, &data.nonce, - mtod(m, uint8_t *), len); - - if (__predict_false(res == EINVAL)) { - goto error_free; - } else if (__predict_false(res == ESTALE)) { - wg_timers_event_want_initiation(&peer->p_timers); - } else if (__predict_false(res != 0)) { - panic("unexpected result: %d\n", res); - } + noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, mtod(m, uint8_t *), len); /* A packet with length 0 is a keepalive packet */ if (__predict_false(len == 0)) - DPRINTF(sc, "Sending keepalive packet to peer %llu\n", - peer->p_id); + DPRINTF(sc, "Sending keepalive packet to peer %llu\n", peer->p_id); /* Put header into packet */ if ((m = m_prepend(m, sizeof(struct wg_pkt_data), M_NOWAIT)) == NULL) goto error; - data.t = WG_PKT_DATA; - data.nonce = htole64(data.nonce); - memcpy(mtod(m, void *), &data, sizeof(struct wg_pkt_data)); - + data = mtod(m, struct wg_pkt_data *); + data->t = WG_PKT_DATA; + data->r_idx = idx; + data->nonce = htole64(pkt->p_nonce); /* * We would count ifc_opackets, ifc_obytes of m here, except if_snd * already does that for us, so no need to worry about it. @@ -1564,6 +1413,7 @@ error: void wg_decap(struct wg_softc *sc, struct wg_packet *pkt) { + struct noise_remote *remote; struct wg_pkt_data data; struct wg_peer *peer, *allowed_peer; struct mbuf *m; @@ -1571,7 +1421,9 @@ wg_decap(struct wg_softc *sc, struct wg_packet *pkt) struct ip6_hdr *ip6; int res, len; - peer = pkt->p_peer; + remote = noise_keypair_remote(pkt->p_keypair); + peer = noise_remote_arg(remote); + noise_remote_put(remote); m = pkt->p_mbuf; len = m->m_pkthdr.len; @@ -1579,28 +1431,25 @@ wg_decap(struct wg_softc *sc, struct wg_packet *pkt) memcpy(&data, mtod(m, void *), sizeof(struct wg_pkt_data)); m_adj(m, sizeof(struct wg_pkt_data)); - /* TODO teach noise_remote_decrypt about mbufs. Currently we have to + /* TODO teach noise_keypair_decrypt about mbufs. Currently we have to * resort to m_pullup to create an decryptable buffer. */ if ((m = m_pullup(m, m->m_pkthdr.len)) == NULL) { goto error; } - res = noise_remote_decrypt(&peer->p_remote, data.r_idx, - le64toh(data.nonce), mtod(m, void *), m->m_pkthdr.len); + pkt->p_nonce = letoh64(data.nonce); + res = noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, mtod(m, void *), m->m_pkthdr.len); if (__predict_false(res == EINVAL)) { goto error_free; } else if (__predict_false(res == ECONNRESET)) { wg_timers_event_handshake_complete(&peer->p_timers); - } else if (__predict_false(res == ESTALE)) { - wg_timers_event_want_initiation(&peer->p_timers); } else if (__predict_false(res != 0)) { panic("unexpected response: %d\n", res); } m_adj(m, -NOISE_AUTHTAG_LEN); - wg_peer_set_endpoint(peer, &pkt->p_endpoint); wg_peer_counters_add(peer, 0, len); counters_pkt(sc->sc_if.if_counters, ifc_ipackets, ifc_ibytes, @@ -1735,8 +1584,11 @@ wg_deliver_out(void *_peer) m_freem(m); } + noise_keypair_put(pkt->p_keypair); pool_put(&wg_packet_pool, pkt); } + if (noise_keep_key_fresh_send(peer->p_remote)) + wg_timers_event_want_initiation(&peer->p_timers); } void @@ -1750,14 +1602,20 @@ wg_deliver_in(void *_peer) while ((pkt = wg_queue_serial_dequeue(&peer->p_decap_serial)) != NULL) { m = pkt->p_mbuf; if (pkt->p_state == WG_PACKET_CRYPTED) { + if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0) { + m_freem(m); + goto put; + } + wg_timers_event_any_authenticated_packet_received( &peer->p_timers); wg_timers_event_any_authenticated_packet_traversal( &peer->p_timers); + wg_peer_set_endpoint(peer, &pkt->p_endpoint); if (m->m_pkthdr.len == 0) { m_freem(m); - continue; + goto put; } #if NBPFILTER > 0 @@ -1778,13 +1636,15 @@ wg_deliver_in(void *_peer) NET_UNLOCK(); wg_timers_event_data_received(&peer->p_timers); - } else { m_freem(m); } - +put: + noise_keypair_put(pkt->p_keypair); pool_put(&wg_packet_pool, pkt); } + if (noise_keep_key_fresh_recv(peer->p_remote)) + wg_timers_event_want_initiation(&peer->p_timers); } void @@ -1808,6 +1668,7 @@ wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_pack } else { mtx_leave(&serial->q_mtx); m_freem(pkt->p_mbuf); + noise_keypair_put(pkt->p_keypair); pool_put(&wg_packet_pool, pkt); return ENOBUFS; } @@ -1837,11 +1698,12 @@ wg_queue_in(struct wg_softc *sc, struct wg_peer *peer, struct wg_packet *pkt) void wg_queue_out(struct wg_softc *sc, struct wg_peer *peer) { + struct noise_keypair *keypair; struct wg_packet *pkt; struct mbuf_list ml; struct mbuf *m; - if (noise_remote_ready(&peer->p_remote) != 0) { + if ((keypair = noise_keypair_current(peer->p_remote)) == NULL) { wg_timers_event_want_initiation(&peer->p_timers); return; } @@ -1850,11 +1712,21 @@ wg_queue_out(struct wg_softc *sc, struct wg_peer *peer) while ((m = ml_dequeue(&ml)) != NULL) { pkt = m->m_pkthdr.ph_cookie; + pkt->p_keypair = noise_keypair_ref(keypair); + + if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0) { + ml_purge(&ml); + pool_put(&wg_packet_pool, pkt); + m_freem(m); + break; + } if (wg_queue_both(&sc->sc_encap_parallel, &peer->p_encap_serial, pkt) != 0) counters_inc(sc->sc_if.if_counters, ifc_oqdrops); } + noise_keypair_put(keypair); + task_add(wg_crypt_taskq, &sc->sc_encap); } @@ -1886,91 +1758,6 @@ wg_queue_parallel_dequeue(struct wg_queue *parallel) return pkt; } -struct noise_remote * -wg_remote_get(void *_sc, uint8_t public[NOISE_PUBLIC_KEY_LEN]) -{ - struct wg_peer *peer; - struct wg_softc *sc = _sc; - if ((peer = wg_peer_lookup(sc, public)) == NULL) - return NULL; - return &peer->p_remote; -} - -uint32_t -wg_index_set(void *_sc, struct noise_remote *remote) -{ - struct wg_peer *peer; - struct wg_softc *sc = _sc; - struct wg_index *index, *iter; - uint32_t key; - - /* - * We can modify this without a lock as wg_index_set, wg_index_drop are - * guaranteed to be serialised (per remote). - */ - peer = CONTAINER_OF(remote, struct wg_peer, p_remote); - index = SLIST_FIRST(&peer->p_unused_index); - KASSERT(index != NULL); - SLIST_REMOVE_HEAD(&peer->p_unused_index, i_unused_entry); - - index->i_value = remote; - - mtx_enter(&sc->sc_index_mtx); -assign_id: - key = index->i_key = arc4random(); - key &= sc->sc_index_mask; - LIST_FOREACH(iter, &sc->sc_index[key], i_entry) - if (iter->i_key == index->i_key) - goto assign_id; - - LIST_INSERT_HEAD(&sc->sc_index[key], index, i_entry); - - mtx_leave(&sc->sc_index_mtx); - - /* Likewise, no need to lock for index here. */ - return index->i_key; -} - -struct noise_remote * -wg_index_get(void *_sc, uint32_t key0) -{ - struct wg_softc *sc = _sc; - struct wg_index *iter; - struct noise_remote *remote = NULL; - uint32_t key = key0 & sc->sc_index_mask; - - mtx_enter(&sc->sc_index_mtx); - LIST_FOREACH(iter, &sc->sc_index[key], i_entry) - if (iter->i_key == key0) { - remote = iter->i_value; - break; - } - mtx_leave(&sc->sc_index_mtx); - return remote; -} - -void -wg_index_drop(void *_sc, uint32_t key0) -{ - struct wg_softc *sc = _sc; - struct wg_index *iter; - struct wg_peer *peer = NULL; - uint32_t key = key0 & sc->sc_index_mask; - - mtx_enter(&sc->sc_index_mtx); - LIST_FOREACH(iter, &sc->sc_index[key], i_entry) - if (iter->i_key == key0) { - LIST_REMOVE(iter, i_entry); - break; - } - mtx_leave(&sc->sc_index_mtx); - - /* We expect an index match */ - KASSERT(iter != NULL); - peer = CONTAINER_OF(iter->i_value, struct wg_peer, p_remote); - SLIST_INSERT_HEAD(&peer->p_unused_index, iter, i_unused_entry); -} - struct mbuf * wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6, void *_uh, int hlen) @@ -1978,7 +1765,6 @@ wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6, struct wg_pkt_data *data; struct noise_remote *remote; struct wg_packet *pkt; - struct wg_peer *peer; struct wg_softc *sc = _sc; struct udphdr *uh = _uh; @@ -2046,13 +1832,13 @@ wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6, } data = mtod(m, struct wg_pkt_data *); - if ((remote = wg_index_get(sc, data->r_idx)) == NULL) + if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL) goto error_mbuf; - peer = CONTAINER_OF(remote, struct wg_peer, p_remote); - pkt->p_peer = peer; + remote = noise_keypair_remote(pkt->p_keypair); pkt->p_mbuf = m; - wg_queue_in(sc, peer, pkt); + wg_queue_in(sc, noise_remote_arg(remote), pkt); + noise_remote_put(remote); } else { counters_inc(sc->sc_if.if_counters, ifc_ierrors); goto error_mbuf; @@ -2090,7 +1876,6 @@ wg_qstart(struct ifqueue *ifq) peer = t->t_peer; pkt->p_mbuf = m; - pkt->p_peer = peer; pkt->p_mtu = t->t_mtu; m->m_pkthdr.ph_cookie = pkt; @@ -2205,12 +1990,14 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) struct wg_peer *peer, *tpeer; struct wg_aip *aip, *taip; + struct noise_remote *remote; + in_port_t port; int rtable; uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE]; size_t i, j; - int ret, has_identity; + int ret; if ((ret = suser(curproc)) != 0) return ret; @@ -2222,27 +2009,23 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) goto error; if (iface_o.i_flags & WG_INTERFACE_REPLACE_PEERS) - TAILQ_FOREACH_SAFE(peer, &sc->sc_peer_seq, p_seq_entry, tpeer) + TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer) wg_peer_destroy(peer); if (iface_o.i_flags & WG_INTERFACE_HAS_PRIVATE && - (noise_local_keys(&sc->sc_local, NULL, private) || + (noise_local_keys(sc->sc_local, NULL, private) || timingsafe_bcmp(private, iface_o.i_private, WG_KEY_SIZE))) { if (curve25519_generate_public(public, iface_o.i_private)) { - if ((peer = wg_peer_lookup(sc, public)) != NULL) - wg_peer_destroy(peer); - } - noise_local_lock_identity(&sc->sc_local); - has_identity = noise_local_set_private(&sc->sc_local, - iface_o.i_private); - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) { - noise_remote_precompute(&peer->p_remote); - wg_timers_event_reset_handshake_last_sent(&peer->p_timers); - noise_remote_expire_current(&peer->p_remote); + if ((remote = noise_remote_lookup(sc->sc_local, public)) != NULL) { + wg_peer_destroy(noise_remote_arg(remote)); + noise_remote_put(remote); + } } - cookie_checker_update(&sc->sc_cookie, - has_identity == 0 ? public : NULL); - noise_local_unlock_identity(&sc->sc_local); + noise_local_private(sc->sc_local, iface_o.i_private); + if (noise_local_keys(sc->sc_local, public, NULL) == 0) + cookie_checker_update(&sc->sc_cookie, public); + else + cookie_checker_update(&sc->sc_cookie, NULL); } if (iface_o.i_flags & WG_INTERFACE_HAS_PORT) @@ -2256,7 +2039,7 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) rtable = sc->sc_udp_rtable; if (port != sc->sc_udp_port || rtable != sc->sc_udp_rtable) { - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) wg_peer_clear_src(peer); if (sc->sc_if.if_flags & IFF_RUNNING) @@ -2287,12 +2070,12 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) } /* Get local public and check that peer key doesn't match */ - if (noise_local_keys(&sc->sc_local, public, NULL) == 0 && + if (noise_local_keys(sc->sc_local, public, NULL) == 0 && bcmp(public, peer_o.p_public, WG_KEY_SIZE) == 0) goto next_peer; /* Lookup peer, or create if it doesn't exist */ - if ((peer = wg_peer_lookup(sc, peer_o.p_public)) == NULL) { + if ((remote = noise_remote_lookup(sc->sc_local, peer_o.p_public)) == NULL) { /* If we want to delete, no need creating a new one. * Also, don't create a new one if we only want to * update. */ @@ -2307,6 +2090,9 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) ret = ENOMEM; goto error; } + } else { + peer = noise_remote_arg(remote); + noise_remote_put(remote); } /* Remove peer and continue if specified */ @@ -2319,7 +2105,7 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) wg_peer_set_sockaddr(peer, &peer_o.p_sa); if (peer_o.p_flags & WG_PEER_HAS_PSK) - noise_remote_set_psk(&peer->p_remote, peer_o.p_psk); + noise_remote_set_psk(peer->p_remote, peer_o.p_psk); if (peer_o.p_flags & WG_PEER_HAS_PKA) wg_timers_set_persistent_keepalive(&peer->p_timers, @@ -2394,7 +2180,7 @@ wg_ioctl_get(struct wg_softc *sc, struct wg_data_io *data) if (!is_suser) goto copy_out_iface; - if (noise_local_keys(&sc->sc_local, iface_o.i_public, + if (noise_local_keys(sc->sc_local, iface_o.i_public, iface_o.i_private) == 0) { iface_o.i_flags |= WG_INTERFACE_HAS_PUBLIC; iface_o.i_flags |= WG_INTERFACE_HAS_PRIVATE; @@ -2407,12 +2193,12 @@ wg_ioctl_get(struct wg_softc *sc, struct wg_data_io *data) peer_count = 0; peer_p = &iface_p->i_peers[0]; - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) { + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { bzero(&peer_o, sizeof(peer_o)); peer_o.p_flags = WG_PEER_HAS_PUBLIC; peer_o.p_protocol_version = 1; - if (noise_remote_keys(&peer->p_remote, peer_o.p_public, + if (noise_remote_keys(peer->p_remote, peer_o.p_public, peer_o.p_psk) == 0) peer_o.p_flags |= WG_PEER_HAS_PSK; @@ -2530,7 +2316,7 @@ wg_up(struct wg_softc *sc) */ ret = wg_bind(sc, &sc->sc_udp_port, &sc->sc_udp_rtable); if (ret == 0) { - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) wg_timers_enable(&peer->p_timers); } rw_exit_write(&sc->sc_lock); @@ -2558,15 +2344,15 @@ wg_down(struct wg_softc *sc) * that isn't granularly locked. */ rw_enter_read(&sc->sc_lock); - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) { + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { mq_purge(&peer->p_stage_queue); wg_timers_disable(&peer->p_timers); } taskq_barrier(wg_handshake_taskq); - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) { - noise_remote_clear(&peer->p_remote); - wg_timers_event_reset_handshake_last_sent(&peer->p_timers); + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { + noise_remote_handshake_clear(peer->p_remote); + noise_remote_keypairs_clear(peer->p_remote); } wg_unbind(sc); @@ -2579,7 +2365,6 @@ wg_clone_create(struct if_clone *ifc, int unit) { struct ifnet *ifp; struct wg_softc *sc; - struct noise_upcall local_upcall; KERNEL_ASSERT_LOCKED(); @@ -2604,20 +2389,15 @@ wg_clone_create(struct if_clone *ifc, int unit) if ((sc = malloc(sizeof(*sc), M_DEVBUF, M_NOWAIT | M_ZERO)) == NULL) goto ret_00; - local_upcall.u_arg = sc; - local_upcall.u_remote_get = wg_remote_get; - local_upcall.u_index_set = wg_index_set; - local_upcall.u_index_drop = wg_index_drop; - - /* sc_if is initialised after everything else */ - arc4random_buf(&sc->sc_secret, sizeof(sc->sc_secret)); - rw_init(&sc->sc_lock, "wg"); - noise_local_init(&sc->sc_local, &local_upcall); - if (cookie_checker_init(&sc->sc_cookie, &wg_ratelimit_pool) != 0) + if ((sc->sc_local = noise_local_alloc(sc)) == NULL) goto ret_01; + if (cookie_checker_init(&sc->sc_cookie, &wg_ratelimit_pool) != 0) + goto ret_02; sc->sc_udp_port = 0; sc->sc_udp_rtable = 0; + TAILQ_INIT(&sc->sc_peers); + sc->sc_peer_num = 0; rw_init(&sc->sc_so_lock, "wg_so"); sc->sc_so4 = NULL; @@ -2627,24 +2407,11 @@ wg_clone_create(struct if_clone *ifc, int unit) sc->sc_aip_num = 0; if ((sc->sc_aip4 = art_alloc(0, 32, 0)) == NULL) - goto ret_02; + goto ret_03; #ifdef INET6 if ((sc->sc_aip6 = art_alloc(0, 128, 0)) == NULL) - goto ret_03; -#endif - - rw_init(&sc->sc_peer_lock, "wg_peer"); - sc->sc_peer_num = 0; - if ((sc->sc_peer = hashinit(HASHTABLE_PEER_SIZE, M_DEVBUF, - M_NOWAIT, &sc->sc_peer_mask)) == NULL) goto ret_04; - - TAILQ_INIT(&sc->sc_peer_seq); - - mtx_init(&sc->sc_index_mtx, IPL_NET); - if ((sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF, - M_NOWAIT, &sc->sc_index_mask)) == NULL) - goto ret_05; +#endif task_set(&sc->sc_handshake, wg_handshake_worker, sc); mq_init(&sc->sc_handshake_queue, MAX_QUEUED_HANDSHAKES, IPL_NET); @@ -2684,22 +2451,42 @@ wg_clone_create(struct if_clone *ifc, int unit) DPRINTF(sc, "Interface created\n"); return 0; -ret_05: - hashfree(sc->sc_peer, HASHTABLE_PEER_SIZE, M_DEVBUF); -ret_04: + #ifdef INET6 free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6)); -ret_03: +ret_04: #endif free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4)); -ret_02: +ret_03: cookie_checker_deinit(&sc->sc_cookie); +ret_02: + noise_local_put(sc->sc_local); ret_01: free(sc, M_DEVBUF, sizeof(*sc)); ret_00: return ENOBUFS; } +void +wg_clone_free(struct noise_local *l) +{ + struct wg_softc *sc = noise_local_arg(l); + wg_counter--; + if (wg_counter == 0) { + KASSERT(wg_handshake_taskq != NULL && wg_crypt_taskq != NULL); + taskq_destroy(wg_handshake_taskq); + taskq_destroy(wg_crypt_taskq); + wg_handshake_taskq = NULL; + wg_crypt_taskq = NULL; + } +#ifdef INET6 + free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6)); +#endif + free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4)); + cookie_checker_deinit(&sc->sc_cookie); + free(sc, M_DEVBUF, sizeof(*sc)); +} + int wg_clone_destroy(struct ifnet *ifp) { @@ -2709,33 +2496,16 @@ wg_clone_destroy(struct ifnet *ifp) KERNEL_ASSERT_LOCKED(); rw_enter_write(&sc->sc_lock); - TAILQ_FOREACH_SAFE(peer, &sc->sc_peer_seq, p_seq_entry, tpeer) + TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer) wg_peer_destroy(peer); + rw_exit_write(&sc->sc_lock); wg_unbind(sc); if_detach(ifp); - wg_counter--; - if (wg_counter == 0) { - KASSERT(wg_handshake_taskq != NULL && wg_crypt_taskq != NULL); - taskq_destroy(wg_handshake_taskq); - taskq_destroy(wg_crypt_taskq); - wg_handshake_taskq = NULL; - wg_crypt_taskq = NULL; - } - + noise_local_free(sc->sc_local, wg_clone_free); DPRINTF(sc, "Destroyed interface\n"); - - hashfree(sc->sc_index, HASHTABLE_INDEX_SIZE, M_DEVBUF); - hashfree(sc->sc_peer, HASHTABLE_PEER_SIZE, M_DEVBUF); -#ifdef INET6 - free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6)); -#endif - free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4)); - cookie_checker_deinit(&sc->sc_cookie); - noise_local_deinit(&sc->sc_local); - free(sc, M_DEVBUF, sizeof(*sc)); return 0; } |