summaryrefslogtreecommitdiffstats
path: root/sys/net/if_wg.c
diff options
context:
space:
mode:
authorMatt Dunwoodie <ncon@noconroy.net>2021-04-03 04:05:08 +1100
committerMatt Dunwoodie <ncon@noconroy.net>2021-04-13 15:47:30 +1000
commitacce302ce645ccc1a3f7f46fb2c399c611e623d8 (patch)
tree0598e731f3bd5b52f6c00b98353346f1f58594c1 /sys/net/if_wg.c
parentAdd refcnt_take_if_gt() (diff)
downloadwireguard-openbsd-acce302ce645ccc1a3f7f46fb2c399c611e623d8.tar.xz
wireguard-openbsd-acce302ce645ccc1a3f7f46fb2c399c611e623d8.zip
Use SMR for wg_noise
While the largest change here is to use SMR for wg_noise, this was motivated by other deficiencies in the module. Primarily, the nonce operations should be performed in serial (wg_queue_out, wg_deliver_in) and not parallel (wg_encap, wg_decap). This also brings in a lock-free encrypt and decrypt path, which is nice. I suppose other improvements are that local, remote and keypair structs are opaque, so no more reaching in and fiddling with things. Unfortunately, these changes make abuse of the API easier (such as calling noise_keypair_encrypt on a keypair retrieved with noise_keypair_lookup (instead of noise_keypair_current) as they have different checks). Additionally, we have to trust that the nonce passed to noise_keypair_encrypt is non repeating (retrieved with noise_keypair_nonce_next), and noise_keypair_nonce_check is valid on received nonces. One area that could use a little bit more adjustment is the *_free functions. They are used to call a function once it is safe to free a parent datastructure (one holding struct noise_{local,remote} *). This is currently used for lifetimes in the system and allows a consumer of wg_noise to opaquely manage lifetimes based on the reference counting of noise, remote and keypair. It is fine for now, but maybe revisit later.
Diffstat (limited to 'sys/net/if_wg.c')
-rw-r--r--sys/net/if_wg.c562
1 files changed, 166 insertions, 396 deletions
diff --git a/sys/net/if_wg.c b/sys/net/if_wg.c
index 807d08a7b6e..dddce93f54f 100644
--- a/sys/net/if_wg.c
+++ b/sys/net/if_wg.c
@@ -49,8 +49,6 @@
#include <netinet/udp.h>
#include <netinet/in_pcb.h>
-#include <crypto/siphash.h>
-
#define DEFAULT_MTU 1420
#define MAX_STAGED_PKT 128
@@ -59,11 +57,6 @@
#define MAX_QUEUED_HANDSHAKES 4096
-#define HASHTABLE_PEER_SIZE (1 << 11)
-#define HASHTABLE_INDEX_SIZE (1 << 13)
-#define MAX_PEERS_PER_IFACE (1 << 20)
-
-#define REKEY_TIMEOUT 5
#define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */
#define KEEPALIVE_TIMEOUT 10
#define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT)
@@ -135,13 +128,6 @@ struct wg_endpoint {
} e_local;
};
-struct wg_index {
- LIST_ENTRY(wg_index) i_entry;
- SLIST_ENTRY(wg_index) i_unused_entry;
- uint32_t i_key;
- struct noise_remote *i_value;
-};
-
struct wg_timers {
/* t_lock is for blocking wg_timers_event_* when setting t_disabled. */
struct rwlock t_lock;
@@ -156,7 +142,6 @@ struct wg_timers {
struct timeout t_persistent_keepalive;
struct mutex t_handshake_mtx;
- struct timespec t_handshake_last_sent; /* nanouptime */
struct timespec t_handshake_complete; /* nanotime */
int t_handshake_retries;
};
@@ -177,7 +162,8 @@ struct wg_packet {
SIMPLEQ_ENTRY(wg_packet) p_serial;
SIMPLEQ_ENTRY(wg_packet) p_parallel;
struct wg_endpoint p_endpoint;
- struct wg_peer *p_peer;
+ struct noise_keypair *p_keypair;
+ uint64_t p_nonce;
struct mbuf *p_mbuf;
int p_mtu;
enum wg_ring_state {
@@ -194,12 +180,11 @@ struct wg_queue {
};
struct wg_peer {
- LIST_ENTRY(wg_peer) p_pubkey_entry;
- TAILQ_ENTRY(wg_peer) p_seq_entry;
+ TAILQ_ENTRY(wg_peer) p_entry;
uint64_t p_id;
struct wg_softc *p_sc;
- struct noise_remote p_remote;
+ struct noise_remote *p_remote;
struct cookie_maker p_cookie;
struct wg_timers p_timers;
@@ -220,9 +205,6 @@ struct wg_peer {
struct wg_queue p_encap_serial;
struct wg_queue p_decap_serial;
- SLIST_HEAD(,wg_index) p_unused_index;
- struct wg_index p_index[3];
-
LIST_HEAD(,wg_aip) p_aip;
SLIST_ENTRY(wg_peer) p_start_list;
@@ -231,13 +213,14 @@ struct wg_peer {
struct wg_softc {
struct ifnet sc_if;
- SIPHASH_KEY sc_secret;
struct rwlock sc_lock;
- struct noise_local sc_local;
+ struct noise_local *sc_local;
struct cookie_checker sc_cookie;
in_port_t sc_udp_port;
int sc_udp_rtable;
+ TAILQ_HEAD(,wg_peer) sc_peers;
+ size_t sc_peer_num;
struct rwlock sc_so_lock;
struct socket *sc_so4;
@@ -251,16 +234,6 @@ struct wg_softc {
struct art_root *sc_aip6;
#endif
- struct rwlock sc_peer_lock;
- size_t sc_peer_num;
- LIST_HEAD(,wg_peer) *sc_peer;
- TAILQ_HEAD(,wg_peer) sc_peer_seq;
- u_long sc_peer_mask;
-
- struct mutex sc_index_mtx;
- LIST_HEAD(,wg_index) *sc_index;
- u_long sc_index_mask;
-
struct task sc_handshake;
struct mbuf_queue sc_handshake_queue;
@@ -273,8 +246,6 @@ struct wg_softc {
struct wg_peer *
wg_peer_create(struct wg_softc *, uint8_t[WG_KEY_SIZE],
uint8_t[WG_KEY_SIZE]);
-struct wg_peer *
- wg_peer_lookup(struct wg_softc *, const uint8_t[WG_KEY_SIZE]);
void wg_peer_destroy(struct wg_peer *);
void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *);
void wg_peer_set_sockaddr(struct wg_peer *, struct sockaddr *);
@@ -319,7 +290,6 @@ void wg_timers_event_handshake_responded(struct wg_timers *);
void wg_timers_event_handshake_complete(struct wg_timers *);
void wg_timers_event_session_derived(struct wg_timers *);
void wg_timers_event_want_initiation(struct wg_timers *);
-void wg_timers_event_reset_handshake_last_sent(struct wg_timers *);
void wg_timers_run_send_initiation(void *, int);
void wg_timers_run_retry_handshake(void *);
@@ -354,14 +324,6 @@ struct wg_packet *
struct wg_packet *
wg_queue_parallel_dequeue(struct wg_queue *);
-struct noise_remote *
- wg_remote_get(void *, uint8_t[NOISE_PUBLIC_KEY_LEN]);
-uint32_t
- wg_index_set(void *, struct noise_remote *);
-struct noise_remote *
- wg_index_get(void *, uint32_t);
-void wg_index_drop(void *, uint32_t);
-
struct mbuf *
wg_input(void *, struct mbuf *, struct ip *, struct ip6_hdr *, void *,
int);
@@ -398,21 +360,15 @@ wg_peer_create(struct wg_softc *sc, uint8_t public[WG_KEY_SIZE],
uint8_t psk[WG_KEY_SIZE])
{
struct wg_peer *peer;
- uint64_t idx;
rw_assert_wrlock(&sc->sc_lock);
- if (sc->sc_peer_num >= MAX_PEERS_PER_IFACE)
- return NULL;
-
if ((peer = pool_get(&wg_peer_pool, PR_NOWAIT)) == NULL)
return NULL;
peer->p_id = peer_counter++;
peer->p_sc = sc;
- noise_remote_init(&peer->p_remote, public, &sc->sc_local);
- noise_remote_set_psk(&peer->p_remote, psk);
cookie_maker_init(&peer->p_cookie, public);
wg_timers_init(&peer->p_timers);
@@ -433,51 +389,28 @@ wg_peer_create(struct wg_softc *sc, uint8_t public[WG_KEY_SIZE],
wg_queue_init(&peer->p_encap_serial);
wg_queue_init(&peer->p_decap_serial);
- SLIST_INIT(&peer->p_unused_index);
- SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[0],
- i_unused_entry);
- SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[1],
- i_unused_entry);
- SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[2],
- i_unused_entry);
-
LIST_INIT(&peer->p_aip);
peer->p_start_onlist = 0;
- idx = SipHash24(&sc->sc_secret, public, WG_KEY_SIZE);
- idx &= sc->sc_peer_mask;
+ if ((peer->p_remote = noise_remote_alloc(sc->sc_local, peer, public, psk)) == NULL) {
+ pool_put(&wg_peer_pool, peer);
+ return NULL;
+ }
- rw_enter_write(&sc->sc_peer_lock);
- LIST_INSERT_HEAD(&sc->sc_peer[idx], peer, p_pubkey_entry);
- TAILQ_INSERT_TAIL(&sc->sc_peer_seq, peer, p_seq_entry);
+ DPRINTF(sc, "Peer %llu created\n", peer->p_id);
+ TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry);
sc->sc_peer_num++;
- rw_exit_write(&sc->sc_peer_lock);
- DPRINTF(sc, "Peer %llu created\n", peer->p_id);
return peer;
}
-struct wg_peer *
-wg_peer_lookup(struct wg_softc *sc, const uint8_t public[WG_KEY_SIZE])
+void
+wg_peer_free(struct noise_remote *r)
{
- uint8_t peer_key[WG_KEY_SIZE];
- struct wg_peer *peer;
- uint64_t idx;
-
- idx = SipHash24(&sc->sc_secret, public, WG_KEY_SIZE);
- idx &= sc->sc_peer_mask;
-
- rw_enter_read(&sc->sc_peer_lock);
- LIST_FOREACH(peer, &sc->sc_peer[idx], p_pubkey_entry) {
- noise_remote_keys(&peer->p_remote, peer_key, NULL);
- if (timingsafe_bcmp(peer_key, public, WG_KEY_SIZE) == 0)
- goto done;
- }
- peer = NULL;
-done:
- rw_exit_read(&sc->sc_peer_lock);
- return peer;
+ struct wg_peer *peer;
+ peer = noise_remote_arg(r);
+ pool_put(&wg_peer_pool, peer);
}
void
@@ -488,46 +421,16 @@ wg_peer_destroy(struct wg_peer *peer)
rw_assert_wrlock(&sc->sc_lock);
- /*
- * Remove peer from the pubkey hashtable and disable all timeouts.
- * After this, and flushing wg_handshake_taskq, then no more handshakes
- * can be started.
- */
- rw_enter_write(&sc->sc_peer_lock);
- LIST_REMOVE(peer, p_pubkey_entry);
- TAILQ_REMOVE(&sc->sc_peer_seq, peer, p_seq_entry);
+ TAILQ_REMOVE(&sc->sc_peers, peer, p_entry);
sc->sc_peer_num--;
- rw_exit_write(&sc->sc_peer_lock);
-
- wg_timers_disable(&peer->p_timers);
-
- taskq_barrier(wg_handshake_taskq);
- /*
- * Now we drop all allowed ips, to drop all outgoing packets to the
- * peer. Then drop all the indexes to drop all incoming packets to the
- * peer. Then we can flush if_snd, wg_crypt_taskq and then nettq to
- * ensure no more references to the peer exist.
- */
LIST_FOREACH_SAFE(aip, &peer->p_aip, a_entry, taip)
wg_aip_remove(sc, peer, &aip->a_data);
- noise_remote_clear(&peer->p_remote);
-
- NET_LOCK();
- while (!ifq_empty(&sc->sc_if.if_snd)) {
- NET_UNLOCK();
- tsleep_nsec(sc, PWAIT, "wg_ifq", 1000);
- NET_LOCK();
- }
- NET_UNLOCK();
-
- taskq_barrier(wg_crypt_taskq);
- taskq_barrier(net_tq(sc->sc_if.if_index));
+ wg_timers_disable(&peer->p_timers);
+ noise_remote_free(peer->p_remote, wg_peer_free);
DPRINTF(sc, "Peer %llu destroyed\n", peer->p_id);
- explicit_bzero(peer, sizeof(*peer));
- pool_put(&wg_peer_pool, peer);
}
void
@@ -917,8 +820,6 @@ wg_tag_get(struct mbuf *m)
* tx: response, rx: response
* wg_timers_event_want_initiation:
* tx: data failed, old keys expiring
- * wg_timers_event_reset_handshake_last_sent:
- * anytime we may immediately want a new handshake
*/
void
wg_timers_init(struct wg_timers *t)
@@ -987,28 +888,6 @@ wg_timers_get_last_handshake(struct wg_timers *t, struct timespec *time)
mtx_leave(&t->t_handshake_mtx);
}
-int
-wg_timers_expired_handshake_last_sent(struct wg_timers *t)
-{
- struct timespec uptime;
- struct timespec expire = { .tv_sec = REKEY_TIMEOUT, .tv_nsec = 0 };
-
- getnanouptime(&uptime);
- timespecadd(&t->t_handshake_last_sent, &expire, &expire);
- return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0;
-}
-
-int
-wg_timers_check_handshake_last_sent(struct wg_timers *t)
-{
- int ret;
- mtx_enter(&t->t_handshake_mtx);
- if ((ret = wg_timers_expired_handshake_last_sent(t)) == ETIMEDOUT)
- getnanouptime(&t->t_handshake_last_sent);
- mtx_leave(&t->t_handshake_mtx);
- return ret;
-}
-
void
wg_timers_event_data_sent(struct wg_timers *t)
{
@@ -1070,14 +949,6 @@ wg_timers_event_handshake_initiated(struct wg_timers *t)
}
void
-wg_timers_event_handshake_responded(struct wg_timers *t)
-{
- mtx_enter(&t->t_handshake_mtx);
- getnanouptime(&t->t_handshake_last_sent);
- mtx_leave(&t->t_handshake_mtx);
-}
-
-void
wg_timers_event_handshake_complete(struct wg_timers *t)
{
rw_enter_read(&t->t_lock);
@@ -1111,21 +982,13 @@ wg_timers_event_want_initiation(struct wg_timers *t)
}
void
-wg_timers_event_reset_handshake_last_sent(struct wg_timers *t)
-{
- mtx_enter(&t->t_handshake_mtx);
- t->t_handshake_last_sent.tv_sec -= (REKEY_TIMEOUT + 1);
- mtx_leave(&t->t_handshake_mtx);
-}
-
-void
wg_timers_run_send_initiation(void *_t, int is_retry)
{
struct wg_timers *t = _t;
struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers);
if (!is_retry)
t->t_handshake_retries = 0;
- if (wg_timers_expired_handshake_last_sent(t) == ETIMEDOUT)
+ if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT)
task_add(wg_handshake_taskq, &peer->p_send_initiation);
}
@@ -1227,15 +1090,13 @@ wg_send_initiation(void *_peer)
struct wg_peer *peer = _peer;
struct wg_pkt_initiation pkt;
- if (wg_timers_check_handshake_last_sent(&peer->p_timers) != ETIMEDOUT)
+ if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue, pkt.es,
+ pkt.ets) != 0)
return;
DPRINTF(peer->p_sc, "Sending handshake initiation to peer %llu\n",
peer->p_id);
- if (noise_create_initiation(&peer->p_remote, &pkt.s_idx, pkt.ue, pkt.es,
- pkt.ets) != 0)
- return;
pkt.t = WG_PKT_INITIATION;
cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
sizeof(pkt)-sizeof(pkt.m));
@@ -1251,16 +1112,14 @@ wg_send_response(struct wg_peer *peer)
DPRINTF(peer->p_sc, "Sending handshake response to peer %llu\n",
peer->p_id);
- if (noise_create_response(&peer->p_remote, &pkt.s_idx, &pkt.r_idx,
+ if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx,
pkt.ue, pkt.en) != 0)
return;
- if (noise_remote_begin_session(&peer->p_remote) != 0)
- return;
wg_timers_event_session_derived(&peer->p_timers);
pkt.t = WG_PKT_RESPONSE;
cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
sizeof(pkt)-sizeof(pkt.m));
- wg_timers_event_handshake_responded(&peer->p_timers);
+
wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
}
@@ -1302,7 +1161,6 @@ wg_send_keepalive(void *_peer)
}
pkt->p_mbuf = m;
- pkt->p_peer = peer;
pkt->p_mtu = 0;
m->m_pkthdr.ph_cookie = pkt;
@@ -1318,7 +1176,7 @@ void
wg_peer_clear_secrets(void *_peer)
{
struct wg_peer *peer = _peer;
- noise_remote_clear(&peer->p_remote);
+ noise_remote_keypairs_clear(peer->p_remote);
}
void
@@ -1330,7 +1188,8 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
struct wg_endpoint *e;
struct wg_peer *peer;
struct mbuf *m;
- struct noise_remote *remote;
+ struct noise_keypair *keypair;
+ struct noise_remote *remote = NULL;
int res, underload = 0;
static struct timeval wg_last_underload; /* microuptime */
@@ -1371,13 +1230,13 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
panic("unexpected response: %d\n", res);
}
- if (noise_consume_initiation(&sc->sc_local, &remote,
+ if (noise_consume_initiation(sc->sc_local, &remote,
init->s_idx, init->ue, init->es, init->ets) != 0) {
DPRINTF(sc, "Invalid handshake initiation\n");
goto error;
}
- peer = CONTAINER_OF(remote, struct wg_peer, p_remote);
+ peer = noise_remote_arg(remote);
DPRINTF(sc, "Receiving handshake initiation from peer %llu\n",
peer->p_id);
@@ -1405,45 +1264,39 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
panic("unexpected response: %d\n", res);
}
- if ((remote = wg_index_get(sc, resp->r_idx)) == NULL) {
- DPRINTF(sc, "Unknown handshake response\n");
- goto error;
- }
-
- peer = CONTAINER_OF(remote, struct wg_peer, p_remote);
-
- if (noise_consume_response(remote, resp->s_idx, resp->r_idx,
- resp->ue, resp->en) != 0) {
+ if (noise_consume_response(sc->sc_local, &remote,
+ resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) {
DPRINTF(sc, "Invalid handshake response\n");
goto error;
}
- DPRINTF(sc, "Receiving handshake response from peer %llu\n",
- peer->p_id);
+ peer = noise_remote_arg(remote);
+ DPRINTF(sc, "Receiving handshake response from peer %llu\n", peer->p_id);
wg_peer_set_endpoint(peer, e);
- if (noise_remote_begin_session(&peer->p_remote) == 0) {
- wg_timers_event_session_derived(&peer->p_timers);
- wg_timers_event_handshake_complete(&peer->p_timers);
- }
+ wg_timers_event_session_derived(&peer->p_timers);
+ wg_timers_event_handshake_complete(&peer->p_timers);
break;
case WG_PKT_COOKIE:
cook = mtod(m, struct wg_pkt_cookie *);
- if ((remote = wg_index_get(sc, cook->r_idx)) == NULL) {
- DPRINTF(sc, "Unknown cookie index\n");
- goto error;
+ if ((remote = noise_remote_index_lookup(sc->sc_local, cook->r_idx)) == NULL) {
+ if ((keypair = noise_keypair_lookup(sc->sc_local, cook->r_idx)) == NULL) {
+ DPRINTF(sc, "Unknown cookie index\n");
+ goto error;
+ }
+ remote = noise_keypair_remote(keypair);
+ noise_keypair_put(keypair);
}
- peer = CONTAINER_OF(remote, struct wg_peer, p_remote);
+ peer = noise_remote_arg(remote);
if (cookie_maker_consume_payload(&peer->p_cookie,
- cook->nonce, cook->ec) != 0) {
+ cook->nonce, cook->ec) == 0)
+ DPRINTF(sc, "Receiving cookie response\n");
+ else
DPRINTF(sc, "Could not decrypt cookie response\n");
- goto error;
- }
- DPRINTF(sc, "Receiving cookie response\n");
goto error;
default:
panic("invalid packet in handshake queue");
@@ -1453,6 +1306,8 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers);
wg_peer_counters_add(peer, 0, m->m_pkthdr.len);
error:
+ if (remote != NULL)
+ noise_remote_put(remote);
m_freem(m);
pool_put(&wg_packet_pool, pkt);
}
@@ -1484,12 +1339,16 @@ wg_handshake_worker(void *_sc)
void
wg_encap(struct wg_softc *sc, struct wg_packet *pkt)
{
- struct wg_pkt_data data;
+ struct noise_remote *remote;
+ struct wg_pkt_data *data;
struct wg_peer *peer;
struct mbuf *m, *ms;
- int res, pad, off, len;
+ uint32_t idx;
+ int pad, off, len;
- peer = pkt->p_peer;
+ remote = noise_keypair_remote(pkt->p_keypair);
+ peer = noise_remote_arg(remote);
+ noise_remote_put(remote);
m = pkt->p_mbuf;
/* Calculate what padding we need to add then limit it to the mtu of
@@ -1506,7 +1365,7 @@ wg_encap(struct wg_softc *sc, struct wg_packet *pkt)
bzero(mtod(ms, uint8_t *) + off, pad);
}
- /* TODO teach noise_remote_encrypt about mbufs. Currently we have to
+ /* TODO teach noise_keypair_encrypt about mbufs. Currently we have to
* resort to m_pullup to create an encryptable buffer. */
len = m->m_pkthdr.len;
if (m_makespace(m, len, NOISE_AUTHTAG_LEN, &off) == NULL)
@@ -1515,30 +1374,20 @@ wg_encap(struct wg_softc *sc, struct wg_packet *pkt)
goto error;
/* Do encryption */
- res = noise_remote_encrypt(&peer->p_remote, &data.r_idx, &data.nonce,
- mtod(m, uint8_t *), len);
-
- if (__predict_false(res == EINVAL)) {
- goto error_free;
- } else if (__predict_false(res == ESTALE)) {
- wg_timers_event_want_initiation(&peer->p_timers);
- } else if (__predict_false(res != 0)) {
- panic("unexpected result: %d\n", res);
- }
+ noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, mtod(m, uint8_t *), len);
/* A packet with length 0 is a keepalive packet */
if (__predict_false(len == 0))
- DPRINTF(sc, "Sending keepalive packet to peer %llu\n",
- peer->p_id);
+ DPRINTF(sc, "Sending keepalive packet to peer %llu\n", peer->p_id);
/* Put header into packet */
if ((m = m_prepend(m, sizeof(struct wg_pkt_data), M_NOWAIT)) == NULL)
goto error;
- data.t = WG_PKT_DATA;
- data.nonce = htole64(data.nonce);
- memcpy(mtod(m, void *), &data, sizeof(struct wg_pkt_data));
-
+ data = mtod(m, struct wg_pkt_data *);
+ data->t = WG_PKT_DATA;
+ data->r_idx = idx;
+ data->nonce = htole64(pkt->p_nonce);
/*
* We would count ifc_opackets, ifc_obytes of m here, except if_snd
* already does that for us, so no need to worry about it.
@@ -1564,6 +1413,7 @@ error:
void
wg_decap(struct wg_softc *sc, struct wg_packet *pkt)
{
+ struct noise_remote *remote;
struct wg_pkt_data data;
struct wg_peer *peer, *allowed_peer;
struct mbuf *m;
@@ -1571,7 +1421,9 @@ wg_decap(struct wg_softc *sc, struct wg_packet *pkt)
struct ip6_hdr *ip6;
int res, len;
- peer = pkt->p_peer;
+ remote = noise_keypair_remote(pkt->p_keypair);
+ peer = noise_remote_arg(remote);
+ noise_remote_put(remote);
m = pkt->p_mbuf;
len = m->m_pkthdr.len;
@@ -1579,28 +1431,25 @@ wg_decap(struct wg_softc *sc, struct wg_packet *pkt)
memcpy(&data, mtod(m, void *), sizeof(struct wg_pkt_data));
m_adj(m, sizeof(struct wg_pkt_data));
- /* TODO teach noise_remote_decrypt about mbufs. Currently we have to
+ /* TODO teach noise_keypair_decrypt about mbufs. Currently we have to
* resort to m_pullup to create an decryptable buffer. */
if ((m = m_pullup(m, m->m_pkthdr.len)) == NULL) {
goto error;
}
- res = noise_remote_decrypt(&peer->p_remote, data.r_idx,
- le64toh(data.nonce), mtod(m, void *), m->m_pkthdr.len);
+ pkt->p_nonce = letoh64(data.nonce);
+ res = noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, mtod(m, void *), m->m_pkthdr.len);
if (__predict_false(res == EINVAL)) {
goto error_free;
} else if (__predict_false(res == ECONNRESET)) {
wg_timers_event_handshake_complete(&peer->p_timers);
- } else if (__predict_false(res == ESTALE)) {
- wg_timers_event_want_initiation(&peer->p_timers);
} else if (__predict_false(res != 0)) {
panic("unexpected response: %d\n", res);
}
m_adj(m, -NOISE_AUTHTAG_LEN);
- wg_peer_set_endpoint(peer, &pkt->p_endpoint);
wg_peer_counters_add(peer, 0, len);
counters_pkt(sc->sc_if.if_counters, ifc_ipackets, ifc_ibytes,
@@ -1735,8 +1584,11 @@ wg_deliver_out(void *_peer)
m_freem(m);
}
+ noise_keypair_put(pkt->p_keypair);
pool_put(&wg_packet_pool, pkt);
}
+ if (noise_keep_key_fresh_send(peer->p_remote))
+ wg_timers_event_want_initiation(&peer->p_timers);
}
void
@@ -1750,14 +1602,20 @@ wg_deliver_in(void *_peer)
while ((pkt = wg_queue_serial_dequeue(&peer->p_decap_serial)) != NULL) {
m = pkt->p_mbuf;
if (pkt->p_state == WG_PACKET_CRYPTED) {
+ if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0) {
+ m_freem(m);
+ goto put;
+ }
+
wg_timers_event_any_authenticated_packet_received(
&peer->p_timers);
wg_timers_event_any_authenticated_packet_traversal(
&peer->p_timers);
+ wg_peer_set_endpoint(peer, &pkt->p_endpoint);
if (m->m_pkthdr.len == 0) {
m_freem(m);
- continue;
+ goto put;
}
#if NBPFILTER > 0
@@ -1778,13 +1636,15 @@ wg_deliver_in(void *_peer)
NET_UNLOCK();
wg_timers_event_data_received(&peer->p_timers);
-
} else {
m_freem(m);
}
-
+put:
+ noise_keypair_put(pkt->p_keypair);
pool_put(&wg_packet_pool, pkt);
}
+ if (noise_keep_key_fresh_recv(peer->p_remote))
+ wg_timers_event_want_initiation(&peer->p_timers);
}
void
@@ -1808,6 +1668,7 @@ wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_pack
} else {
mtx_leave(&serial->q_mtx);
m_freem(pkt->p_mbuf);
+ noise_keypair_put(pkt->p_keypair);
pool_put(&wg_packet_pool, pkt);
return ENOBUFS;
}
@@ -1837,11 +1698,12 @@ wg_queue_in(struct wg_softc *sc, struct wg_peer *peer, struct wg_packet *pkt)
void
wg_queue_out(struct wg_softc *sc, struct wg_peer *peer)
{
+ struct noise_keypair *keypair;
struct wg_packet *pkt;
struct mbuf_list ml;
struct mbuf *m;
- if (noise_remote_ready(&peer->p_remote) != 0) {
+ if ((keypair = noise_keypair_current(peer->p_remote)) == NULL) {
wg_timers_event_want_initiation(&peer->p_timers);
return;
}
@@ -1850,11 +1712,21 @@ wg_queue_out(struct wg_softc *sc, struct wg_peer *peer)
while ((m = ml_dequeue(&ml)) != NULL) {
pkt = m->m_pkthdr.ph_cookie;
+ pkt->p_keypair = noise_keypair_ref(keypair);
+
+ if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0) {
+ ml_purge(&ml);
+ pool_put(&wg_packet_pool, pkt);
+ m_freem(m);
+ break;
+ }
if (wg_queue_both(&sc->sc_encap_parallel, &peer->p_encap_serial, pkt) != 0)
counters_inc(sc->sc_if.if_counters, ifc_oqdrops);
}
+ noise_keypair_put(keypair);
+
task_add(wg_crypt_taskq, &sc->sc_encap);
}
@@ -1886,91 +1758,6 @@ wg_queue_parallel_dequeue(struct wg_queue *parallel)
return pkt;
}
-struct noise_remote *
-wg_remote_get(void *_sc, uint8_t public[NOISE_PUBLIC_KEY_LEN])
-{
- struct wg_peer *peer;
- struct wg_softc *sc = _sc;
- if ((peer = wg_peer_lookup(sc, public)) == NULL)
- return NULL;
- return &peer->p_remote;
-}
-
-uint32_t
-wg_index_set(void *_sc, struct noise_remote *remote)
-{
- struct wg_peer *peer;
- struct wg_softc *sc = _sc;
- struct wg_index *index, *iter;
- uint32_t key;
-
- /*
- * We can modify this without a lock as wg_index_set, wg_index_drop are
- * guaranteed to be serialised (per remote).
- */
- peer = CONTAINER_OF(remote, struct wg_peer, p_remote);
- index = SLIST_FIRST(&peer->p_unused_index);
- KASSERT(index != NULL);
- SLIST_REMOVE_HEAD(&peer->p_unused_index, i_unused_entry);
-
- index->i_value = remote;
-
- mtx_enter(&sc->sc_index_mtx);
-assign_id:
- key = index->i_key = arc4random();
- key &= sc->sc_index_mask;
- LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
- if (iter->i_key == index->i_key)
- goto assign_id;
-
- LIST_INSERT_HEAD(&sc->sc_index[key], index, i_entry);
-
- mtx_leave(&sc->sc_index_mtx);
-
- /* Likewise, no need to lock for index here. */
- return index->i_key;
-}
-
-struct noise_remote *
-wg_index_get(void *_sc, uint32_t key0)
-{
- struct wg_softc *sc = _sc;
- struct wg_index *iter;
- struct noise_remote *remote = NULL;
- uint32_t key = key0 & sc->sc_index_mask;
-
- mtx_enter(&sc->sc_index_mtx);
- LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
- if (iter->i_key == key0) {
- remote = iter->i_value;
- break;
- }
- mtx_leave(&sc->sc_index_mtx);
- return remote;
-}
-
-void
-wg_index_drop(void *_sc, uint32_t key0)
-{
- struct wg_softc *sc = _sc;
- struct wg_index *iter;
- struct wg_peer *peer = NULL;
- uint32_t key = key0 & sc->sc_index_mask;
-
- mtx_enter(&sc->sc_index_mtx);
- LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
- if (iter->i_key == key0) {
- LIST_REMOVE(iter, i_entry);
- break;
- }
- mtx_leave(&sc->sc_index_mtx);
-
- /* We expect an index match */
- KASSERT(iter != NULL);
- peer = CONTAINER_OF(iter->i_value, struct wg_peer, p_remote);
- SLIST_INSERT_HEAD(&peer->p_unused_index, iter, i_unused_entry);
-}
-
struct mbuf *
wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6,
void *_uh, int hlen)
@@ -1978,7 +1765,6 @@ wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6,
struct wg_pkt_data *data;
struct noise_remote *remote;
struct wg_packet *pkt;
- struct wg_peer *peer;
struct wg_softc *sc = _sc;
struct udphdr *uh = _uh;
@@ -2046,13 +1832,13 @@ wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6,
}
data = mtod(m, struct wg_pkt_data *);
- if ((remote = wg_index_get(sc, data->r_idx)) == NULL)
+ if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL)
goto error_mbuf;
- peer = CONTAINER_OF(remote, struct wg_peer, p_remote);
- pkt->p_peer = peer;
+ remote = noise_keypair_remote(pkt->p_keypair);
pkt->p_mbuf = m;
- wg_queue_in(sc, peer, pkt);
+ wg_queue_in(sc, noise_remote_arg(remote), pkt);
+ noise_remote_put(remote);
} else {
counters_inc(sc->sc_if.if_counters, ifc_ierrors);
goto error_mbuf;
@@ -2090,7 +1876,6 @@ wg_qstart(struct ifqueue *ifq)
peer = t->t_peer;
pkt->p_mbuf = m;
- pkt->p_peer = peer;
pkt->p_mtu = t->t_mtu;
m->m_pkthdr.ph_cookie = pkt;
@@ -2205,12 +1990,14 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data)
struct wg_peer *peer, *tpeer;
struct wg_aip *aip, *taip;
+ struct noise_remote *remote;
+
in_port_t port;
int rtable;
uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE];
size_t i, j;
- int ret, has_identity;
+ int ret;
if ((ret = suser(curproc)) != 0)
return ret;
@@ -2222,27 +2009,23 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data)
goto error;
if (iface_o.i_flags & WG_INTERFACE_REPLACE_PEERS)
- TAILQ_FOREACH_SAFE(peer, &sc->sc_peer_seq, p_seq_entry, tpeer)
+ TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer)
wg_peer_destroy(peer);
if (iface_o.i_flags & WG_INTERFACE_HAS_PRIVATE &&
- (noise_local_keys(&sc->sc_local, NULL, private) ||
+ (noise_local_keys(sc->sc_local, NULL, private) ||
timingsafe_bcmp(private, iface_o.i_private, WG_KEY_SIZE))) {
if (curve25519_generate_public(public, iface_o.i_private)) {
- if ((peer = wg_peer_lookup(sc, public)) != NULL)
- wg_peer_destroy(peer);
- }
- noise_local_lock_identity(&sc->sc_local);
- has_identity = noise_local_set_private(&sc->sc_local,
- iface_o.i_private);
- TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) {
- noise_remote_precompute(&peer->p_remote);
- wg_timers_event_reset_handshake_last_sent(&peer->p_timers);
- noise_remote_expire_current(&peer->p_remote);
+ if ((remote = noise_remote_lookup(sc->sc_local, public)) != NULL) {
+ wg_peer_destroy(noise_remote_arg(remote));
+ noise_remote_put(remote);
+ }
}
- cookie_checker_update(&sc->sc_cookie,
- has_identity == 0 ? public : NULL);
- noise_local_unlock_identity(&sc->sc_local);
+ noise_local_private(sc->sc_local, iface_o.i_private);
+ if (noise_local_keys(sc->sc_local, public, NULL) == 0)
+ cookie_checker_update(&sc->sc_cookie, public);
+ else
+ cookie_checker_update(&sc->sc_cookie, NULL);
}
if (iface_o.i_flags & WG_INTERFACE_HAS_PORT)
@@ -2256,7 +2039,7 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data)
rtable = sc->sc_udp_rtable;
if (port != sc->sc_udp_port || rtable != sc->sc_udp_rtable) {
- TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry)
+ TAILQ_FOREACH(peer, &sc->sc_peers, p_entry)
wg_peer_clear_src(peer);
if (sc->sc_if.if_flags & IFF_RUNNING)
@@ -2287,12 +2070,12 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data)
}
/* Get local public and check that peer key doesn't match */
- if (noise_local_keys(&sc->sc_local, public, NULL) == 0 &&
+ if (noise_local_keys(sc->sc_local, public, NULL) == 0 &&
bcmp(public, peer_o.p_public, WG_KEY_SIZE) == 0)
goto next_peer;
/* Lookup peer, or create if it doesn't exist */
- if ((peer = wg_peer_lookup(sc, peer_o.p_public)) == NULL) {
+ if ((remote = noise_remote_lookup(sc->sc_local, peer_o.p_public)) == NULL) {
/* If we want to delete, no need creating a new one.
* Also, don't create a new one if we only want to
* update. */
@@ -2307,6 +2090,9 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data)
ret = ENOMEM;
goto error;
}
+ } else {
+ peer = noise_remote_arg(remote);
+ noise_remote_put(remote);
}
/* Remove peer and continue if specified */
@@ -2319,7 +2105,7 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data)
wg_peer_set_sockaddr(peer, &peer_o.p_sa);
if (peer_o.p_flags & WG_PEER_HAS_PSK)
- noise_remote_set_psk(&peer->p_remote, peer_o.p_psk);
+ noise_remote_set_psk(peer->p_remote, peer_o.p_psk);
if (peer_o.p_flags & WG_PEER_HAS_PKA)
wg_timers_set_persistent_keepalive(&peer->p_timers,
@@ -2394,7 +2180,7 @@ wg_ioctl_get(struct wg_softc *sc, struct wg_data_io *data)
if (!is_suser)
goto copy_out_iface;
- if (noise_local_keys(&sc->sc_local, iface_o.i_public,
+ if (noise_local_keys(sc->sc_local, iface_o.i_public,
iface_o.i_private) == 0) {
iface_o.i_flags |= WG_INTERFACE_HAS_PUBLIC;
iface_o.i_flags |= WG_INTERFACE_HAS_PRIVATE;
@@ -2407,12 +2193,12 @@ wg_ioctl_get(struct wg_softc *sc, struct wg_data_io *data)
peer_count = 0;
peer_p = &iface_p->i_peers[0];
- TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) {
+ TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
bzero(&peer_o, sizeof(peer_o));
peer_o.p_flags = WG_PEER_HAS_PUBLIC;
peer_o.p_protocol_version = 1;
- if (noise_remote_keys(&peer->p_remote, peer_o.p_public,
+ if (noise_remote_keys(peer->p_remote, peer_o.p_public,
peer_o.p_psk) == 0)
peer_o.p_flags |= WG_PEER_HAS_PSK;
@@ -2530,7 +2316,7 @@ wg_up(struct wg_softc *sc)
*/
ret = wg_bind(sc, &sc->sc_udp_port, &sc->sc_udp_rtable);
if (ret == 0) {
- TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry)
+ TAILQ_FOREACH(peer, &sc->sc_peers, p_entry)
wg_timers_enable(&peer->p_timers);
}
rw_exit_write(&sc->sc_lock);
@@ -2558,15 +2344,15 @@ wg_down(struct wg_softc *sc)
* that isn't granularly locked.
*/
rw_enter_read(&sc->sc_lock);
- TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) {
+ TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
mq_purge(&peer->p_stage_queue);
wg_timers_disable(&peer->p_timers);
}
taskq_barrier(wg_handshake_taskq);
- TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) {
- noise_remote_clear(&peer->p_remote);
- wg_timers_event_reset_handshake_last_sent(&peer->p_timers);
+ TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
+ noise_remote_handshake_clear(peer->p_remote);
+ noise_remote_keypairs_clear(peer->p_remote);
}
wg_unbind(sc);
@@ -2579,7 +2365,6 @@ wg_clone_create(struct if_clone *ifc, int unit)
{
struct ifnet *ifp;
struct wg_softc *sc;
- struct noise_upcall local_upcall;
KERNEL_ASSERT_LOCKED();
@@ -2604,20 +2389,15 @@ wg_clone_create(struct if_clone *ifc, int unit)
if ((sc = malloc(sizeof(*sc), M_DEVBUF, M_NOWAIT | M_ZERO)) == NULL)
goto ret_00;
- local_upcall.u_arg = sc;
- local_upcall.u_remote_get = wg_remote_get;
- local_upcall.u_index_set = wg_index_set;
- local_upcall.u_index_drop = wg_index_drop;
-
- /* sc_if is initialised after everything else */
- arc4random_buf(&sc->sc_secret, sizeof(sc->sc_secret));
-
rw_init(&sc->sc_lock, "wg");
- noise_local_init(&sc->sc_local, &local_upcall);
- if (cookie_checker_init(&sc->sc_cookie, &wg_ratelimit_pool) != 0)
+ if ((sc->sc_local = noise_local_alloc(sc)) == NULL)
goto ret_01;
+ if (cookie_checker_init(&sc->sc_cookie, &wg_ratelimit_pool) != 0)
+ goto ret_02;
sc->sc_udp_port = 0;
sc->sc_udp_rtable = 0;
+ TAILQ_INIT(&sc->sc_peers);
+ sc->sc_peer_num = 0;
rw_init(&sc->sc_so_lock, "wg_so");
sc->sc_so4 = NULL;
@@ -2627,24 +2407,11 @@ wg_clone_create(struct if_clone *ifc, int unit)
sc->sc_aip_num = 0;
if ((sc->sc_aip4 = art_alloc(0, 32, 0)) == NULL)
- goto ret_02;
+ goto ret_03;
#ifdef INET6
if ((sc->sc_aip6 = art_alloc(0, 128, 0)) == NULL)
- goto ret_03;
-#endif
-
- rw_init(&sc->sc_peer_lock, "wg_peer");
- sc->sc_peer_num = 0;
- if ((sc->sc_peer = hashinit(HASHTABLE_PEER_SIZE, M_DEVBUF,
- M_NOWAIT, &sc->sc_peer_mask)) == NULL)
goto ret_04;
-
- TAILQ_INIT(&sc->sc_peer_seq);
-
- mtx_init(&sc->sc_index_mtx, IPL_NET);
- if ((sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF,
- M_NOWAIT, &sc->sc_index_mask)) == NULL)
- goto ret_05;
+#endif
task_set(&sc->sc_handshake, wg_handshake_worker, sc);
mq_init(&sc->sc_handshake_queue, MAX_QUEUED_HANDSHAKES, IPL_NET);
@@ -2684,22 +2451,42 @@ wg_clone_create(struct if_clone *ifc, int unit)
DPRINTF(sc, "Interface created\n");
return 0;
-ret_05:
- hashfree(sc->sc_peer, HASHTABLE_PEER_SIZE, M_DEVBUF);
-ret_04:
+
#ifdef INET6
free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6));
-ret_03:
+ret_04:
#endif
free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4));
-ret_02:
+ret_03:
cookie_checker_deinit(&sc->sc_cookie);
+ret_02:
+ noise_local_put(sc->sc_local);
ret_01:
free(sc, M_DEVBUF, sizeof(*sc));
ret_00:
return ENOBUFS;
}
+void
+wg_clone_free(struct noise_local *l)
+{
+ struct wg_softc *sc = noise_local_arg(l);
+ wg_counter--;
+ if (wg_counter == 0) {
+ KASSERT(wg_handshake_taskq != NULL && wg_crypt_taskq != NULL);
+ taskq_destroy(wg_handshake_taskq);
+ taskq_destroy(wg_crypt_taskq);
+ wg_handshake_taskq = NULL;
+ wg_crypt_taskq = NULL;
+ }
+#ifdef INET6
+ free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6));
+#endif
+ free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4));
+ cookie_checker_deinit(&sc->sc_cookie);
+ free(sc, M_DEVBUF, sizeof(*sc));
+}
+
int
wg_clone_destroy(struct ifnet *ifp)
{
@@ -2709,33 +2496,16 @@ wg_clone_destroy(struct ifnet *ifp)
KERNEL_ASSERT_LOCKED();
rw_enter_write(&sc->sc_lock);
- TAILQ_FOREACH_SAFE(peer, &sc->sc_peer_seq, p_seq_entry, tpeer)
+ TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer)
wg_peer_destroy(peer);
+
rw_exit_write(&sc->sc_lock);
wg_unbind(sc);
if_detach(ifp);
- wg_counter--;
- if (wg_counter == 0) {
- KASSERT(wg_handshake_taskq != NULL && wg_crypt_taskq != NULL);
- taskq_destroy(wg_handshake_taskq);
- taskq_destroy(wg_crypt_taskq);
- wg_handshake_taskq = NULL;
- wg_crypt_taskq = NULL;
- }
-
+ noise_local_free(sc->sc_local, wg_clone_free);
DPRINTF(sc, "Destroyed interface\n");
-
- hashfree(sc->sc_index, HASHTABLE_INDEX_SIZE, M_DEVBUF);
- hashfree(sc->sc_peer, HASHTABLE_PEER_SIZE, M_DEVBUF);
-#ifdef INET6
- free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6));
-#endif
- free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4));
- cookie_checker_deinit(&sc->sc_cookie);
- noise_local_deinit(&sc->sc_local);
- free(sc, M_DEVBUF, sizeof(*sc));
return 0;
}