diff options
Diffstat (limited to 'src/if_wg.c')
-rw-r--r-- | src/if_wg.c | 569 |
1 files changed, 152 insertions, 417 deletions
diff --git a/src/if_wg.c b/src/if_wg.c index 6f4d225..70f6b4b 100644 --- a/src/if_wg.c +++ b/src/if_wg.c @@ -89,13 +89,7 @@ __FBSDID("$FreeBSD$"); #define MAX_QUEUED_HANDSHAKES 4096 -#define HASHTABLE_PEER_SIZE (1 << 11) -#define HASHTABLE_INDEX_SIZE (1 << 13) -#define MAX_PEERS_PER_IFACE (1 << 20) - -#define REKEY_TIMEOUT 5 #define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */ -#define KEEPALIVE_TIMEOUT 10 #define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT) #define NEW_HANDSHAKE_TIMEOUT (REKEY_TIMEOUT + KEEPALIVE_TIMEOUT) #define UNDERLOAD_TIMEOUT 1 @@ -163,19 +157,13 @@ struct wg_endpoint { struct wg_tag { struct m_tag t_tag; struct wg_endpoint t_endpoint; - struct wg_peer *t_peer; + struct noise_keypair *t_keypair; + uint64_t t_nonce; struct mbuf *t_mbuf; int t_done; int t_mtu; }; -struct wg_index { - LIST_ENTRY(wg_index) i_entry; - SLIST_ENTRY(wg_index) i_unused_entry; - uint32_t i_key; - struct noise_remote *i_value; -}; - struct wg_timers { /* t_lock is for blocking wg_timers_event_* when setting t_disabled. */ struct rwlock t_lock; @@ -190,7 +178,6 @@ struct wg_timers { struct callout t_persistent_keepalive; struct mtx t_handshake_mtx; - struct timespec t_handshake_last_sent; struct timespec t_handshake_complete; volatile int t_handshake_retries; }; @@ -209,26 +196,21 @@ struct wg_queue { }; struct wg_peer { - CK_LIST_ENTRY(wg_peer) p_hash_entry; - CK_LIST_ENTRY(wg_peer) p_entry; + TAILQ_ENTRY(wg_peer) p_entry; uint64_t p_id; struct wg_softc *p_sc; - struct noise_remote p_remote; + struct noise_remote *p_remote; struct cookie_maker p_cookie; struct wg_timers p_timers; struct rwlock p_endpoint_lock; struct wg_endpoint p_endpoint; - SLIST_HEAD(,wg_index) p_unused_index; - struct wg_index p_index[3]; - struct wg_queue p_stage_queue; struct wg_queue p_encap_queue; struct wg_queue p_decap_queue; - struct grouptask p_clear_secrets; struct grouptask p_send_initiation; struct grouptask p_send_keepalive; struct grouptask p_send; @@ -238,8 +220,6 @@ struct wg_peer { counter_u64_t p_rx_bytes; CK_LIST_HEAD(, wg_aip) p_aips; - struct mtx p_lock; - struct epoch_context p_ctx; }; enum route_direction { @@ -264,15 +244,6 @@ struct wg_allowedip { uint8_t cidr; }; -struct wg_hashtable { - struct mtx h_mtx; - SIPHASH_KEY h_secret; - CK_LIST_HEAD(, wg_peer) h_peers_list; - CK_LIST_HEAD(, wg_peer) *h_peers; - u_long h_peers_mask; - size_t h_num_peers; -}; - struct wg_socket { struct mtx so_mtx; struct socket *so_so4; @@ -289,13 +260,15 @@ struct wg_softc { struct ucred *sc_ucred; struct wg_socket sc_socket; - struct wg_hashtable sc_hashtable; struct wg_aip_table sc_aips; + TAILQ_HEAD(,wg_peer) sc_peers; + size_t sc_peers_num; + struct mbufq sc_handshake_queue; struct grouptask sc_handshake; - struct noise_local sc_local; + struct noise_local *sc_local; struct cookie_checker sc_cookie; struct buf_ring *sc_encap_ring; @@ -304,12 +277,7 @@ struct wg_softc { struct grouptask *sc_encrypt; struct grouptask *sc_decrypt; - struct rwlock sc_index_lock; - LIST_HEAD(,wg_index) *sc_index; - u_long sc_index_mask; - struct sx sc_lock; - volatile u_int sc_peer_count; }; #define WGF_DYING 0x0001 @@ -369,11 +337,9 @@ static void wg_timers_event_any_authenticated_packet_sent(struct wg_timers *); static void wg_timers_event_any_authenticated_packet_received(struct wg_timers *); static void wg_timers_event_any_authenticated_packet_traversal(struct wg_timers *); static void wg_timers_event_handshake_initiated(struct wg_timers *); -static void wg_timers_event_handshake_responded(struct wg_timers *); static void wg_timers_event_handshake_complete(struct wg_timers *); static void wg_timers_event_session_derived(struct wg_timers *); static void wg_timers_event_want_initiation(struct wg_timers *); -static void wg_timers_event_reset_handshake_last_sent(struct wg_timers *); static void wg_timers_run_send_initiation(struct wg_timers *, int); static void wg_timers_run_retry_handshake(struct wg_timers *); static void wg_timers_run_send_keepalive(struct wg_timers *); @@ -385,8 +351,6 @@ static void wg_timers_enable(struct wg_timers *); static void wg_timers_disable(struct wg_timers *); static void wg_timers_set_persistent_keepalive(struct wg_timers *, uint16_t); static void wg_timers_get_last_handshake(struct wg_timers *, struct wg_timespec64 *); -static int wg_timers_expired_handshake_last_sent(struct wg_timers *); -static int wg_timers_check_handshake_last_sent(struct wg_timers *); static void wg_queue_init(struct wg_queue *, const char *); static void wg_queue_deinit(struct wg_queue *); static void wg_queue_purge(struct wg_queue *); @@ -404,14 +368,9 @@ static int wg_peer_remove(struct radix_node *, void *); static void wg_peer_remove_all(struct wg_softc *); static int wg_aip_delete(struct wg_aip_table *, struct wg_peer *); static struct wg_peer *wg_aip_lookup(struct wg_aip_table *, struct mbuf *, enum route_direction); -static void wg_hashtable_init(struct wg_hashtable *); -static void wg_hashtable_destroy(struct wg_hashtable *); -static void wg_hashtable_peer_insert(struct wg_hashtable *, struct wg_peer *); -static struct wg_peer *wg_peer_lookup(struct wg_softc *, const uint8_t [32]); -static void wg_hashtable_peer_remove(struct wg_hashtable *, struct wg_peer *); static int wg_cookie_validate_packet(struct cookie_checker *, struct mbuf *, int); static struct wg_peer *wg_peer_alloc(struct wg_softc *); -static void wg_peer_free_deferred(epoch_context_t); +static void wg_peer_free_deferred(struct noise_remote *); static void wg_peer_destroy(struct wg_peer *); static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t); static void wg_send_initiation(struct wg_peer *); @@ -430,10 +389,6 @@ static void wg_decap(struct wg_softc *, struct mbuf *); static void wg_softc_handshake_receive(struct wg_softc *); static void wg_softc_decrypt(struct wg_softc *); static void wg_softc_encrypt(struct wg_softc *); -static struct noise_remote *wg_remote_get(struct wg_softc *, uint8_t [NOISE_PUBLIC_KEY_LEN]); -static uint32_t wg_index_set(struct wg_softc *, struct noise_remote *); -static struct noise_remote *wg_index_get(struct wg_softc *, uint32_t); -static void wg_index_drop(struct wg_softc *, uint32_t); static int wg_update_endpoint_addrs(struct wg_endpoint *, const struct sockaddr *, struct ifnet *); static void wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *); static void wg_encrypt_dispatch(struct wg_softc *); @@ -480,9 +435,6 @@ wg_peer_alloc(struct wg_softc *sc) taskqgroup_attach(qgroup_wg_tqg, &peer->p_send_initiation, peer, NULL, NULL, "wg initiation"); GROUPTASK_INIT(&peer->p_send_keepalive, 0, (gtask_fn_t *)wg_send_keepalive, peer); taskqgroup_attach(qgroup_wg_tqg, &peer->p_send_keepalive, peer, NULL, NULL, "wg keepalive"); - GROUPTASK_INIT(&peer->p_clear_secrets, 0, (gtask_fn_t *)noise_remote_clear, &peer->p_remote); - taskqgroup_attach(qgroup_wg_tqg, &peer->p_clear_secrets, - &peer->p_remote, NULL, NULL, "wg clear secrets"); GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer); taskqgroup_attach(qgroup_wg_tqg, &peer->p_send, peer, NULL, NULL, "wg send"); @@ -494,91 +446,13 @@ wg_peer_alloc(struct wg_softc *sc) peer->p_tx_bytes = counter_u64_alloc(M_WAITOK); peer->p_rx_bytes = counter_u64_alloc(M_WAITOK); - SLIST_INIT(&peer->p_unused_index); - SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[0], - i_unused_entry); - SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[1], - i_unused_entry); - SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[2], - i_unused_entry); - return (peer); } -#define WG_HASHTABLE_PEER_FOREACH(peer, i, ht) \ - for (i = 0; i < HASHTABLE_PEER_SIZE; i++) \ - LIST_FOREACH(peer, &(ht)->h_peers[i], p_hash_entry) -#define WG_HASHTABLE_PEER_FOREACH_SAFE(peer, i, ht, tpeer) \ - for (i = 0; i < HASHTABLE_PEER_SIZE; i++) \ - CK_LIST_FOREACH_SAFE(peer, &(ht)->h_peers[i], p_hash_entry, tpeer) -static void -wg_hashtable_init(struct wg_hashtable *ht) -{ - mtx_init(&ht->h_mtx, "hash lock", NULL, MTX_DEF); - arc4random_buf(&ht->h_secret, sizeof(ht->h_secret)); - ht->h_num_peers = 0; - ht->h_peers = hashinit(HASHTABLE_PEER_SIZE, M_DEVBUF, - &ht->h_peers_mask); -} - -static void -wg_hashtable_destroy(struct wg_hashtable *ht) -{ - MPASS(ht->h_num_peers == 0); - mtx_destroy(&ht->h_mtx); - hashdestroy(ht->h_peers, M_DEVBUF, ht->h_peers_mask); -} - static void -wg_hashtable_peer_insert(struct wg_hashtable *ht, struct wg_peer *peer) +wg_peer_free_deferred(struct noise_remote *r) { - uint64_t key; - - key = siphash24(&ht->h_secret, peer->p_remote.r_public, - sizeof(peer->p_remote.r_public)); - - mtx_lock(&ht->h_mtx); - ht->h_num_peers++; - CK_LIST_INSERT_HEAD(&ht->h_peers[key & ht->h_peers_mask], peer, p_hash_entry); - CK_LIST_INSERT_HEAD(&ht->h_peers_list, peer, p_entry); - mtx_unlock(&ht->h_mtx); -} - -static struct wg_peer * -wg_peer_lookup(struct wg_softc *sc, - const uint8_t pubkey[WG_KEY_SIZE]) -{ - struct wg_hashtable *ht = &sc->sc_hashtable; - uint64_t key; - struct wg_peer *i = NULL; - - key = siphash24(&ht->h_secret, pubkey, WG_KEY_SIZE); - - mtx_lock(&ht->h_mtx); - CK_LIST_FOREACH(i, &ht->h_peers[key & ht->h_peers_mask], p_hash_entry) { - if (timingsafe_bcmp(i->p_remote.r_public, pubkey, - WG_KEY_SIZE) == 0) - break; - } - mtx_unlock(&ht->h_mtx); - - return i; -} - -static void -wg_hashtable_peer_remove(struct wg_hashtable *ht, struct wg_peer *peer) -{ - mtx_lock(&ht->h_mtx); - ht->h_num_peers--; - CK_LIST_REMOVE(peer, p_hash_entry); - CK_LIST_REMOVE(peer, p_entry); - mtx_unlock(&ht->h_mtx); -} - -static void -wg_peer_free_deferred(epoch_context_t ctx) -{ - struct wg_peer *peer = __containerof(ctx, struct wg_peer, p_ctx); + struct wg_peer *peer = noise_remote_arg(r); counter_u64_free(peer->p_tx_bytes); counter_u64_free(peer->p_rx_bytes); rw_destroy(&peer->p_timers.t_lock); @@ -599,13 +473,11 @@ wg_peer_destroy(struct wg_peer *peer) wg_timers_disable(&peer->p_timers); /* Ensure the tasks have finished running */ - GROUPTASK_DRAIN(&peer->p_clear_secrets); GROUPTASK_DRAIN(&peer->p_send_initiation); GROUPTASK_DRAIN(&peer->p_send_keepalive); GROUPTASK_DRAIN(&peer->p_recv); GROUPTASK_DRAIN(&peer->p_send); - taskqgroup_detach(qgroup_wg_tqg, &peer->p_clear_secrets); taskqgroup_detach(qgroup_wg_tqg, &peer->p_send_initiation); taskqgroup_detach(qgroup_wg_tqg, &peer->p_send_keepalive); taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv); @@ -616,10 +488,10 @@ wg_peer_destroy(struct wg_peer *peer) wg_queue_deinit(&peer->p_stage_queue); /* Final cleanup */ - --peer->p_sc->sc_peer_count; - noise_remote_clear(&peer->p_remote); + peer->p_sc->sc_peers_num--; + TAILQ_REMOVE(&peer->p_sc->sc_peers, peer, p_entry); DPRINTF(peer->p_sc, "Peer %llu destroyed\n", (unsigned long long)peer->p_id); - NET_EPOCH_CALL(wg_peer_free_deferred, &peer->p_ctx); + noise_remote_free(peer->p_remote, wg_peer_free_deferred); } static void @@ -888,11 +760,8 @@ wg_peer_remove_all(struct wg_softc *sc) sx_assert(&sc->sc_lock, SX_XLOCKED); - CK_LIST_FOREACH_SAFE(peer, &sc->sc_hashtable.h_peers_list, - p_entry, tpeer) { - wg_hashtable_peer_remove(&sc->sc_hashtable, peer); + TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer) wg_peer_destroy(peer); - } } static int @@ -1253,29 +1122,6 @@ wg_timers_get_last_handshake(struct wg_timers *t, struct wg_timespec64 *time) rw_runlock(&t->t_lock); } -static int -wg_timers_expired_handshake_last_sent(struct wg_timers *t) -{ - struct timespec uptime; - struct timespec expire = { .tv_sec = REKEY_TIMEOUT, .tv_nsec = 0 }; - - getnanouptime(&uptime); - timespecadd(&t->t_handshake_last_sent, &expire, &expire); - return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0; -} - -static int -wg_timers_check_handshake_last_sent(struct wg_timers *t) -{ - int ret; - - rw_wlock(&t->t_lock); - if ((ret = wg_timers_expired_handshake_last_sent(t)) == ETIMEDOUT) - getnanouptime(&t->t_handshake_last_sent); - rw_wunlock(&t->t_lock); - return (ret); -} - /* Should be called after an authenticated data packet is sent. */ static void wg_timers_event_data_sent(struct wg_timers *t) @@ -1354,14 +1200,6 @@ wg_timers_event_handshake_initiated(struct wg_timers *t) rw_runlock(&t->t_lock); } -static void -wg_timers_event_handshake_responded(struct wg_timers *t) -{ - rw_wlock(&t->t_lock); - getnanouptime(&t->t_handshake_last_sent); - rw_wunlock(&t->t_lock); -} - /* * Should be called after a handshake response message is received and processed * or when getting key confirmation via the first data message. @@ -1405,20 +1243,12 @@ wg_timers_event_want_initiation(struct wg_timers *t) } static void -wg_timers_event_reset_handshake_last_sent(struct wg_timers *t) -{ - rw_wlock(&t->t_lock); - t->t_handshake_last_sent.tv_sec -= (REKEY_TIMEOUT + 1); - rw_wunlock(&t->t_lock); -} - -static void wg_timers_run_send_initiation(struct wg_timers *t, int is_retry) { struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers); if (!is_retry) t->t_handshake_retries = 0; - if (wg_timers_expired_handshake_last_sent(t) == ETIMEDOUT) + if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT) GROUPTASK_ENQUEUE(&peer->p_send_initiation); } @@ -1489,7 +1319,7 @@ wg_timers_run_zero_key_material(struct wg_timers *t) DPRINTF(peer->p_sc, "Zeroing out all keys for peer %llu, since we " "haven't received a new one in %d seconds\n", (unsigned long long)peer->p_id, REJECT_AFTER_TIME * 3); - GROUPTASK_ENQUEUE(&peer->p_clear_secrets); + noise_remote_keypairs_clear(peer->p_remote); } static void @@ -1520,15 +1350,14 @@ wg_send_initiation(struct wg_peer *peer) struct wg_pkt_initiation pkt; struct epoch_tracker et; - if (wg_timers_check_handshake_last_sent(&peer->p_timers) != ETIMEDOUT) - return; - DPRINTF(peer->p_sc, "Sending handshake initiation to peer %llu\n", - (unsigned long long)peer->p_id); - NET_EPOCH_ENTER(et); - if (noise_create_initiation(&peer->p_remote, &pkt.s_idx, pkt.ue, + if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue, pkt.es, pkt.ets) != 0) goto out; + + DPRINTF(peer->p_sc, "Sending handshake initiation to peer %llu\n", + (unsigned long long)peer->p_id); + pkt.t = WG_PKT_INITIATION; cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, sizeof(pkt)-sizeof(pkt.m)); @@ -1545,21 +1374,17 @@ wg_send_response(struct wg_peer *peer) struct epoch_tracker et; NET_EPOCH_ENTER(et); + if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx, + pkt.ue, pkt.en) != 0) + goto out; DPRINTF(peer->p_sc, "Sending handshake response to peer %llu\n", (unsigned long long)peer->p_id); - if (noise_create_response(&peer->p_remote, &pkt.s_idx, &pkt.r_idx, - pkt.ue, pkt.en) != 0) - goto out; - if (noise_remote_begin_session(&peer->p_remote) != 0) - goto out; - wg_timers_event_session_derived(&peer->p_timers); pkt.t = WG_PKT_RESPONSE; cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, sizeof(pkt)-sizeof(pkt.m)); - wg_timers_event_handshake_responded(&peer->p_timers); wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt)); out: NET_EPOCH_EXIT(et); @@ -1600,7 +1425,6 @@ wg_send_keepalive(struct wg_peer *peer) m_freem(m); return; } - t->t_peer = peer; t->t_mbuf = NULL; t->t_done = 0; t->t_mtu = 0; /* MTU == 0 OK for keepalive */ @@ -1673,10 +1497,6 @@ wg_handshake(struct wg_softc *sc, struct mbuf *m) res = wg_cookie_validate_packet(&sc->sc_cookie, m, underload); - if (res && res != EAGAIN) { - printf("validate_packet got %d\n", res); - goto free; - } if (res == EINVAL) { DPRINTF(sc, "Invalid initiation MAC\n"); goto free; @@ -1699,13 +1519,13 @@ wg_handshake(struct wg_softc *sc, struct mbuf *m) wg_send_cookie(sc, &init->m, init->s_idx, m); goto free; } - if (noise_consume_initiation(&sc->sc_local, &remote, + if (noise_consume_initiation(sc->sc_local, &remote, init->s_idx, init->ue, init->es, init->ets) != 0) { DPRINTF(sc, "Invalid handshake initiation"); goto free; } - peer = __containerof(remote, struct wg_peer, p_remote); + peer = noise_remote_arg(remote); DPRINTF(sc, "Receiving handshake initiation from peer %llu\n", (unsigned long long)peer->p_id); counter_u64_add(peer->p_rx_bytes, sizeof(*init)); @@ -1721,46 +1541,43 @@ wg_handshake(struct wg_softc *sc, struct mbuf *m) wg_send_cookie(sc, &resp->m, resp->s_idx, m); goto free; } - - if ((remote = wg_index_get(sc, resp->r_idx)) == NULL) { - DPRINTF(sc, "Unknown handshake response\n"); - goto free; - } - peer = __containerof(remote, struct wg_peer, p_remote); - if (noise_consume_response(remote, resp->s_idx, resp->r_idx, - resp->ue, resp->en) != 0) { + if (noise_consume_response(sc->sc_local, &remote, resp->s_idx, + resp->r_idx, resp->ue, resp->en) != 0) { DPRINTF(sc, "Invalid handshake response\n"); goto free; } + peer = noise_remote_arg(remote); + DPRINTF(sc, "Receiving handshake response from peer %llu\n", (unsigned long long)peer->p_id); + + wg_timers_event_session_derived(&peer->p_timers); + wg_timers_event_handshake_complete(&peer->p_timers); + counter_u64_add(peer->p_rx_bytes, sizeof(*resp)); if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, sizeof(*resp)); wg_peer_set_endpoint_from_tag(peer, t); - if (noise_remote_begin_session(&peer->p_remote) == 0) { - wg_timers_event_session_derived(&peer->p_timers); - wg_timers_event_handshake_complete(&peer->p_timers); - } break; case WG_PKT_COOKIE: cook = mtod(m, struct wg_pkt_cookie *); - if ((remote = wg_index_get(sc, cook->r_idx)) == NULL) { + if ((remote = noise_remote_index_lookup(sc->sc_local, + cook->r_idx)) == NULL) { DPRINTF(sc, "Unknown cookie index\n"); goto free; } - peer = __containerof(remote, struct wg_peer, p_remote); + peer = noise_remote_arg(remote); if (cookie_maker_consume_payload(&peer->p_cookie, - cook->nonce, cook->ec) != 0) { + cook->nonce, cook->ec) == 0) { + DPRINTF(sc, "Receiving cookie response\n"); + } else { DPRINTF(sc, "Could not decrypt cookie response\n"); - goto free; } - - DPRINTF(sc, "Receiving cookie response\n"); + noise_remote_put(remote); goto free; default: goto free; @@ -1788,15 +1605,19 @@ wg_encap(struct wg_softc *sc, struct mbuf *m) { struct wg_pkt_data *data; size_t padding_len, plaintext_len, out_len; + struct noise_remote *remote; struct mbuf *mc; struct wg_peer *peer; struct wg_tag *t; uint64_t nonce; - int res, allocation_order; + int allocation_order; NET_EPOCH_ASSERT(); t = wg_tag_get(m); - peer = t->t_peer; + remote = noise_keypair_remote(t->t_keypair); + peer = noise_remote_arg(remote); + /* We can put the remote as we still hold a keypair ref */ + noise_remote_put(remote); plaintext_len = MIN(WG_PKT_WITH_PADDING(m->m_pkthdr.len), t->t_mtu); padding_len = plaintext_len - m->m_pkthdr.len; @@ -1822,23 +1643,11 @@ wg_encap(struct wg_softc *sc, struct mbuf *m) data->t = WG_PKT_DATA; - res = noise_remote_encrypt(&peer->p_remote, &data->r_idx, &nonce, + noise_keypair_encrypt(t->t_keypair, &data->r_idx, t->t_nonce, data->buf, plaintext_len); - nonce = htole64(nonce); /* Wire format is little endian. */ - memcpy(data->nonce, &nonce, sizeof(data->nonce)); - if (__predict_false(res)) { - if (res == EINVAL) { - wg_timers_event_want_initiation(&peer->p_timers); - m_freem(mc); - goto error; - } else if (res == ESTALE) { - wg_timers_event_want_initiation(&peer->p_timers); - } else { - m_freem(mc); - goto error; - } - } + nonce = htole64(t->t_nonce); /* Wire format is little endian. */ + memcpy(data->nonce, &nonce, sizeof(data->nonce)); /* A packet with length 0 is a keepalive packet */ if (m->m_pkthdr.len == 0) @@ -1863,6 +1672,7 @@ wg_decap(struct wg_softc *sc, struct mbuf *m) { struct wg_pkt_data *data; struct wg_peer *peer, *routed_peer; + struct noise_remote *remote; struct wg_tag *t; size_t plaintext_len; uint8_t version; @@ -1874,21 +1684,22 @@ wg_decap(struct wg_softc *sc, struct mbuf *m) plaintext_len = m->m_pkthdr.len - sizeof(struct wg_pkt_data); t = wg_tag_get(m); - peer = t->t_peer; + remote = noise_keypair_remote(t->t_keypair); + peer = noise_remote_arg(remote); + /* We can put the remote as we still hold a keypair ref */ + noise_remote_put(remote); memcpy(&nonce, data->nonce, sizeof(nonce)); - nonce = le64toh(nonce); /* Wire format is little endian. */ + t->t_nonce = le64toh(nonce); /* Wire format is little endian. */ - res = noise_remote_decrypt(&peer->p_remote, data->r_idx, nonce, - data->buf, plaintext_len); + res = noise_keypair_decrypt(t->t_keypair, t->t_nonce, data->buf, + plaintext_len); if (__predict_false(res)) { if (res == EINVAL) { goto error; } else if (res == ECONNRESET) { wg_timers_event_handshake_complete(&peer->p_timers); - } else if (res == ESTALE) { - wg_timers_event_want_initiation(&peer->p_timers); } else { panic("unexpected response: %d\n", res); } @@ -1994,6 +1805,7 @@ wg_deliver_out(struct wg_peer *peer) continue; } len = t->t_mbuf->m_pkthdr.len; + noise_keypair_put(t->t_keypair); ret = wg_send(peer->p_sc, &endpoint, t->t_mbuf); if (ret == 0) { @@ -2010,6 +1822,9 @@ wg_deliver_out(struct wg_peer *peer) wg_peer_get_endpoint(peer, &endpoint); } m_freem(m); + + if (noise_keep_key_fresh_send(peer->p_remote) == 0) + wg_timers_event_want_initiation(&peer->p_timers); } NET_EPOCH_EXIT(et); @@ -2039,6 +1854,14 @@ wg_deliver_in(struct wg_peer *peer) } MPASS(m == t->t_mbuf); + if (noise_keypair_nonce_check(t->t_keypair, t->t_nonce) != 0) { + if_inc_counter(ifp, IFCOUNTER_IERRORS, 1); + noise_keypair_put(t->t_keypair); + m_freem(m); + continue; + } + noise_keypair_put(t->t_keypair); + wg_timers_event_any_authenticated_packet_received( &peer->p_timers); wg_timers_event_any_authenticated_packet_traversal( @@ -2048,6 +1871,9 @@ wg_deliver_in(struct wg_peer *peer) if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN); + if (noise_keep_key_fresh_recv(peer->p_remote) == 0) + wg_timers_event_want_initiation(&peer->p_timers); + if (m->m_pkthdr.len == 0) { m_freem(m); continue; @@ -2129,10 +1955,11 @@ wg_queue_out(struct wg_peer *peer) struct buf_ring *parallel = peer->p_sc->sc_encap_ring; struct wg_queue *serial = &peer->p_encap_queue; struct wg_tag *t; + struct noise_keypair *keypair; struct mbufq staged; struct mbuf *m; - if (noise_remote_ready(&peer->p_remote) != 0) { + if ((keypair = noise_keypair_current(peer->p_remote)) == NULL) { if (wg_queue_len(&peer->p_stage_queue)) wg_timers_event_want_initiation(&peer->p_timers); return; @@ -2153,10 +1980,18 @@ wg_queue_out(struct wg_peer *peer) m_freem(m); continue; } - t->t_peer = peer; + if (noise_keypair_nonce_next(keypair, &t->t_nonce) != 0) { + /* TODO if we get here, it means we are about to + * overflow this keypair sending nonce. We should place + * this back on the staged queue. */ + m_freem(m); + continue; + } + t->t_keypair = noise_keypair_ref(keypair); mtx_lock(&serial->q_mtx); if (mbufq_enqueue(&serial->q, m) != 0) { m_freem(m); + noise_keypair_put(keypair); if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1); } else { m->m_flags |= M_ENQUEUED; @@ -2215,90 +2050,6 @@ wg_queue_purge(struct wg_queue *q) mtx_unlock(&q->q_mtx); } -/* TODO Indexes */ -static struct noise_remote * -wg_remote_get(struct wg_softc *sc, uint8_t public[NOISE_PUBLIC_KEY_LEN]) -{ - struct wg_peer *peer; - - if ((peer = wg_peer_lookup(sc, public)) == NULL) - return (NULL); - return (&peer->p_remote); -} - -static uint32_t -wg_index_set(struct wg_softc *sc, struct noise_remote *remote) -{ - struct wg_index *index, *iter; - struct wg_peer *peer; - uint32_t key; - - /* We can modify this without a lock as wg_index_set, wg_index_drop are - * guaranteed to be serialised (per remote). */ - peer = __containerof(remote, struct wg_peer, p_remote); - index = SLIST_FIRST(&peer->p_unused_index); - MPASS(index != NULL); - SLIST_REMOVE_HEAD(&peer->p_unused_index, i_unused_entry); - - index->i_value = remote; - - rw_wlock(&sc->sc_index_lock); -assign_id: - key = index->i_key = arc4random(); - key &= sc->sc_index_mask; - LIST_FOREACH(iter, &sc->sc_index[key], i_entry) - if (iter->i_key == index->i_key) - goto assign_id; - - LIST_INSERT_HEAD(&sc->sc_index[key], index, i_entry); - - rw_wunlock(&sc->sc_index_lock); - - /* Likewise, no need to lock for index here. */ - return index->i_key; -} - -static struct noise_remote * -wg_index_get(struct wg_softc *sc, uint32_t key0) -{ - struct wg_index *iter; - struct noise_remote *remote = NULL; - uint32_t key = key0 & sc->sc_index_mask; - - rw_enter_read(&sc->sc_index_lock); - LIST_FOREACH(iter, &sc->sc_index[key], i_entry) - if (iter->i_key == key0) { - remote = iter->i_value; - break; - } - rw_exit_read(&sc->sc_index_lock); - return remote; -} - -static void -wg_index_drop(struct wg_softc *sc, uint32_t key0) -{ - struct wg_index *iter; - struct wg_peer *peer = NULL; - uint32_t key = key0 & sc->sc_index_mask; - - rw_enter_write(&sc->sc_index_lock); - LIST_FOREACH(iter, &sc->sc_index[key], i_entry) - if (iter->i_key == key0) { - LIST_REMOVE(iter, i_entry); - break; - } - rw_exit_write(&sc->sc_index_lock); - - if (iter == NULL) - return; - - /* We expect a peer */ - peer = __containerof(iter->i_value, struct wg_peer, p_remote); - MPASS(peer != NULL); - SLIST_INSERT_HEAD(&peer->p_unused_index, iter, i_unused_entry); -} - static int wg_update_endpoint_addrs(struct wg_endpoint *e, const struct sockaddr *srcsa, struct ifnet *rcvif) @@ -2339,6 +2090,7 @@ wg_input(struct mbuf *m0, int offset, struct inpcb *inpcb, struct wg_softc *sc = _sc; struct mbuf *m; int pktlen, pkttype; + struct noise_keypair *keypair; struct noise_remote *remote; struct wg_tag *t; void *data; @@ -2406,21 +2158,23 @@ wg_input(struct mbuf *m0, int offset, struct inpcb *inpcb, } else if (pktlen >= sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN && pkttype == WG_PKT_DATA) { pkt_data = data; - remote = wg_index_get(sc, pkt_data->r_idx); - if (remote == NULL) { + keypair = noise_keypair_lookup(sc->sc_local, pkt_data->r_idx); + if (keypair == NULL) { if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1); m_freem(m); } else if (buf_ring_count(sc->sc_decap_ring) > MAX_QUEUED_PKT) { if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); + noise_keypair_put(keypair); m_freem(m); } else { - t->t_peer = __containerof(remote, struct wg_peer, - p_remote); + t->t_keypair = keypair; t->t_mbuf = NULL; t->t_done = 0; - wg_queue_in(t->t_peer, m); + remote = noise_keypair_remote(keypair); + wg_queue_in(noise_remote_arg(remote), m); wg_decrypt_dispatch(sc); + noise_remote_put(remote); } } else { free: @@ -2465,7 +2219,6 @@ wg_transmit(struct ifnet *ifp, struct mbuf *m) rc = EHOSTUNREACH; goto err; } - t->t_peer = peer; t->t_mbuf = NULL; t->t_done = 0; t->t_mtu = ifp->if_mtu; @@ -2494,10 +2247,11 @@ static int wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl) { uint8_t public[WG_KEY_SIZE]; - const void *pub_key; + const void *pub_key, *preshared_key = NULL; const struct sockaddr *endpoint; int err; size_t size; + struct noise_remote *remote; struct wg_peer *peer = NULL; bool need_insert = false; @@ -2510,17 +2264,16 @@ wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl) if (size != WG_KEY_SIZE) { return (EINVAL); } - if (noise_local_keys(&sc->sc_local, public, NULL) == 0 && + if (noise_local_keys(sc->sc_local, public, NULL) == 0 && bcmp(public, pub_key, WG_KEY_SIZE) == 0) { return (0); // Silently ignored; not actually a failure. } - peer = wg_peer_lookup(sc, pub_key); + if ((remote = noise_remote_lookup(sc->sc_local, pub_key)) != NULL) + peer = noise_remote_arg(remote); if (nvlist_exists_bool(nvl, "remove") && nvlist_get_bool(nvl, "remove")) { - if (peer != NULL) { - wg_hashtable_peer_remove(&sc->sc_hashtable, peer); + if (peer != NULL) wg_peer_destroy(peer); - } return (0); } if (nvlist_exists_bool(nvl, "replace-allowedips") && @@ -2530,14 +2283,9 @@ wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl) wg_aip_delete(&peer->p_sc->sc_aips, peer); } if (peer == NULL) { - if (sc->sc_peer_count >= MAX_PEERS_PER_IFACE) - return (E2BIG); - sc->sc_peer_count++; - need_insert = true; peer = wg_peer_alloc(sc); MPASS(peer != NULL); - noise_remote_init(&peer->p_remote, pub_key, &sc->sc_local); cookie_maker_init(&peer->p_cookie, pub_key); } if (nvlist_exists_binary(nvl, "endpoint")) { @@ -2549,14 +2297,13 @@ wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl) memcpy(&peer->p_endpoint.e_remote, endpoint, size); } if (nvlist_exists_binary(nvl, "preshared-key")) { - const void *key; - - key = nvlist_get_binary(nvl, "preshared-key", &size); + preshared_key = nvlist_get_binary(nvl, "preshared-key", &size); if (size != WG_KEY_SIZE) { err = EINVAL; goto out; } - noise_remote_set_psk(&peer->p_remote, key); + if (!need_insert) + noise_remote_set_psk(peer->p_remote, preshared_key); } if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) { uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval"); @@ -2606,7 +2353,11 @@ wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl) } } if (need_insert) { - wg_hashtable_peer_insert(&sc->sc_hashtable, peer); + if ((peer->p_remote = noise_remote_alloc(sc->sc_local, peer, + pub_key, preshared_key)) == NULL) + goto out; + TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry); + sc->sc_peers_num++; if (sc->sc_ifp->if_link_state == LINK_STATE_UP) wg_timers_enable(&peer->p_timers); } @@ -2672,19 +2423,18 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) goto out; } - if (noise_local_keys(&sc->sc_local, NULL, private) != 0 || + if (noise_local_keys(sc->sc_local, NULL, private) != 0 || timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) { - struct noise_local *local; struct wg_peer *peer; - struct wg_hashtable *ht = &sc->sc_hashtable; - bool has_identity; if (curve25519_generate_public(public, key)) { /* Peer conflict: remove conflicting peer. */ - if ((peer = wg_peer_lookup(sc, public)) != - NULL) { - wg_hashtable_peer_remove(ht, peer); + struct noise_remote *remote; + if ((remote = noise_remote_lookup(sc->sc_local, + public)) != NULL) { + peer = noise_remote_arg(remote); wg_peer_destroy(peer); + noise_remote_put(remote); } } @@ -2692,21 +2442,12 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) * Set the private key and invalidate all existing * handshakes. */ - local = &sc->sc_local; - noise_local_lock_identity(local); /* Note: we might be removing the private key. */ - has_identity = noise_local_set_private(local, key) == 0; - mtx_lock(&ht->h_mtx); - CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) { - noise_remote_precompute(&peer->p_remote); - wg_timers_event_reset_handshake_last_sent( - &peer->p_timers); - noise_remote_expire_current(&peer->p_remote); - } - mtx_unlock(&ht->h_mtx); - cookie_checker_update(&sc->sc_cookie, - has_identity ? public : NULL); - noise_local_unlock_identity(local); + noise_local_private(sc->sc_local, key); + if (noise_local_keys(sc->sc_local, NULL, NULL) == 0) + cookie_checker_update(&sc->sc_cookie, public); + else + cookie_checker_update(&sc->sc_cookie, NULL); } } if (nvlist_exists_number(nvl, "user-cookie")) { @@ -2742,7 +2483,9 @@ out: static int wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) { - uint8_t public_key[WG_KEY_SIZE] = { 0 }, preshared_key[NOISE_SYMMETRIC_KEY_LEN] = { 0 }; + uint8_t public_key[WG_KEY_SIZE] = { 0 }; + uint8_t private_key[WG_KEY_SIZE] = { 0 }; + uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN] = { 0 }; nvlist_t *nvl, *nvl_peer, *nvl_aip, **nvl_peers, **nvl_aips; size_t size, peer_count, aip_count, i, j; struct wg_timespec64 ts64; @@ -2761,16 +2504,17 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port); if (sc->sc_socket.so_user_cookie != 0) nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie); - if (sc->sc_local.l_has_identity) { - nvlist_add_binary(nvl, "public-key", sc->sc_local.l_public, WG_KEY_SIZE); + if (noise_local_keys(sc->sc_local, public_key, private_key) == 0) { + nvlist_add_binary(nvl, "public-key", public_key, WG_KEY_SIZE); if (wgc_privileged(sc)) - nvlist_add_binary(nvl, "private-key", sc->sc_local.l_private, WG_KEY_SIZE); + nvlist_add_binary(nvl, "private-key", private_key, WG_KEY_SIZE); + explicit_bzero(private_key, sizeof(private_key)); } - peer_count = sc->sc_hashtable.h_num_peers; + peer_count = sc->sc_peers_num; if (peer_count) { nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO); i = 0; - CK_LIST_FOREACH(peer, &sc->sc_hashtable.h_peers_list, p_entry) { + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { if (i >= peer_count) panic("peers changed from under us"); @@ -2780,7 +2524,7 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) goto err_peer; } - (void)noise_remote_keys(&peer->p_remote, public_key, preshared_key); + (void)noise_remote_keys(peer->p_remote, public_key, preshared_key); nvlist_add_binary(nvl_peer, "public-key", public_key, sizeof(public_key)); if (wgc_privileged(sc)) nvlist_add_binary(nvl_peer, "preshared-key", preshared_key, sizeof(preshared_key)); @@ -2941,7 +2685,6 @@ out: static int wg_up(struct wg_softc *sc) { - struct wg_hashtable *ht = &sc->sc_hashtable; struct ifnet *ifp = sc->sc_ifp; struct wg_peer *peer; int rc = EBUSY; @@ -2959,13 +2702,10 @@ wg_up(struct wg_softc *sc) rc = wg_socket_init(sc, sc->sc_socket.so_port); if (rc == 0) { - mtx_lock(&ht->h_mtx); - CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) { + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { wg_timers_enable(&peer->p_timers); wg_queue_out(peer); } - mtx_unlock(&ht->h_mtx); - if_link_state_change(sc->sc_ifp, LINK_STATE_UP); } else { ifp->if_drv_flags &= ~IFF_DRV_RUNNING; @@ -2979,7 +2719,6 @@ out: static void wg_down(struct wg_softc *sc) { - struct wg_hashtable *ht = &sc->sc_hashtable; struct ifnet *ifp = sc->sc_ifp; struct wg_peer *peer; @@ -2990,21 +2729,17 @@ wg_down(struct wg_softc *sc) } ifp->if_drv_flags &= ~IFF_DRV_RUNNING; - mtx_lock(&ht->h_mtx); - CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) { + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { wg_queue_purge(&peer->p_stage_queue); wg_timers_disable(&peer->p_timers); } - mtx_unlock(&ht->h_mtx); mbufq_drain(&sc->sc_handshake_queue); - mtx_lock(&ht->h_mtx); - CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) { - noise_remote_clear(&peer->p_remote); - wg_timers_event_reset_handshake_last_sent(&peer->p_timers); + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { + noise_remote_handshake_clear(peer->p_remote); + noise_remote_keypairs_clear(peer->p_remote); } - mtx_unlock(&ht->h_mtx); if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN); wg_socket_uninit(sc); @@ -3045,7 +2780,6 @@ wg_clone_create(struct if_clone *ifc, int unit, caddr_t params) { struct wg_softc *sc; struct ifnet *ifp; - struct noise_upcall noise_upcall; sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO); sc->sc_ucred = crhold(curthread->td_ucred); @@ -3054,14 +2788,15 @@ wg_clone_create(struct if_clone *ifc, int unit, caddr_t params) ifp->if_softc = sc; if_initname(ifp, wgname, unit); - noise_upcall.u_arg = sc; - noise_upcall.u_remote_get = - (struct noise_remote *(*)(void *, uint8_t *))wg_remote_get; - noise_upcall.u_index_set = - (uint32_t (*)(void *, struct noise_remote *))wg_index_set; - noise_upcall.u_index_drop = - (void (*)(void *, uint32_t))wg_index_drop; - noise_local_init(&sc->sc_local, &noise_upcall); + TAILQ_INIT(&sc->sc_peers); + sc->sc_peers_num = 0; + + if ((sc->sc_local = noise_local_alloc(sc)) == NULL) { + free(sc, M_WG); + return ENOMEM; + } + + /* TODO check checker_init return value */ cookie_checker_init(&sc->sc_cookie, ratelimit_zone); sc->sc_socket.so_port = 0; @@ -3071,8 +2806,6 @@ wg_clone_create(struct if_clone *ifc, int unit, caddr_t params) mbufq_init(&sc->sc_handshake_queue, MAX_QUEUED_HANDSHAKES); sx_init(&sc->sc_lock, "wg softc lock"); - rw_init(&sc->sc_index_lock, "wg index lock"); - sc->sc_peer_count = 0; sc->sc_encap_ring = buf_ring_alloc(MAX_QUEUED_PKT, M_WG, M_WAITOK, NULL); sc->sc_decap_ring = buf_ring_alloc(MAX_QUEUED_PKT, M_WG, M_WAITOK, NULL); GROUPTASK_INIT(&sc->sc_handshake, 0, @@ -3080,8 +2813,6 @@ wg_clone_create(struct if_clone *ifc, int unit, caddr_t params) taskqgroup_attach(qgroup_wg_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation"); crypto_taskq_setup(sc); - wg_hashtable_init(&sc->sc_hashtable); - sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF, &sc->sc_index_mask); wg_aip_init(&sc->sc_aips); if_setmtu(ifp, ETHERMTU - 80); @@ -3108,6 +2839,15 @@ wg_clone_create(struct if_clone *ifc, int unit, caddr_t params) } static void +wg_clone_deferred_free(struct noise_local *l) +{ + struct wg_softc *sc = noise_local_arg(l); + + free(sc, M_WG); + atomic_add_int(&clone_count, -1); +} + +static void wg_clone_destroy(struct ifnet *ifp) { struct wg_softc *sc = ifp->if_softc; @@ -3144,24 +2884,19 @@ wg_clone_destroy(struct ifnet *ifp) epoch_drain_callbacks(net_epoch_preempt); sx_xunlock(&sc->sc_lock); sx_destroy(&sc->sc_lock); - rw_destroy(&sc->sc_index_lock); taskqgroup_detach(qgroup_wg_tqg, &sc->sc_handshake); crypto_taskq_destroy(sc); buf_ring_free(sc->sc_encap_ring, M_WG); buf_ring_free(sc->sc_decap_ring, M_WG); wg_aip_destroy(&sc->sc_aips); - wg_hashtable_destroy(&sc->sc_hashtable); if (cred != NULL) crfree(cred); if_detach(sc->sc_ifp); if_free(sc->sc_ifp); - /* Ensure any local/private keys are cleaned up */ - explicit_bzero(sc, sizeof(*sc)); - free(sc, M_WG); - atomic_add_int(&clone_count, -1); + noise_local_free(sc->sc_local, wg_clone_deferred_free); } static void |