From acce302ce645ccc1a3f7f46fb2c399c611e623d8 Mon Sep 17 00:00:00 2001 From: Matt Dunwoodie Date: Sat, 3 Apr 2021 04:05:08 +1100 Subject: Use SMR for wg_noise While the largest change here is to use SMR for wg_noise, this was motivated by other deficiencies in the module. Primarily, the nonce operations should be performed in serial (wg_queue_out, wg_deliver_in) and not parallel (wg_encap, wg_decap). This also brings in a lock-free encrypt and decrypt path, which is nice. I suppose other improvements are that local, remote and keypair structs are opaque, so no more reaching in and fiddling with things. Unfortunately, these changes make abuse of the API easier (such as calling noise_keypair_encrypt on a keypair retrieved with noise_keypair_lookup (instead of noise_keypair_current) as they have different checks). Additionally, we have to trust that the nonce passed to noise_keypair_encrypt is non repeating (retrieved with noise_keypair_nonce_next), and noise_keypair_nonce_check is valid on received nonces. One area that could use a little bit more adjustment is the *_free functions. They are used to call a function once it is safe to free a parent datastructure (one holding struct noise_{local,remote} *). This is currently used for lifetimes in the system and allows a consumer of wg_noise to opaquely manage lifetimes based on the reference counting of noise, remote and keypair. It is fine for now, but maybe revisit later. --- sys/net/if_wg.c | 562 ++++++------------ sys/net/wg_noise.c | 1644 +++++++++++++++++++++++++++------------------------- sys/net/wg_noise.h | 200 +++---- 3 files changed, 1091 insertions(+), 1315 deletions(-) diff --git a/sys/net/if_wg.c b/sys/net/if_wg.c index 807d08a7b6e..dddce93f54f 100644 --- a/sys/net/if_wg.c +++ b/sys/net/if_wg.c @@ -49,8 +49,6 @@ #include #include -#include - #define DEFAULT_MTU 1420 #define MAX_STAGED_PKT 128 @@ -59,11 +57,6 @@ #define MAX_QUEUED_HANDSHAKES 4096 -#define HASHTABLE_PEER_SIZE (1 << 11) -#define HASHTABLE_INDEX_SIZE (1 << 13) -#define MAX_PEERS_PER_IFACE (1 << 20) - -#define REKEY_TIMEOUT 5 #define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */ #define KEEPALIVE_TIMEOUT 10 #define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT) @@ -135,13 +128,6 @@ struct wg_endpoint { } e_local; }; -struct wg_index { - LIST_ENTRY(wg_index) i_entry; - SLIST_ENTRY(wg_index) i_unused_entry; - uint32_t i_key; - struct noise_remote *i_value; -}; - struct wg_timers { /* t_lock is for blocking wg_timers_event_* when setting t_disabled. */ struct rwlock t_lock; @@ -156,7 +142,6 @@ struct wg_timers { struct timeout t_persistent_keepalive; struct mutex t_handshake_mtx; - struct timespec t_handshake_last_sent; /* nanouptime */ struct timespec t_handshake_complete; /* nanotime */ int t_handshake_retries; }; @@ -177,7 +162,8 @@ struct wg_packet { SIMPLEQ_ENTRY(wg_packet) p_serial; SIMPLEQ_ENTRY(wg_packet) p_parallel; struct wg_endpoint p_endpoint; - struct wg_peer *p_peer; + struct noise_keypair *p_keypair; + uint64_t p_nonce; struct mbuf *p_mbuf; int p_mtu; enum wg_ring_state { @@ -194,12 +180,11 @@ struct wg_queue { }; struct wg_peer { - LIST_ENTRY(wg_peer) p_pubkey_entry; - TAILQ_ENTRY(wg_peer) p_seq_entry; + TAILQ_ENTRY(wg_peer) p_entry; uint64_t p_id; struct wg_softc *p_sc; - struct noise_remote p_remote; + struct noise_remote *p_remote; struct cookie_maker p_cookie; struct wg_timers p_timers; @@ -220,9 +205,6 @@ struct wg_peer { struct wg_queue p_encap_serial; struct wg_queue p_decap_serial; - SLIST_HEAD(,wg_index) p_unused_index; - struct wg_index p_index[3]; - LIST_HEAD(,wg_aip) p_aip; SLIST_ENTRY(wg_peer) p_start_list; @@ -231,13 +213,14 @@ struct wg_peer { struct wg_softc { struct ifnet sc_if; - SIPHASH_KEY sc_secret; struct rwlock sc_lock; - struct noise_local sc_local; + struct noise_local *sc_local; struct cookie_checker sc_cookie; in_port_t sc_udp_port; int sc_udp_rtable; + TAILQ_HEAD(,wg_peer) sc_peers; + size_t sc_peer_num; struct rwlock sc_so_lock; struct socket *sc_so4; @@ -251,16 +234,6 @@ struct wg_softc { struct art_root *sc_aip6; #endif - struct rwlock sc_peer_lock; - size_t sc_peer_num; - LIST_HEAD(,wg_peer) *sc_peer; - TAILQ_HEAD(,wg_peer) sc_peer_seq; - u_long sc_peer_mask; - - struct mutex sc_index_mtx; - LIST_HEAD(,wg_index) *sc_index; - u_long sc_index_mask; - struct task sc_handshake; struct mbuf_queue sc_handshake_queue; @@ -273,8 +246,6 @@ struct wg_softc { struct wg_peer * wg_peer_create(struct wg_softc *, uint8_t[WG_KEY_SIZE], uint8_t[WG_KEY_SIZE]); -struct wg_peer * - wg_peer_lookup(struct wg_softc *, const uint8_t[WG_KEY_SIZE]); void wg_peer_destroy(struct wg_peer *); void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *); void wg_peer_set_sockaddr(struct wg_peer *, struct sockaddr *); @@ -319,7 +290,6 @@ void wg_timers_event_handshake_responded(struct wg_timers *); void wg_timers_event_handshake_complete(struct wg_timers *); void wg_timers_event_session_derived(struct wg_timers *); void wg_timers_event_want_initiation(struct wg_timers *); -void wg_timers_event_reset_handshake_last_sent(struct wg_timers *); void wg_timers_run_send_initiation(void *, int); void wg_timers_run_retry_handshake(void *); @@ -354,14 +324,6 @@ struct wg_packet * struct wg_packet * wg_queue_parallel_dequeue(struct wg_queue *); -struct noise_remote * - wg_remote_get(void *, uint8_t[NOISE_PUBLIC_KEY_LEN]); -uint32_t - wg_index_set(void *, struct noise_remote *); -struct noise_remote * - wg_index_get(void *, uint32_t); -void wg_index_drop(void *, uint32_t); - struct mbuf * wg_input(void *, struct mbuf *, struct ip *, struct ip6_hdr *, void *, int); @@ -398,21 +360,15 @@ wg_peer_create(struct wg_softc *sc, uint8_t public[WG_KEY_SIZE], uint8_t psk[WG_KEY_SIZE]) { struct wg_peer *peer; - uint64_t idx; rw_assert_wrlock(&sc->sc_lock); - if (sc->sc_peer_num >= MAX_PEERS_PER_IFACE) - return NULL; - if ((peer = pool_get(&wg_peer_pool, PR_NOWAIT)) == NULL) return NULL; peer->p_id = peer_counter++; peer->p_sc = sc; - noise_remote_init(&peer->p_remote, public, &sc->sc_local); - noise_remote_set_psk(&peer->p_remote, psk); cookie_maker_init(&peer->p_cookie, public); wg_timers_init(&peer->p_timers); @@ -433,51 +389,28 @@ wg_peer_create(struct wg_softc *sc, uint8_t public[WG_KEY_SIZE], wg_queue_init(&peer->p_encap_serial); wg_queue_init(&peer->p_decap_serial); - SLIST_INIT(&peer->p_unused_index); - SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[0], - i_unused_entry); - SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[1], - i_unused_entry); - SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[2], - i_unused_entry); - LIST_INIT(&peer->p_aip); peer->p_start_onlist = 0; - idx = SipHash24(&sc->sc_secret, public, WG_KEY_SIZE); - idx &= sc->sc_peer_mask; + if ((peer->p_remote = noise_remote_alloc(sc->sc_local, peer, public, psk)) == NULL) { + pool_put(&wg_peer_pool, peer); + return NULL; + } - rw_enter_write(&sc->sc_peer_lock); - LIST_INSERT_HEAD(&sc->sc_peer[idx], peer, p_pubkey_entry); - TAILQ_INSERT_TAIL(&sc->sc_peer_seq, peer, p_seq_entry); + DPRINTF(sc, "Peer %llu created\n", peer->p_id); + TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry); sc->sc_peer_num++; - rw_exit_write(&sc->sc_peer_lock); - DPRINTF(sc, "Peer %llu created\n", peer->p_id); return peer; } -struct wg_peer * -wg_peer_lookup(struct wg_softc *sc, const uint8_t public[WG_KEY_SIZE]) +void +wg_peer_free(struct noise_remote *r) { - uint8_t peer_key[WG_KEY_SIZE]; - struct wg_peer *peer; - uint64_t idx; - - idx = SipHash24(&sc->sc_secret, public, WG_KEY_SIZE); - idx &= sc->sc_peer_mask; - - rw_enter_read(&sc->sc_peer_lock); - LIST_FOREACH(peer, &sc->sc_peer[idx], p_pubkey_entry) { - noise_remote_keys(&peer->p_remote, peer_key, NULL); - if (timingsafe_bcmp(peer_key, public, WG_KEY_SIZE) == 0) - goto done; - } - peer = NULL; -done: - rw_exit_read(&sc->sc_peer_lock); - return peer; + struct wg_peer *peer; + peer = noise_remote_arg(r); + pool_put(&wg_peer_pool, peer); } void @@ -488,46 +421,16 @@ wg_peer_destroy(struct wg_peer *peer) rw_assert_wrlock(&sc->sc_lock); - /* - * Remove peer from the pubkey hashtable and disable all timeouts. - * After this, and flushing wg_handshake_taskq, then no more handshakes - * can be started. - */ - rw_enter_write(&sc->sc_peer_lock); - LIST_REMOVE(peer, p_pubkey_entry); - TAILQ_REMOVE(&sc->sc_peer_seq, peer, p_seq_entry); + TAILQ_REMOVE(&sc->sc_peers, peer, p_entry); sc->sc_peer_num--; - rw_exit_write(&sc->sc_peer_lock); - - wg_timers_disable(&peer->p_timers); - - taskq_barrier(wg_handshake_taskq); - /* - * Now we drop all allowed ips, to drop all outgoing packets to the - * peer. Then drop all the indexes to drop all incoming packets to the - * peer. Then we can flush if_snd, wg_crypt_taskq and then nettq to - * ensure no more references to the peer exist. - */ LIST_FOREACH_SAFE(aip, &peer->p_aip, a_entry, taip) wg_aip_remove(sc, peer, &aip->a_data); - noise_remote_clear(&peer->p_remote); - - NET_LOCK(); - while (!ifq_empty(&sc->sc_if.if_snd)) { - NET_UNLOCK(); - tsleep_nsec(sc, PWAIT, "wg_ifq", 1000); - NET_LOCK(); - } - NET_UNLOCK(); - - taskq_barrier(wg_crypt_taskq); - taskq_barrier(net_tq(sc->sc_if.if_index)); + wg_timers_disable(&peer->p_timers); + noise_remote_free(peer->p_remote, wg_peer_free); DPRINTF(sc, "Peer %llu destroyed\n", peer->p_id); - explicit_bzero(peer, sizeof(*peer)); - pool_put(&wg_peer_pool, peer); } void @@ -917,8 +820,6 @@ wg_tag_get(struct mbuf *m) * tx: response, rx: response * wg_timers_event_want_initiation: * tx: data failed, old keys expiring - * wg_timers_event_reset_handshake_last_sent: - * anytime we may immediately want a new handshake */ void wg_timers_init(struct wg_timers *t) @@ -987,28 +888,6 @@ wg_timers_get_last_handshake(struct wg_timers *t, struct timespec *time) mtx_leave(&t->t_handshake_mtx); } -int -wg_timers_expired_handshake_last_sent(struct wg_timers *t) -{ - struct timespec uptime; - struct timespec expire = { .tv_sec = REKEY_TIMEOUT, .tv_nsec = 0 }; - - getnanouptime(&uptime); - timespecadd(&t->t_handshake_last_sent, &expire, &expire); - return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0; -} - -int -wg_timers_check_handshake_last_sent(struct wg_timers *t) -{ - int ret; - mtx_enter(&t->t_handshake_mtx); - if ((ret = wg_timers_expired_handshake_last_sent(t)) == ETIMEDOUT) - getnanouptime(&t->t_handshake_last_sent); - mtx_leave(&t->t_handshake_mtx); - return ret; -} - void wg_timers_event_data_sent(struct wg_timers *t) { @@ -1069,14 +948,6 @@ wg_timers_event_handshake_initiated(struct wg_timers *t) rw_exit_read(&t->t_lock); } -void -wg_timers_event_handshake_responded(struct wg_timers *t) -{ - mtx_enter(&t->t_handshake_mtx); - getnanouptime(&t->t_handshake_last_sent); - mtx_leave(&t->t_handshake_mtx); -} - void wg_timers_event_handshake_complete(struct wg_timers *t) { @@ -1110,14 +981,6 @@ wg_timers_event_want_initiation(struct wg_timers *t) rw_exit_read(&t->t_lock); } -void -wg_timers_event_reset_handshake_last_sent(struct wg_timers *t) -{ - mtx_enter(&t->t_handshake_mtx); - t->t_handshake_last_sent.tv_sec -= (REKEY_TIMEOUT + 1); - mtx_leave(&t->t_handshake_mtx); -} - void wg_timers_run_send_initiation(void *_t, int is_retry) { @@ -1125,7 +988,7 @@ wg_timers_run_send_initiation(void *_t, int is_retry) struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers); if (!is_retry) t->t_handshake_retries = 0; - if (wg_timers_expired_handshake_last_sent(t) == ETIMEDOUT) + if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT) task_add(wg_handshake_taskq, &peer->p_send_initiation); } @@ -1227,15 +1090,13 @@ wg_send_initiation(void *_peer) struct wg_peer *peer = _peer; struct wg_pkt_initiation pkt; - if (wg_timers_check_handshake_last_sent(&peer->p_timers) != ETIMEDOUT) + if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue, pkt.es, + pkt.ets) != 0) return; DPRINTF(peer->p_sc, "Sending handshake initiation to peer %llu\n", peer->p_id); - if (noise_create_initiation(&peer->p_remote, &pkt.s_idx, pkt.ue, pkt.es, - pkt.ets) != 0) - return; pkt.t = WG_PKT_INITIATION; cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, sizeof(pkt)-sizeof(pkt.m)); @@ -1251,16 +1112,14 @@ wg_send_response(struct wg_peer *peer) DPRINTF(peer->p_sc, "Sending handshake response to peer %llu\n", peer->p_id); - if (noise_create_response(&peer->p_remote, &pkt.s_idx, &pkt.r_idx, + if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx, pkt.ue, pkt.en) != 0) return; - if (noise_remote_begin_session(&peer->p_remote) != 0) - return; wg_timers_event_session_derived(&peer->p_timers); pkt.t = WG_PKT_RESPONSE; cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, sizeof(pkt)-sizeof(pkt.m)); - wg_timers_event_handshake_responded(&peer->p_timers); + wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt)); } @@ -1302,7 +1161,6 @@ wg_send_keepalive(void *_peer) } pkt->p_mbuf = m; - pkt->p_peer = peer; pkt->p_mtu = 0; m->m_pkthdr.ph_cookie = pkt; @@ -1318,7 +1176,7 @@ void wg_peer_clear_secrets(void *_peer) { struct wg_peer *peer = _peer; - noise_remote_clear(&peer->p_remote); + noise_remote_keypairs_clear(peer->p_remote); } void @@ -1330,7 +1188,8 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) struct wg_endpoint *e; struct wg_peer *peer; struct mbuf *m; - struct noise_remote *remote; + struct noise_keypair *keypair; + struct noise_remote *remote = NULL; int res, underload = 0; static struct timeval wg_last_underload; /* microuptime */ @@ -1371,13 +1230,13 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) panic("unexpected response: %d\n", res); } - if (noise_consume_initiation(&sc->sc_local, &remote, + if (noise_consume_initiation(sc->sc_local, &remote, init->s_idx, init->ue, init->es, init->ets) != 0) { DPRINTF(sc, "Invalid handshake initiation\n"); goto error; } - peer = CONTAINER_OF(remote, struct wg_peer, p_remote); + peer = noise_remote_arg(remote); DPRINTF(sc, "Receiving handshake initiation from peer %llu\n", peer->p_id); @@ -1405,45 +1264,39 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) panic("unexpected response: %d\n", res); } - if ((remote = wg_index_get(sc, resp->r_idx)) == NULL) { - DPRINTF(sc, "Unknown handshake response\n"); - goto error; - } - - peer = CONTAINER_OF(remote, struct wg_peer, p_remote); - - if (noise_consume_response(remote, resp->s_idx, resp->r_idx, - resp->ue, resp->en) != 0) { + if (noise_consume_response(sc->sc_local, &remote, + resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) { DPRINTF(sc, "Invalid handshake response\n"); goto error; } - DPRINTF(sc, "Receiving handshake response from peer %llu\n", - peer->p_id); + peer = noise_remote_arg(remote); + DPRINTF(sc, "Receiving handshake response from peer %llu\n", peer->p_id); wg_peer_set_endpoint(peer, e); - if (noise_remote_begin_session(&peer->p_remote) == 0) { - wg_timers_event_session_derived(&peer->p_timers); - wg_timers_event_handshake_complete(&peer->p_timers); - } + wg_timers_event_session_derived(&peer->p_timers); + wg_timers_event_handshake_complete(&peer->p_timers); break; case WG_PKT_COOKIE: cook = mtod(m, struct wg_pkt_cookie *); - if ((remote = wg_index_get(sc, cook->r_idx)) == NULL) { - DPRINTF(sc, "Unknown cookie index\n"); - goto error; + if ((remote = noise_remote_index_lookup(sc->sc_local, cook->r_idx)) == NULL) { + if ((keypair = noise_keypair_lookup(sc->sc_local, cook->r_idx)) == NULL) { + DPRINTF(sc, "Unknown cookie index\n"); + goto error; + } + remote = noise_keypair_remote(keypair); + noise_keypair_put(keypair); } - peer = CONTAINER_OF(remote, struct wg_peer, p_remote); + peer = noise_remote_arg(remote); if (cookie_maker_consume_payload(&peer->p_cookie, - cook->nonce, cook->ec) != 0) { + cook->nonce, cook->ec) == 0) + DPRINTF(sc, "Receiving cookie response\n"); + else DPRINTF(sc, "Could not decrypt cookie response\n"); - goto error; - } - DPRINTF(sc, "Receiving cookie response\n"); goto error; default: panic("invalid packet in handshake queue"); @@ -1453,6 +1306,8 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers); wg_peer_counters_add(peer, 0, m->m_pkthdr.len); error: + if (remote != NULL) + noise_remote_put(remote); m_freem(m); pool_put(&wg_packet_pool, pkt); } @@ -1484,12 +1339,16 @@ wg_handshake_worker(void *_sc) void wg_encap(struct wg_softc *sc, struct wg_packet *pkt) { - struct wg_pkt_data data; + struct noise_remote *remote; + struct wg_pkt_data *data; struct wg_peer *peer; struct mbuf *m, *ms; - int res, pad, off, len; + uint32_t idx; + int pad, off, len; - peer = pkt->p_peer; + remote = noise_keypair_remote(pkt->p_keypair); + peer = noise_remote_arg(remote); + noise_remote_put(remote); m = pkt->p_mbuf; /* Calculate what padding we need to add then limit it to the mtu of @@ -1506,7 +1365,7 @@ wg_encap(struct wg_softc *sc, struct wg_packet *pkt) bzero(mtod(ms, uint8_t *) + off, pad); } - /* TODO teach noise_remote_encrypt about mbufs. Currently we have to + /* TODO teach noise_keypair_encrypt about mbufs. Currently we have to * resort to m_pullup to create an encryptable buffer. */ len = m->m_pkthdr.len; if (m_makespace(m, len, NOISE_AUTHTAG_LEN, &off) == NULL) @@ -1515,30 +1374,20 @@ wg_encap(struct wg_softc *sc, struct wg_packet *pkt) goto error; /* Do encryption */ - res = noise_remote_encrypt(&peer->p_remote, &data.r_idx, &data.nonce, - mtod(m, uint8_t *), len); - - if (__predict_false(res == EINVAL)) { - goto error_free; - } else if (__predict_false(res == ESTALE)) { - wg_timers_event_want_initiation(&peer->p_timers); - } else if (__predict_false(res != 0)) { - panic("unexpected result: %d\n", res); - } + noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, mtod(m, uint8_t *), len); /* A packet with length 0 is a keepalive packet */ if (__predict_false(len == 0)) - DPRINTF(sc, "Sending keepalive packet to peer %llu\n", - peer->p_id); + DPRINTF(sc, "Sending keepalive packet to peer %llu\n", peer->p_id); /* Put header into packet */ if ((m = m_prepend(m, sizeof(struct wg_pkt_data), M_NOWAIT)) == NULL) goto error; - data.t = WG_PKT_DATA; - data.nonce = htole64(data.nonce); - memcpy(mtod(m, void *), &data, sizeof(struct wg_pkt_data)); - + data = mtod(m, struct wg_pkt_data *); + data->t = WG_PKT_DATA; + data->r_idx = idx; + data->nonce = htole64(pkt->p_nonce); /* * We would count ifc_opackets, ifc_obytes of m here, except if_snd * already does that for us, so no need to worry about it. @@ -1564,6 +1413,7 @@ error: void wg_decap(struct wg_softc *sc, struct wg_packet *pkt) { + struct noise_remote *remote; struct wg_pkt_data data; struct wg_peer *peer, *allowed_peer; struct mbuf *m; @@ -1571,7 +1421,9 @@ wg_decap(struct wg_softc *sc, struct wg_packet *pkt) struct ip6_hdr *ip6; int res, len; - peer = pkt->p_peer; + remote = noise_keypair_remote(pkt->p_keypair); + peer = noise_remote_arg(remote); + noise_remote_put(remote); m = pkt->p_mbuf; len = m->m_pkthdr.len; @@ -1579,28 +1431,25 @@ wg_decap(struct wg_softc *sc, struct wg_packet *pkt) memcpy(&data, mtod(m, void *), sizeof(struct wg_pkt_data)); m_adj(m, sizeof(struct wg_pkt_data)); - /* TODO teach noise_remote_decrypt about mbufs. Currently we have to + /* TODO teach noise_keypair_decrypt about mbufs. Currently we have to * resort to m_pullup to create an decryptable buffer. */ if ((m = m_pullup(m, m->m_pkthdr.len)) == NULL) { goto error; } - res = noise_remote_decrypt(&peer->p_remote, data.r_idx, - le64toh(data.nonce), mtod(m, void *), m->m_pkthdr.len); + pkt->p_nonce = letoh64(data.nonce); + res = noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, mtod(m, void *), m->m_pkthdr.len); if (__predict_false(res == EINVAL)) { goto error_free; } else if (__predict_false(res == ECONNRESET)) { wg_timers_event_handshake_complete(&peer->p_timers); - } else if (__predict_false(res == ESTALE)) { - wg_timers_event_want_initiation(&peer->p_timers); } else if (__predict_false(res != 0)) { panic("unexpected response: %d\n", res); } m_adj(m, -NOISE_AUTHTAG_LEN); - wg_peer_set_endpoint(peer, &pkt->p_endpoint); wg_peer_counters_add(peer, 0, len); counters_pkt(sc->sc_if.if_counters, ifc_ipackets, ifc_ibytes, @@ -1735,8 +1584,11 @@ wg_deliver_out(void *_peer) m_freem(m); } + noise_keypair_put(pkt->p_keypair); pool_put(&wg_packet_pool, pkt); } + if (noise_keep_key_fresh_send(peer->p_remote)) + wg_timers_event_want_initiation(&peer->p_timers); } void @@ -1750,14 +1602,20 @@ wg_deliver_in(void *_peer) while ((pkt = wg_queue_serial_dequeue(&peer->p_decap_serial)) != NULL) { m = pkt->p_mbuf; if (pkt->p_state == WG_PACKET_CRYPTED) { + if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0) { + m_freem(m); + goto put; + } + wg_timers_event_any_authenticated_packet_received( &peer->p_timers); wg_timers_event_any_authenticated_packet_traversal( &peer->p_timers); + wg_peer_set_endpoint(peer, &pkt->p_endpoint); if (m->m_pkthdr.len == 0) { m_freem(m); - continue; + goto put; } #if NBPFILTER > 0 @@ -1778,13 +1636,15 @@ wg_deliver_in(void *_peer) NET_UNLOCK(); wg_timers_event_data_received(&peer->p_timers); - } else { m_freem(m); } - +put: + noise_keypair_put(pkt->p_keypair); pool_put(&wg_packet_pool, pkt); } + if (noise_keep_key_fresh_recv(peer->p_remote)) + wg_timers_event_want_initiation(&peer->p_timers); } void @@ -1808,6 +1668,7 @@ wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_pack } else { mtx_leave(&serial->q_mtx); m_freem(pkt->p_mbuf); + noise_keypair_put(pkt->p_keypair); pool_put(&wg_packet_pool, pkt); return ENOBUFS; } @@ -1837,11 +1698,12 @@ wg_queue_in(struct wg_softc *sc, struct wg_peer *peer, struct wg_packet *pkt) void wg_queue_out(struct wg_softc *sc, struct wg_peer *peer) { + struct noise_keypair *keypair; struct wg_packet *pkt; struct mbuf_list ml; struct mbuf *m; - if (noise_remote_ready(&peer->p_remote) != 0) { + if ((keypair = noise_keypair_current(peer->p_remote)) == NULL) { wg_timers_event_want_initiation(&peer->p_timers); return; } @@ -1850,11 +1712,21 @@ wg_queue_out(struct wg_softc *sc, struct wg_peer *peer) while ((m = ml_dequeue(&ml)) != NULL) { pkt = m->m_pkthdr.ph_cookie; + pkt->p_keypair = noise_keypair_ref(keypair); + + if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0) { + ml_purge(&ml); + pool_put(&wg_packet_pool, pkt); + m_freem(m); + break; + } if (wg_queue_both(&sc->sc_encap_parallel, &peer->p_encap_serial, pkt) != 0) counters_inc(sc->sc_if.if_counters, ifc_oqdrops); } + noise_keypair_put(keypair); + task_add(wg_crypt_taskq, &sc->sc_encap); } @@ -1886,91 +1758,6 @@ wg_queue_parallel_dequeue(struct wg_queue *parallel) return pkt; } -struct noise_remote * -wg_remote_get(void *_sc, uint8_t public[NOISE_PUBLIC_KEY_LEN]) -{ - struct wg_peer *peer; - struct wg_softc *sc = _sc; - if ((peer = wg_peer_lookup(sc, public)) == NULL) - return NULL; - return &peer->p_remote; -} - -uint32_t -wg_index_set(void *_sc, struct noise_remote *remote) -{ - struct wg_peer *peer; - struct wg_softc *sc = _sc; - struct wg_index *index, *iter; - uint32_t key; - - /* - * We can modify this without a lock as wg_index_set, wg_index_drop are - * guaranteed to be serialised (per remote). - */ - peer = CONTAINER_OF(remote, struct wg_peer, p_remote); - index = SLIST_FIRST(&peer->p_unused_index); - KASSERT(index != NULL); - SLIST_REMOVE_HEAD(&peer->p_unused_index, i_unused_entry); - - index->i_value = remote; - - mtx_enter(&sc->sc_index_mtx); -assign_id: - key = index->i_key = arc4random(); - key &= sc->sc_index_mask; - LIST_FOREACH(iter, &sc->sc_index[key], i_entry) - if (iter->i_key == index->i_key) - goto assign_id; - - LIST_INSERT_HEAD(&sc->sc_index[key], index, i_entry); - - mtx_leave(&sc->sc_index_mtx); - - /* Likewise, no need to lock for index here. */ - return index->i_key; -} - -struct noise_remote * -wg_index_get(void *_sc, uint32_t key0) -{ - struct wg_softc *sc = _sc; - struct wg_index *iter; - struct noise_remote *remote = NULL; - uint32_t key = key0 & sc->sc_index_mask; - - mtx_enter(&sc->sc_index_mtx); - LIST_FOREACH(iter, &sc->sc_index[key], i_entry) - if (iter->i_key == key0) { - remote = iter->i_value; - break; - } - mtx_leave(&sc->sc_index_mtx); - return remote; -} - -void -wg_index_drop(void *_sc, uint32_t key0) -{ - struct wg_softc *sc = _sc; - struct wg_index *iter; - struct wg_peer *peer = NULL; - uint32_t key = key0 & sc->sc_index_mask; - - mtx_enter(&sc->sc_index_mtx); - LIST_FOREACH(iter, &sc->sc_index[key], i_entry) - if (iter->i_key == key0) { - LIST_REMOVE(iter, i_entry); - break; - } - mtx_leave(&sc->sc_index_mtx); - - /* We expect an index match */ - KASSERT(iter != NULL); - peer = CONTAINER_OF(iter->i_value, struct wg_peer, p_remote); - SLIST_INSERT_HEAD(&peer->p_unused_index, iter, i_unused_entry); -} - struct mbuf * wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6, void *_uh, int hlen) @@ -1978,7 +1765,6 @@ wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6, struct wg_pkt_data *data; struct noise_remote *remote; struct wg_packet *pkt; - struct wg_peer *peer; struct wg_softc *sc = _sc; struct udphdr *uh = _uh; @@ -2046,13 +1832,13 @@ wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6, } data = mtod(m, struct wg_pkt_data *); - if ((remote = wg_index_get(sc, data->r_idx)) == NULL) + if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL) goto error_mbuf; - peer = CONTAINER_OF(remote, struct wg_peer, p_remote); - pkt->p_peer = peer; + remote = noise_keypair_remote(pkt->p_keypair); pkt->p_mbuf = m; - wg_queue_in(sc, peer, pkt); + wg_queue_in(sc, noise_remote_arg(remote), pkt); + noise_remote_put(remote); } else { counters_inc(sc->sc_if.if_counters, ifc_ierrors); goto error_mbuf; @@ -2090,7 +1876,6 @@ wg_qstart(struct ifqueue *ifq) peer = t->t_peer; pkt->p_mbuf = m; - pkt->p_peer = peer; pkt->p_mtu = t->t_mtu; m->m_pkthdr.ph_cookie = pkt; @@ -2205,12 +1990,14 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) struct wg_peer *peer, *tpeer; struct wg_aip *aip, *taip; + struct noise_remote *remote; + in_port_t port; int rtable; uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE]; size_t i, j; - int ret, has_identity; + int ret; if ((ret = suser(curproc)) != 0) return ret; @@ -2222,27 +2009,23 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) goto error; if (iface_o.i_flags & WG_INTERFACE_REPLACE_PEERS) - TAILQ_FOREACH_SAFE(peer, &sc->sc_peer_seq, p_seq_entry, tpeer) + TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer) wg_peer_destroy(peer); if (iface_o.i_flags & WG_INTERFACE_HAS_PRIVATE && - (noise_local_keys(&sc->sc_local, NULL, private) || + (noise_local_keys(sc->sc_local, NULL, private) || timingsafe_bcmp(private, iface_o.i_private, WG_KEY_SIZE))) { if (curve25519_generate_public(public, iface_o.i_private)) { - if ((peer = wg_peer_lookup(sc, public)) != NULL) - wg_peer_destroy(peer); - } - noise_local_lock_identity(&sc->sc_local); - has_identity = noise_local_set_private(&sc->sc_local, - iface_o.i_private); - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) { - noise_remote_precompute(&peer->p_remote); - wg_timers_event_reset_handshake_last_sent(&peer->p_timers); - noise_remote_expire_current(&peer->p_remote); + if ((remote = noise_remote_lookup(sc->sc_local, public)) != NULL) { + wg_peer_destroy(noise_remote_arg(remote)); + noise_remote_put(remote); + } } - cookie_checker_update(&sc->sc_cookie, - has_identity == 0 ? public : NULL); - noise_local_unlock_identity(&sc->sc_local); + noise_local_private(sc->sc_local, iface_o.i_private); + if (noise_local_keys(sc->sc_local, public, NULL) == 0) + cookie_checker_update(&sc->sc_cookie, public); + else + cookie_checker_update(&sc->sc_cookie, NULL); } if (iface_o.i_flags & WG_INTERFACE_HAS_PORT) @@ -2256,7 +2039,7 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) rtable = sc->sc_udp_rtable; if (port != sc->sc_udp_port || rtable != sc->sc_udp_rtable) { - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) wg_peer_clear_src(peer); if (sc->sc_if.if_flags & IFF_RUNNING) @@ -2287,12 +2070,12 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) } /* Get local public and check that peer key doesn't match */ - if (noise_local_keys(&sc->sc_local, public, NULL) == 0 && + if (noise_local_keys(sc->sc_local, public, NULL) == 0 && bcmp(public, peer_o.p_public, WG_KEY_SIZE) == 0) goto next_peer; /* Lookup peer, or create if it doesn't exist */ - if ((peer = wg_peer_lookup(sc, peer_o.p_public)) == NULL) { + if ((remote = noise_remote_lookup(sc->sc_local, peer_o.p_public)) == NULL) { /* If we want to delete, no need creating a new one. * Also, don't create a new one if we only want to * update. */ @@ -2307,6 +2090,9 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) ret = ENOMEM; goto error; } + } else { + peer = noise_remote_arg(remote); + noise_remote_put(remote); } /* Remove peer and continue if specified */ @@ -2319,7 +2105,7 @@ wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data) wg_peer_set_sockaddr(peer, &peer_o.p_sa); if (peer_o.p_flags & WG_PEER_HAS_PSK) - noise_remote_set_psk(&peer->p_remote, peer_o.p_psk); + noise_remote_set_psk(peer->p_remote, peer_o.p_psk); if (peer_o.p_flags & WG_PEER_HAS_PKA) wg_timers_set_persistent_keepalive(&peer->p_timers, @@ -2394,7 +2180,7 @@ wg_ioctl_get(struct wg_softc *sc, struct wg_data_io *data) if (!is_suser) goto copy_out_iface; - if (noise_local_keys(&sc->sc_local, iface_o.i_public, + if (noise_local_keys(sc->sc_local, iface_o.i_public, iface_o.i_private) == 0) { iface_o.i_flags |= WG_INTERFACE_HAS_PUBLIC; iface_o.i_flags |= WG_INTERFACE_HAS_PRIVATE; @@ -2407,12 +2193,12 @@ wg_ioctl_get(struct wg_softc *sc, struct wg_data_io *data) peer_count = 0; peer_p = &iface_p->i_peers[0]; - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) { + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { bzero(&peer_o, sizeof(peer_o)); peer_o.p_flags = WG_PEER_HAS_PUBLIC; peer_o.p_protocol_version = 1; - if (noise_remote_keys(&peer->p_remote, peer_o.p_public, + if (noise_remote_keys(peer->p_remote, peer_o.p_public, peer_o.p_psk) == 0) peer_o.p_flags |= WG_PEER_HAS_PSK; @@ -2530,7 +2316,7 @@ wg_up(struct wg_softc *sc) */ ret = wg_bind(sc, &sc->sc_udp_port, &sc->sc_udp_rtable); if (ret == 0) { - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) wg_timers_enable(&peer->p_timers); } rw_exit_write(&sc->sc_lock); @@ -2558,15 +2344,15 @@ wg_down(struct wg_softc *sc) * that isn't granularly locked. */ rw_enter_read(&sc->sc_lock); - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) { + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { mq_purge(&peer->p_stage_queue); wg_timers_disable(&peer->p_timers); } taskq_barrier(wg_handshake_taskq); - TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) { - noise_remote_clear(&peer->p_remote); - wg_timers_event_reset_handshake_last_sent(&peer->p_timers); + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { + noise_remote_handshake_clear(peer->p_remote); + noise_remote_keypairs_clear(peer->p_remote); } wg_unbind(sc); @@ -2579,7 +2365,6 @@ wg_clone_create(struct if_clone *ifc, int unit) { struct ifnet *ifp; struct wg_softc *sc; - struct noise_upcall local_upcall; KERNEL_ASSERT_LOCKED(); @@ -2604,20 +2389,15 @@ wg_clone_create(struct if_clone *ifc, int unit) if ((sc = malloc(sizeof(*sc), M_DEVBUF, M_NOWAIT | M_ZERO)) == NULL) goto ret_00; - local_upcall.u_arg = sc; - local_upcall.u_remote_get = wg_remote_get; - local_upcall.u_index_set = wg_index_set; - local_upcall.u_index_drop = wg_index_drop; - - /* sc_if is initialised after everything else */ - arc4random_buf(&sc->sc_secret, sizeof(sc->sc_secret)); - rw_init(&sc->sc_lock, "wg"); - noise_local_init(&sc->sc_local, &local_upcall); - if (cookie_checker_init(&sc->sc_cookie, &wg_ratelimit_pool) != 0) + if ((sc->sc_local = noise_local_alloc(sc)) == NULL) goto ret_01; + if (cookie_checker_init(&sc->sc_cookie, &wg_ratelimit_pool) != 0) + goto ret_02; sc->sc_udp_port = 0; sc->sc_udp_rtable = 0; + TAILQ_INIT(&sc->sc_peers); + sc->sc_peer_num = 0; rw_init(&sc->sc_so_lock, "wg_so"); sc->sc_so4 = NULL; @@ -2627,24 +2407,11 @@ wg_clone_create(struct if_clone *ifc, int unit) sc->sc_aip_num = 0; if ((sc->sc_aip4 = art_alloc(0, 32, 0)) == NULL) - goto ret_02; + goto ret_03; #ifdef INET6 if ((sc->sc_aip6 = art_alloc(0, 128, 0)) == NULL) - goto ret_03; -#endif - - rw_init(&sc->sc_peer_lock, "wg_peer"); - sc->sc_peer_num = 0; - if ((sc->sc_peer = hashinit(HASHTABLE_PEER_SIZE, M_DEVBUF, - M_NOWAIT, &sc->sc_peer_mask)) == NULL) goto ret_04; - - TAILQ_INIT(&sc->sc_peer_seq); - - mtx_init(&sc->sc_index_mtx, IPL_NET); - if ((sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF, - M_NOWAIT, &sc->sc_index_mask)) == NULL) - goto ret_05; +#endif task_set(&sc->sc_handshake, wg_handshake_worker, sc); mq_init(&sc->sc_handshake_queue, MAX_QUEUED_HANDSHAKES, IPL_NET); @@ -2684,22 +2451,42 @@ wg_clone_create(struct if_clone *ifc, int unit) DPRINTF(sc, "Interface created\n"); return 0; -ret_05: - hashfree(sc->sc_peer, HASHTABLE_PEER_SIZE, M_DEVBUF); -ret_04: + #ifdef INET6 free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6)); -ret_03: +ret_04: #endif free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4)); -ret_02: +ret_03: cookie_checker_deinit(&sc->sc_cookie); +ret_02: + noise_local_put(sc->sc_local); ret_01: free(sc, M_DEVBUF, sizeof(*sc)); ret_00: return ENOBUFS; } +void +wg_clone_free(struct noise_local *l) +{ + struct wg_softc *sc = noise_local_arg(l); + wg_counter--; + if (wg_counter == 0) { + KASSERT(wg_handshake_taskq != NULL && wg_crypt_taskq != NULL); + taskq_destroy(wg_handshake_taskq); + taskq_destroy(wg_crypt_taskq); + wg_handshake_taskq = NULL; + wg_crypt_taskq = NULL; + } +#ifdef INET6 + free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6)); +#endif + free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4)); + cookie_checker_deinit(&sc->sc_cookie); + free(sc, M_DEVBUF, sizeof(*sc)); +} + int wg_clone_destroy(struct ifnet *ifp) { @@ -2709,33 +2496,16 @@ wg_clone_destroy(struct ifnet *ifp) KERNEL_ASSERT_LOCKED(); rw_enter_write(&sc->sc_lock); - TAILQ_FOREACH_SAFE(peer, &sc->sc_peer_seq, p_seq_entry, tpeer) + TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer) wg_peer_destroy(peer); + rw_exit_write(&sc->sc_lock); wg_unbind(sc); if_detach(ifp); - wg_counter--; - if (wg_counter == 0) { - KASSERT(wg_handshake_taskq != NULL && wg_crypt_taskq != NULL); - taskq_destroy(wg_handshake_taskq); - taskq_destroy(wg_crypt_taskq); - wg_handshake_taskq = NULL; - wg_crypt_taskq = NULL; - } - + noise_local_free(sc->sc_local, wg_clone_free); DPRINTF(sc, "Destroyed interface\n"); - - hashfree(sc->sc_index, HASHTABLE_INDEX_SIZE, M_DEVBUF); - hashfree(sc->sc_peer, HASHTABLE_PEER_SIZE, M_DEVBUF); -#ifdef INET6 - free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6)); -#endif - free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4)); - cookie_checker_deinit(&sc->sc_cookie); - noise_local_deinit(&sc->sc_local); - free(sc, M_DEVBUF, sizeof(*sc)); return 0; } diff --git a/sys/net/wg_noise.c b/sys/net/wg_noise.c index b87a46b18a1..ef62056cc01 100644 --- a/sys/net/wg_noise.c +++ b/sys/net/wg_noise.c @@ -1,7 +1,7 @@ /* $OpenBSD: wg_noise.c,v 1.5 2021/03/21 18:13:59 sthen Exp $ */ /* * Copyright (C) 2015-2020 Jason A. Donenfeld . All Rights Reserved. - * Copyright (C) 2019-2020 Matt Dunwoodie + * Copyright (C) 2019-2021 Matt Dunwoodie * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above @@ -21,152 +21,467 @@ #include #include #include +#include +#include #include #include #include +#include #include -/* Private functions */ -static struct noise_keypair * - noise_remote_keypair_allocate(struct noise_remote *); -static void - noise_remote_keypair_free(struct noise_remote *, - struct noise_keypair *); -static uint32_t noise_remote_handshake_index_get(struct noise_remote *); -static void noise_remote_handshake_index_drop(struct noise_remote *); - -static uint64_t noise_counter_send(struct noise_counter *); -static int noise_counter_recv(struct noise_counter *, uint64_t); +/* Protocol string constants */ +#define NOISE_HANDSHAKE_NAME "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" +#define NOISE_IDENTIFIER_NAME "WireGuard v1 zx2c4 Jason@zx2c4.com" + +/* Constants for the counter */ +#define COUNTER_BITS_TOTAL 8192 +#define COUNTER_BITS (sizeof(unsigned long) * 8) +#define COUNTER_NUM (COUNTER_BITS_TOTAL / COUNTER_BITS) +#define COUNTER_WINDOW_SIZE (COUNTER_BITS_TOTAL - COUNTER_BITS) + +/* Constants for the keypair */ +#define REKEY_AFTER_MESSAGES (1ull << 60) +#define REJECT_AFTER_MESSAGES (UINT64_MAX - COUNTER_WINDOW_SIZE - 1) +#define REKEY_AFTER_TIME 120 +#define REKEY_AFTER_TIME_RECV 165 +#define REJECT_INTERVAL (1000000000 / 50) /* fifty times per sec */ +/* 24 = floor(log2(REJECT_INTERVAL)) */ +#define REJECT_INTERVAL_MASK (~((1ull<<24)-1)) +#define TIMER_RESET (struct timespec){ -(REKEY_TIMEOUT+1), 0 } + +#define HT_INDEX_SIZE (1 << 13) +#define HT_INDEX_MASK (HT_INDEX_SIZE - 1) +#define HT_REMOTE_SIZE (1 << 11) +#define HT_REMOTE_MASK (HT_REMOTE_SIZE - 1) +#define MAX_REMOTE_PER_LOCAL (1 << 20) + +struct noise_index { + SMR_LIST_ENTRY(noise_index) i_entry; + uint32_t i_local_index; + uint32_t i_remote_index; + int i_is_keypair; +}; + +struct noise_keypair { + struct noise_index kp_index; + struct refcnt kp_refcnt; + int kp_can_send; + int kp_is_initiator; + struct timespec kp_birthdate; /* nanouptime */ + struct noise_remote *kp_remote; + + uint8_t kp_send[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t kp_recv[NOISE_SYMMETRIC_KEY_LEN]; + + /* Counter elements */ + struct rwlock kp_nonce_lock; + uint64_t kp_nonce_send; + uint64_t kp_nonce_recv; + unsigned long kp_backtrack[COUNTER_NUM]; + + struct smr_entry kp_smr; +}; + +struct noise_handshake { + uint8_t hs_e[NOISE_PUBLIC_KEY_LEN]; + uint8_t hs_hash[NOISE_HASH_LEN]; + uint8_t hs_ck[NOISE_HASH_LEN]; +}; + +struct noise_remote { + struct noise_index r_index; + + SMR_LIST_ENTRY(noise_remote) r_entry; + uint8_t r_public[NOISE_PUBLIC_KEY_LEN]; + + struct rwlock r_handshake_lock; + struct noise_handshake r_handshake; + int r_handshake_alive; + int r_handshake_initiator; + struct timespec r_last_sent; /* nanouptime */ + struct timespec r_last_init_recv; /* nanouptime */ + uint8_t r_timestamp[NOISE_TIMESTAMP_LEN]; + uint8_t r_psk[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t r_ss[NOISE_PUBLIC_KEY_LEN]; + + struct refcnt r_refcnt; + struct noise_local *r_local; + void *r_arg; + + struct rwlock r_keypair_lock; + struct noise_keypair *r_next, *r_current, *r_previous; + + struct smr_entry r_smr; + void (*r_cleanup)(struct noise_remote *); +}; + +struct noise_local { + struct rwlock l_identity_lock; + int l_has_identity; + uint8_t l_public[NOISE_PUBLIC_KEY_LEN]; + uint8_t l_private[NOISE_PUBLIC_KEY_LEN]; + + struct refcnt l_refcnt; + SIPHASH_KEY l_hash_key; + void *l_arg; + void (*l_cleanup)(struct noise_local *); + + struct rwlock l_remote_lock; + size_t l_remote_num; + SMR_LIST_HEAD(,noise_remote) l_remote_hash[HT_REMOTE_SIZE]; + + struct rwlock l_index_lock; + SMR_LIST_HEAD(,noise_index) l_index_hash[HT_INDEX_SIZE]; +}; + +static void noise_precompute_ss(struct noise_local *, struct noise_remote *); + +static void noise_remote_index_insert(struct noise_local *, struct noise_remote *); +static int noise_remote_index_remove(struct noise_local *, struct noise_remote *); +static void noise_remote_expire_current(struct noise_remote *); + + +static void noise_add_new_keypair(struct noise_local *, struct noise_remote *, struct noise_keypair *); +static int noise_received_with(struct noise_keypair *); +static int noise_begin_session(struct noise_remote *); +static void noise_keypair_drop(struct noise_keypair *); static void noise_kdf(uint8_t *, uint8_t *, uint8_t *, const uint8_t *, - size_t, size_t, size_t, size_t, - const uint8_t [NOISE_HASH_LEN]); -static int noise_mix_dh( - uint8_t [NOISE_HASH_LEN], - uint8_t [NOISE_SYMMETRIC_KEY_LEN], - const uint8_t [NOISE_PUBLIC_KEY_LEN], - const uint8_t [NOISE_PUBLIC_KEY_LEN]); -static int noise_mix_ss( - uint8_t ck[NOISE_HASH_LEN], - uint8_t key[NOISE_SYMMETRIC_KEY_LEN], - const uint8_t ss[NOISE_PUBLIC_KEY_LEN]); -static void noise_mix_hash( - uint8_t [NOISE_HASH_LEN], - const uint8_t *, - size_t); -static void noise_mix_psk( - uint8_t [NOISE_HASH_LEN], - uint8_t [NOISE_HASH_LEN], - uint8_t [NOISE_SYMMETRIC_KEY_LEN], - const uint8_t [NOISE_SYMMETRIC_KEY_LEN]); -static void noise_param_init( - uint8_t [NOISE_HASH_LEN], - uint8_t [NOISE_HASH_LEN], - const uint8_t [NOISE_PUBLIC_KEY_LEN]); - + size_t, size_t, size_t, size_t, + const uint8_t [NOISE_HASH_LEN]); +static int noise_mix_dh(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_SYMMETRIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); +static int noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t [NOISE_SYMMETRIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); +static void noise_mix_hash(uint8_t [NOISE_HASH_LEN], const uint8_t *, size_t); +static void noise_mix_psk(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN], + uint8_t [NOISE_SYMMETRIC_KEY_LEN], const uint8_t [NOISE_SYMMETRIC_KEY_LEN]); +static void noise_param_init(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); static void noise_msg_encrypt(uint8_t *, const uint8_t *, size_t, - uint8_t [NOISE_SYMMETRIC_KEY_LEN], - uint8_t [NOISE_HASH_LEN]); + uint8_t [NOISE_SYMMETRIC_KEY_LEN], uint8_t [NOISE_HASH_LEN]); static int noise_msg_decrypt(uint8_t *, const uint8_t *, size_t, - uint8_t [NOISE_SYMMETRIC_KEY_LEN], - uint8_t [NOISE_HASH_LEN]); -static void noise_msg_ephemeral( - uint8_t [NOISE_HASH_LEN], - uint8_t [NOISE_HASH_LEN], - const uint8_t src[NOISE_PUBLIC_KEY_LEN]); - + uint8_t [NOISE_SYMMETRIC_KEY_LEN], uint8_t [NOISE_HASH_LEN]); +static void noise_msg_ephemeral(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); static void noise_tai64n_now(uint8_t [NOISE_TIMESTAMP_LEN]); static int noise_timer_expired(struct timespec *, time_t, long); -/* Set/Get noise parameters */ -void -noise_local_init(struct noise_local *l, struct noise_upcall *upcall) +/* Make rwlock spin locks */ +#define rw_enter_write_spin(l) while (rw_enter(l, RW_WRITE | RW_NOSLEEP) != 0) +#define rw_enter_read_spin(l) while (rw_enter(l, RW_READ | RW_NOSLEEP) != 0) + +/* Local configuration */ +struct noise_local * +noise_local_alloc(void *arg) { - bzero(l, sizeof(*l)); - rw_init(&l->l_identity_lock, "noise_local_identity"); - l->l_upcall = *upcall; + struct noise_local *l; + size_t i; + + if ((l = malloc(sizeof(*l), M_DEVBUF, M_NOWAIT)) == NULL) + return NULL; + + rw_init(&l->l_identity_lock, "noise_identity"); + l->l_has_identity = 0; + bzero(l->l_public, NOISE_PUBLIC_KEY_LEN); + bzero(l->l_private, NOISE_PUBLIC_KEY_LEN); + + refcnt_init(&l->l_refcnt); + arc4random_buf(&l->l_hash_key, sizeof(l->l_hash_key)); + l->l_arg = arg; + l->l_cleanup = NULL; + + rw_init(&l->l_remote_lock, "noise_remote"); + l->l_remote_num = 0; + for (i = 0; i < HT_REMOTE_SIZE; i++) + SMR_LIST_INIT(&l->l_remote_hash[i]); + + rw_init(&l->l_index_lock, "noise_index"); + for (i = 0; i < HT_INDEX_SIZE; i++) + SMR_LIST_INIT(&l->l_index_hash[i]); + + return l; } -void -noise_local_deinit(struct noise_local *l) +struct noise_local * +noise_local_ref(struct noise_local *l) { - l->l_has_identity = 0; - explicit_bzero(&l->l_public, sizeof(l->l_public)); - explicit_bzero(&l->l_private, sizeof(l->l_private)); + refcnt_take(&l->l_refcnt); + return l; } void -noise_local_lock_identity(struct noise_local *l) +noise_local_put(struct noise_local *l) { - rw_enter_write(&l->l_identity_lock); + if (refcnt_rele(&l->l_refcnt)) { + if (l->l_cleanup != NULL) + l->l_cleanup(l); + explicit_bzero(l, sizeof(*l)); + free(l, M_DEVBUF, sizeof(*l)); + } } void -noise_local_unlock_identity(struct noise_local *l) +noise_local_free(struct noise_local *l, void (*cleanup)(struct noise_local *)) { - rw_exit_write(&l->l_identity_lock); + l->l_cleanup = cleanup; + noise_local_put(l); } -int -noise_local_set_private(struct noise_local *l, - uint8_t private[NOISE_PUBLIC_KEY_LEN]) +void * +noise_local_arg(struct noise_local *l) +{ + return l->l_arg; +} + +void +noise_local_private(struct noise_local *l, const uint8_t private[NOISE_PUBLIC_KEY_LEN]) { - rw_assert_wrlock(&l->l_identity_lock); + struct noise_remote *r; + size_t i; + rw_enter_write_spin(&l->l_identity_lock); memcpy(l->l_private, private, NOISE_PUBLIC_KEY_LEN); curve25519_clamp_secret(l->l_private); - l->l_has_identity = curve25519_generate_public(l->l_public, private); + l->l_has_identity = curve25519_generate_public(l->l_public, l->l_private); - return l->l_has_identity ? 0 : ENXIO; + smr_read_enter(); + for (i = 0; i < HT_REMOTE_SIZE; i++) { + SMR_LIST_FOREACH(r, &l->l_remote_hash[i], r_entry) { + noise_precompute_ss(l, r); + noise_remote_expire_current(r); + } + } + smr_read_leave(); + rw_exit_write(&l->l_identity_lock); } int noise_local_keys(struct noise_local *l, uint8_t public[NOISE_PUBLIC_KEY_LEN], uint8_t private[NOISE_PUBLIC_KEY_LEN]) { - int ret = 0; - rw_enter_read(&l->l_identity_lock); - if (l->l_has_identity) { + int has_identity; + rw_enter_read_spin(&l->l_identity_lock); + if ((has_identity = l->l_has_identity)) { if (public != NULL) memcpy(public, l->l_public, NOISE_PUBLIC_KEY_LEN); if (private != NULL) memcpy(private, l->l_private, NOISE_PUBLIC_KEY_LEN); - } else { - ret = ENXIO; } rw_exit_read(&l->l_identity_lock); - return ret; + return has_identity ? 0 : ENXIO; } -void -noise_remote_init(struct noise_remote *r, uint8_t public[NOISE_PUBLIC_KEY_LEN], - struct noise_local *l) +static void +noise_precompute_ss(struct noise_local *l, struct noise_remote *r) +{ + rw_enter_write_spin(&r->r_handshake_lock); + if (!l->l_has_identity || + !curve25519(r->r_ss, l->l_private, r->r_public)) + bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN); + rw_exit_write(&r->r_handshake_lock); +} + +/* Remote configuration */ +struct noise_remote * +noise_remote_alloc(struct noise_local *l, void *arg, + const uint8_t public[NOISE_PUBLIC_KEY_LEN], + const uint8_t psk[NOISE_PUBLIC_KEY_LEN]) { - bzero(r, sizeof(*r)); + struct noise_remote *r, *ri; + uint64_t idx; + + if ((r = malloc(sizeof(*r), M_DEVBUF, M_NOWAIT)) == NULL) + return NULL; + + r->r_index.i_is_keypair = 0; + memcpy(r->r_public, public, NOISE_PUBLIC_KEY_LEN); + rw_init(&r->r_handshake_lock, "noise_handshake"); + bzero(&r->r_handshake, sizeof(r->r_handshake)); + r->r_handshake_alive = 0; + r->r_handshake_initiator = 0; + r->r_last_sent = TIMER_RESET; + r->r_last_init_recv = TIMER_RESET; + bzero(r->r_timestamp, NOISE_TIMESTAMP_LEN); + noise_remote_set_psk(r, psk); + noise_precompute_ss(l, r); + + refcnt_init(&r->r_refcnt); + r->r_local = noise_local_ref(l); + r->r_arg = arg; + rw_init(&r->r_keypair_lock, "noise_keypair"); + r->r_next = r->r_current = r->r_previous = NULL; - SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[0], kp_entry); - SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[1], kp_entry); - SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[2], kp_entry); + smr_init(&r->r_smr); - KASSERT(l != NULL); - r->r_local = l; + /* Insert to hashtable */ + idx = SipHash24(&l->l_hash_key, public, NOISE_PUBLIC_KEY_LEN) & HT_REMOTE_MASK; - rw_enter_write(&l->l_identity_lock); - noise_remote_precompute(r); - rw_exit_write(&l->l_identity_lock); + rw_enter_write_spin(&l->l_remote_lock); + SMR_LIST_FOREACH_LOCKED(ri, &l->l_remote_hash[idx], r_entry) + if (timingsafe_bcmp(ri->r_public, public, NOISE_PUBLIC_KEY_LEN) == 0) + goto free; + if (l->l_remote_num < MAX_REMOTE_PER_LOCAL) { + l->l_remote_num++; + SMR_LIST_INSERT_HEAD_LOCKED(&l->l_remote_hash[idx], r, r_entry); + } else { +free: + free(r, M_DEVBUF, sizeof(*r)); + noise_local_put(l); + r = NULL; + } + rw_exit_write(&l->l_remote_lock); + + return r; +} + +struct noise_remote * +noise_remote_lookup(struct noise_local *l, const uint8_t public[NOISE_PUBLIC_KEY_LEN]) +{ + struct noise_remote *r, *ret = NULL; + uint64_t idx; + + idx = SipHash24(&l->l_hash_key, public, NOISE_PUBLIC_KEY_LEN) & HT_REMOTE_MASK; + + smr_read_enter(); + SMR_LIST_FOREACH(r, &l->l_remote_hash[idx], r_entry) { + if (timingsafe_bcmp(r->r_public, public, NOISE_PUBLIC_KEY_LEN) == 0) { + if (refcnt_take_if_gt(&r->r_refcnt, 0)) + ret = r; + break; + } + } + smr_read_leave(); + return ret; +} + +static void +noise_remote_index_insert(struct noise_local *l, struct noise_remote *r) +{ + struct noise_index *i, *r_i = &r->r_index; + uint32_t idx; + + noise_remote_index_remove(l, r); + + rw_enter_write_spin(&l->l_index_lock); +assign_id: + r_i->i_local_index = arc4random(); + idx = r_i->i_local_index & HT_INDEX_MASK; + SMR_LIST_FOREACH_LOCKED(i, &l->l_index_hash[idx], i_entry) + if (i->i_local_index == r_i->i_local_index) + goto assign_id; + + SMR_LIST_INSERT_HEAD_LOCKED(&l->l_index_hash[idx], r_i, i_entry); + rw_exit_write(&l->l_index_lock); + + r->r_handshake_alive = 1; +} + +struct noise_remote * +noise_remote_index_lookup(struct noise_local *l, uint32_t idx0) +{ + struct noise_index *i; + struct noise_remote *r, *ret = NULL; + uint32_t idx = idx0 & HT_INDEX_MASK; + + smr_read_enter(); + SMR_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) { + if (i->i_local_index == idx0 && !i->i_is_keypair) { + r = (struct noise_remote *) i; + if (refcnt_take_if_gt(&r->r_refcnt, 0)) + ret = r; + break; + } + } + smr_read_leave(); + return ret; +} + +static int +noise_remote_index_remove(struct noise_local *l, struct noise_remote *r) +{ + rw_assert_wrlock(&r->r_handshake_lock); + if (r->r_handshake_alive) { + rw_enter_write_spin(&l->l_index_lock); + SMR_LIST_REMOVE_LOCKED(&r->r_index, i_entry); + rw_exit_write(&l->l_index_lock); + r->r_handshake_alive = 0; + return 1; + } + return 0; +} + +struct noise_remote * +noise_remote_ref(struct noise_remote *r) +{ + refcnt_take(&r->r_refcnt); + return r; +} + +static void +noise_remote_smr_free(void *_r) +{ + struct noise_remote *r = _r; + if (r->r_cleanup != NULL) + r->r_cleanup(r); + noise_local_put(r->r_local); + explicit_bzero(r, sizeof(*r)); + free(r, M_DEVBUF, sizeof(*r)); +} + +void +noise_remote_put(struct noise_remote *r) +{ + if (refcnt_rele(&r->r_refcnt)) + smr_call(&r->r_smr, noise_remote_smr_free, r); +} + +void +noise_remote_free(struct noise_remote *r, void (*cleanup)(struct noise_remote *)) +{ + struct noise_local *l = r->r_local; + + r->r_cleanup = cleanup; + + /* remove from hashtable */ + rw_enter_write_spin(&l->l_remote_lock); + SMR_LIST_REMOVE_LOCKED(r, r_entry); + l->l_remote_num--; + rw_exit_write(&l->l_remote_lock); + + /* now clear all keypairs and handshakes, then put this reference */ + noise_remote_handshake_clear(r); + noise_remote_keypairs_clear(r); + noise_remote_put(r); +} + +struct noise_local * +noise_remote_local(struct noise_remote *r) +{ + return noise_local_ref(r->r_local); +} + +void * +noise_remote_arg(struct noise_remote *r) +{ + return r->r_arg; } void noise_remote_set_psk(struct noise_remote *r, - uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) + const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) { - rw_enter_write(&r->r_handshake_lock); - memcpy(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN); + rw_enter_write_spin(&r->r_handshake_lock); + if (psk == NULL) + bzero(r->r_psk, NOISE_SYMMETRIC_KEY_LEN); + else + memcpy(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN); rw_exit_write(&r->r_handshake_lock); } @@ -180,35 +495,391 @@ noise_remote_keys(struct noise_remote *r, uint8_t public[NOISE_PUBLIC_KEY_LEN], if (public != NULL) memcpy(public, r->r_public, NOISE_PUBLIC_KEY_LEN); - rw_enter_read(&r->r_handshake_lock); + rw_enter_read_spin(&r->r_handshake_lock); if (psk != NULL) memcpy(psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); ret = timingsafe_bcmp(r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN); rw_exit_read(&r->r_handshake_lock); - /* If r_psk != null_psk return 0, else ENOENT (no psk) */ - return ret ? 0 : ENOENT; + return ret ? 0 : ENOENT; +} + +int +noise_remote_initiation_expired(struct noise_remote *r) +{ + int expired; + rw_enter_read_spin(&r->r_handshake_lock); + expired = noise_timer_expired(&r->r_last_sent, REKEY_TIMEOUT, 0); + rw_exit_read(&r->r_handshake_lock); + return expired; +} + +void +noise_remote_handshake_clear(struct noise_remote *r) +{ + rw_enter_write_spin(&r->r_handshake_lock); + if (noise_remote_index_remove(r->r_local, r)) + bzero(&r->r_handshake, sizeof(r->r_handshake)); + r->r_last_sent = TIMER_RESET; + rw_exit_write(&r->r_handshake_lock); +} + +void +noise_remote_keypairs_clear(struct noise_remote *r) +{ + struct noise_keypair *kp; + + rw_enter_write_spin(&r->r_keypair_lock); + kp = SMR_PTR_GET_LOCKED(&r->r_next); + SMR_PTR_SET_LOCKED(&r->r_next, NULL); + noise_keypair_drop(kp); + + kp = SMR_PTR_GET_LOCKED(&r->r_current); + SMR_PTR_SET_LOCKED(&r->r_current, NULL); + noise_keypair_drop(kp); + + kp = SMR_PTR_GET_LOCKED(&r->r_previous); + SMR_PTR_SET_LOCKED(&r->r_previous, NULL); + noise_keypair_drop(kp); + rw_exit_write(&r->r_keypair_lock); +} + +static void +noise_remote_expire_current(struct noise_remote *r) +{ + struct noise_keypair *kp; + + noise_remote_handshake_clear(r); + + smr_read_enter(); + kp = SMR_PTR_GET(&r->r_next); + if (kp != NULL) WRITE_ONCE(kp->kp_can_send, 0); + kp = SMR_PTR_GET(&r->r_current); + if (kp != NULL) WRITE_ONCE(kp->kp_can_send, 0); + smr_read_leave(); +} + +/* Keypair functions */ +static void +noise_add_new_keypair(struct noise_local *l, struct noise_remote *r, + struct noise_keypair *kp) +{ + struct noise_keypair *next, *current, *previous; + struct noise_index *r_i = &r->r_index; + + /* Insert into the keypair table */ + rw_enter_write_spin(&r->r_keypair_lock); + next = SMR_PTR_GET_LOCKED(&r->r_next); + current = SMR_PTR_GET_LOCKED(&r->r_current); + previous = SMR_PTR_GET_LOCKED(&r->r_previous); + + if (kp->kp_is_initiator) { + if (next != NULL) { + SMR_PTR_SET_LOCKED(&r->r_next, NULL); + SMR_PTR_SET_LOCKED(&r->r_previous, next); + noise_keypair_drop(current); + } else { + SMR_PTR_SET_LOCKED(&r->r_previous, current); + } + noise_keypair_drop(previous); + SMR_PTR_SET_LOCKED(&r->r_current, kp); + } else { + SMR_PTR_SET_LOCKED(&r->r_next, kp); + noise_keypair_drop(next); + SMR_PTR_SET_LOCKED(&r->r_previous, NULL); + noise_keypair_drop(previous); + + } + rw_exit_write(&r->r_keypair_lock); + + /* Insert into index table */ + rw_assert_wrlock(&r->r_handshake_lock); + + kp->kp_index.i_is_keypair = 1; + kp->kp_index.i_local_index = r_i->i_local_index; + kp->kp_index.i_remote_index = r_i->i_remote_index; + + rw_enter_write_spin(&l->l_index_lock); + SMR_LIST_INSERT_BEFORE_LOCKED(r_i, &kp->kp_index, i_entry); + SMR_LIST_REMOVE_LOCKED(r_i, i_entry); + rw_exit_write(&l->l_index_lock); + + explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); +} + +static int +noise_received_with(struct noise_keypair *kp) +{ + struct noise_keypair *old; + struct noise_remote *r = kp->kp_remote; + + smr_read_enter(); + if (kp != SMR_PTR_GET(&r->r_next)) { + smr_read_leave(); + return 0; + } + smr_read_leave(); + + rw_enter_write_spin(&r->r_keypair_lock); + if (kp != SMR_PTR_GET_LOCKED(&r->r_next)) { + rw_exit_write(&r->r_keypair_lock); + return 0; + } + + old = SMR_PTR_GET_LOCKED(&r->r_previous); + SMR_PTR_SET_LOCKED(&r->r_previous, SMR_PTR_GET_LOCKED(&r->r_current)); + noise_keypair_drop(old); + SMR_PTR_SET_LOCKED(&r->r_current, kp); + SMR_PTR_SET_LOCKED(&r->r_next, NULL); + rw_exit_write(&r->r_keypair_lock); + + return ECONNRESET; +} + +static int +noise_begin_session(struct noise_remote *r) +{ + struct noise_keypair *kp; + + rw_assert_wrlock(&r->r_handshake_lock); + + if ((kp = malloc(sizeof(*kp), M_DEVBUF, M_NOWAIT)) == NULL) + return ENOSPC; + + refcnt_init(&kp->kp_refcnt); + kp->kp_can_send = 1; + kp->kp_is_initiator = r->r_handshake_initiator; + getnanouptime(&kp->kp_birthdate); + kp->kp_remote = noise_remote_ref(r); + + if (kp->kp_is_initiator) + noise_kdf(kp->kp_send, kp->kp_recv, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + r->r_handshake.hs_ck); + else + noise_kdf(kp->kp_recv, kp->kp_send, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + r->r_handshake.hs_ck); + + rw_init(&kp->kp_nonce_lock, "noise_nonce"); + kp->kp_nonce_send = 0; + kp->kp_nonce_recv = 0; + bzero(kp->kp_backtrack, sizeof(kp->kp_backtrack)); + smr_init(&kp->kp_smr); + + noise_add_new_keypair(r->r_local, r, kp); + return 0; +} + +struct noise_keypair * +noise_keypair_lookup(struct noise_local *l, uint32_t idx0) +{ + struct noise_index *i; + struct noise_keypair *kp, *ret = NULL; + uint32_t idx = idx0 & HT_INDEX_MASK; + + smr_read_enter(); + SMR_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) { + if (i->i_local_index == idx0 && i->i_is_keypair) { + kp = (struct noise_keypair *) i; + if (refcnt_take_if_gt(&kp->kp_refcnt, 0)) + ret = kp; + break; + } + } + smr_read_leave(); + return ret; +} + +struct noise_keypair * +noise_keypair_current(struct noise_remote *r) +{ + struct noise_keypair *kp, *ret = NULL; + + smr_read_enter(); + kp = SMR_PTR_GET(&r->r_current); + if (kp != NULL && READ_ONCE(kp->kp_can_send)) { + if (noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0)) + WRITE_ONCE(kp->kp_can_send, 0); + else if (refcnt_take_if_gt(&kp->kp_refcnt, 0)) + ret = kp; + } + smr_read_leave(); + return ret; +} + +struct noise_keypair * +noise_keypair_ref(struct noise_keypair *kp) +{ + refcnt_take(&kp->kp_refcnt); + return kp; +} + +static void +noise_keypair_smr_free(void *_kp) +{ + struct noise_keypair *kp = _kp; + noise_remote_put(kp->kp_remote); + explicit_bzero(kp, sizeof(*kp)); + free(kp, M_DEVBUF, sizeof(*kp)); +} + + +void +noise_keypair_put(struct noise_keypair *kp) +{ + if (refcnt_rele(&kp->kp_refcnt)) + smr_call(&kp->kp_smr, noise_keypair_smr_free, kp); +} + +static void +noise_keypair_drop(struct noise_keypair *kp) +{ + struct noise_remote *r; + struct noise_local *l; + + if (kp == NULL) + return; + + r = kp->kp_remote; + l = r->r_local; + + rw_enter_write_spin(&l->l_index_lock); + SMR_LIST_REMOVE_LOCKED(&kp->kp_index, i_entry); + rw_exit_write(&l->l_index_lock); + + noise_keypair_put(kp); +} + +struct noise_remote * +noise_keypair_remote(struct noise_keypair *kp) +{ + return noise_remote_ref(kp->kp_remote); +} + +int +noise_keypair_nonce_next(struct noise_keypair *kp, uint64_t *send) +{ +#ifdef __LP64__ + *send = atomic_inc_long_nv((u_long *)&kp->kp_nonce_send) - 1; +#else + rw_enter_write_spin(&kp->kp_nonce_lock); + *send = ctr->c_send++; + rw_exit_write(&kp->kp_nonce_lock); +#endif + if (*send < REJECT_AFTER_MESSAGES) + return 0; + WRITE_ONCE(kp->kp_can_send, 0); + return EINVAL; +} + +int +noise_keypair_nonce_check(struct noise_keypair *kp, uint64_t recv) +{ + uint64_t i, top, index_recv, index_ctr; + unsigned long bit; + int ret = EEXIST; + + rw_enter_write_spin(&kp->kp_nonce_lock); + + /* Check that the recv counter is valid */ + if (kp->kp_nonce_recv >= REJECT_AFTER_MESSAGES || + recv >= REJECT_AFTER_MESSAGES) + goto error; + + /* If the packet is out of the window, invalid */ + if (recv + COUNTER_WINDOW_SIZE < kp->kp_nonce_recv) + goto error; + + /* If the new counter is ahead of the current counter, we'll need to + * zero out the bitmap that has previously been used */ + index_recv = recv / COUNTER_BITS; + index_ctr = kp->kp_nonce_recv / COUNTER_BITS; + + if (recv > kp->kp_nonce_recv) { + top = MIN(index_recv - index_ctr, COUNTER_NUM); + for (i = 1; i <= top; i++) + kp->kp_backtrack[ + (i + index_ctr) & (COUNTER_NUM - 1)] = 0; + WRITE_ONCE(kp->kp_nonce_recv, recv); + } + + index_recv %= COUNTER_NUM; + bit = 1ul << (recv % COUNTER_BITS); + + if (kp->kp_backtrack[index_recv] & bit) + goto error; + + kp->kp_backtrack[index_recv] |= bit; + + ret = 0; +error: + rw_exit_write(&kp->kp_nonce_lock); + return ret; +} + +int +noise_keep_key_fresh_send(struct noise_remote *r) +{ + struct noise_keypair *current; + int keep_key_fresh; + + smr_read_enter(); + current = SMR_PTR_GET(&r->r_current); + keep_key_fresh = current != NULL && READ_ONCE(current->kp_can_send) && ( + READ_ONCE(current->kp_nonce_send) > REKEY_AFTER_MESSAGES || + (current->kp_is_initiator && noise_timer_expired(¤t->kp_birthdate, REKEY_AFTER_TIME, 0))); + smr_read_leave(); + + return keep_key_fresh ? ESTALE : 0; +} + +int +noise_keep_key_fresh_recv(struct noise_remote *r) +{ + struct noise_keypair *current; + int keep_key_fresh; + + smr_read_enter(); + current = SMR_PTR_GET(&r->r_current); + keep_key_fresh = current != NULL && READ_ONCE(current->kp_can_send) && + current->kp_is_initiator && noise_timer_expired(¤t->kp_birthdate, + REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT, 0); + smr_read_leave(); + + return keep_key_fresh ? ESTALE : 0; } void -noise_remote_precompute(struct noise_remote *r) +noise_keypair_encrypt(struct noise_keypair *kp, uint32_t *r_idx, uint64_t nonce, + uint8_t *buf, size_t buflen) { - struct noise_local *l = r->r_local; - rw_assert_wrlock(&l->l_identity_lock); - if (!l->l_has_identity) - bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN); - else if (!curve25519(r->r_ss, l->l_private, r->r_public)) - bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN); + chacha20poly1305_encrypt(buf, buf, buflen, NULL, 0, nonce, kp->kp_send); + *r_idx = kp->kp_index.i_remote_index; +} - rw_enter_write(&r->r_handshake_lock); - noise_remote_handshake_index_drop(r); - explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); - rw_exit_write(&r->r_handshake_lock); +int +noise_keypair_decrypt(struct noise_keypair *kp, uint64_t nonce, uint8_t *buf, + size_t buflen) +{ + if (READ_ONCE(kp->kp_nonce_recv) >= REJECT_AFTER_MESSAGES || + noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0)) + return EINVAL; + + if (chacha20poly1305_decrypt(buf, buf, buflen, NULL, 0, nonce, kp->kp_recv) == 0) + return EINVAL; + + if (noise_received_with(kp) != 0) + return ECONNRESET; + + return 0; } + /* Handshake functions */ int -noise_create_initiation(struct noise_remote *r, uint32_t *s_idx, +noise_create_initiation(struct noise_remote *r, + uint32_t *s_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]) @@ -218,10 +889,12 @@ noise_create_initiation(struct noise_remote *r, uint32_t *s_idx, uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; int ret = EINVAL; - rw_enter_read(&l->l_identity_lock); - rw_enter_write(&r->r_handshake_lock); + rw_enter_read_spin(&l->l_identity_lock); + rw_enter_write_spin(&r->r_handshake_lock); if (!l->l_has_identity) goto error; + if (!noise_timer_expired(&r->r_last_sent, REKEY_TIMEOUT, 0)) + goto error; noise_param_init(hs->hs_ck, hs->hs_hash, r->r_public); /* e */ @@ -247,10 +920,10 @@ noise_create_initiation(struct noise_remote *r, uint32_t *s_idx, noise_msg_encrypt(ets, ets, NOISE_TIMESTAMP_LEN, key, hs->hs_hash); - noise_remote_handshake_index_drop(r); - hs->hs_state = CREATED_INITIATION; - hs->hs_local_index = noise_remote_handshake_index_get(r); - *s_idx = hs->hs_local_index; + noise_remote_index_insert(l, r); + getnanouptime(&r->r_last_sent); + *s_idx = r->r_index.i_local_index; + r->r_handshake_initiator = 1; ret = 0; error: rw_exit_write(&r->r_handshake_lock); @@ -261,7 +934,8 @@ error: int noise_consume_initiation(struct noise_local *l, struct noise_remote **rp, - uint32_t s_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint32_t s_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]) { @@ -272,7 +946,7 @@ noise_consume_initiation(struct noise_local *l, struct noise_remote **rp, uint8_t timestamp[NOISE_TIMESTAMP_LEN]; int ret = EINVAL; - rw_enter_read(&l->l_identity_lock); + rw_enter_read_spin(&l->l_identity_lock); if (!l->l_has_identity) goto error; noise_param_init(hs.hs_ck, hs.hs_hash, l->l_public); @@ -290,23 +964,23 @@ noise_consume_initiation(struct noise_local *l, struct noise_remote **rp, goto error; /* Lookup the remote we received from */ - if ((r = l->l_upcall.u_remote_get(l->l_upcall.u_arg, r_public)) == NULL) + if ((r = noise_remote_lookup(l, r_public)) == NULL) goto error; /* ss */ if (noise_mix_ss(hs.hs_ck, key, r->r_ss) != 0) - goto error; + goto error_put; /* {t} */ if (noise_msg_decrypt(timestamp, ets, NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0) - goto error; + goto error_put; memcpy(hs.hs_e, ue, NOISE_PUBLIC_KEY_LEN); /* We have successfully computed the same results, now we ensure that * this is not an initiation replay, or a flood attack */ - rw_enter_write(&r->r_handshake_lock); + rw_enter_write_spin(&r->r_handshake_lock); /* Replay */ if (memcmp(timestamp, r->r_timestamp, NOISE_TIMESTAMP_LEN) > 0) @@ -314,21 +988,22 @@ noise_consume_initiation(struct noise_local *l, struct noise_remote **rp, else goto error_set; /* Flood attack */ - if (noise_timer_expired(&r->r_last_init, 0, REJECT_INTERVAL)) - getnanouptime(&r->r_last_init); + if (noise_timer_expired(&r->r_last_init_recv, 0, REJECT_INTERVAL)) + getnanouptime(&r->r_last_init_recv); else goto error_set; /* Ok, we're happy to accept this initiation now */ - noise_remote_handshake_index_drop(r); - hs.hs_state = CONSUMED_INITIATION; - hs.hs_local_index = noise_remote_handshake_index_get(r); - hs.hs_remote_index = s_idx; + noise_remote_index_insert(l, r); + r->r_index.i_remote_index = s_idx; + r->r_handshake_initiator = 0; r->r_handshake = hs; - *rp = r; + *rp = noise_remote_ref(r); ret = 0; error_set: rw_exit_write(&r->r_handshake_lock); +error_put: + noise_remote_put(r); error: rw_exit_read(&l->l_identity_lock); explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); @@ -337,18 +1012,21 @@ error: } int -noise_create_response(struct noise_remote *r, uint32_t *s_idx, uint32_t *r_idx, - uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN]) +noise_create_response(struct noise_remote *r, + uint32_t *s_idx, uint32_t *r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t en[0 + NOISE_AUTHTAG_LEN]) { struct noise_handshake *hs = &r->r_handshake; + struct noise_local *l = r->r_local; uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; uint8_t e[NOISE_PUBLIC_KEY_LEN]; int ret = EINVAL; - rw_enter_read(&r->r_local->l_identity_lock); - rw_enter_write(&r->r_handshake_lock); + rw_enter_read_spin(&l->l_identity_lock); + rw_enter_write_spin(&r->r_handshake_lock); - if (hs->hs_state != CONSUMED_INITIATION) + if (!r->r_handshake_alive || r->r_handshake_initiator) goto error; /* e */ @@ -371,51 +1049,57 @@ noise_create_response(struct noise_remote *r, uint32_t *s_idx, uint32_t *r_idx, /* {} */ noise_msg_encrypt(en, NULL, 0, key, hs->hs_hash); - hs->hs_state = CREATED_RESPONSE; - *r_idx = hs->hs_remote_index; - *s_idx = hs->hs_local_index; - ret = 0; + if ((ret = noise_begin_session(r)) == 0) { + getnanouptime(&r->r_last_sent); + *s_idx = r->r_index.i_local_index; + *r_idx = r->r_index.i_remote_index; + } error: rw_exit_write(&r->r_handshake_lock); - rw_exit_read(&r->r_local->l_identity_lock); + rw_exit_read(&l->l_identity_lock); explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); explicit_bzero(e, NOISE_PUBLIC_KEY_LEN); return ret; } int -noise_consume_response(struct noise_remote *r, uint32_t s_idx, uint32_t r_idx, - uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN]) +noise_consume_response(struct noise_local *l, struct noise_remote **rp, + uint32_t s_idx, uint32_t r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t en[0 + NOISE_AUTHTAG_LEN]) { - struct noise_local *l = r->r_local; - struct noise_handshake hs; + uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN]; uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; - uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN]; + struct noise_handshake hs; + struct noise_remote *r = NULL; int ret = EINVAL; - rw_enter_read(&l->l_identity_lock); + if ((r = noise_remote_index_lookup(l, r_idx)) == NULL) + return ret; + + rw_enter_read_spin(&l->l_identity_lock); if (!l->l_has_identity) goto error; - rw_enter_read(&r->r_handshake_lock); - hs = r->r_handshake; + rw_enter_read_spin(&r->r_handshake_lock); + if (!r->r_handshake_alive || !r->r_handshake_initiator) { + rw_exit_read(&r->r_handshake_lock); + goto error; + } memcpy(preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); + hs = r->r_handshake; rw_exit_read(&r->r_handshake_lock); - if (hs.hs_state != CREATED_INITIATION || - hs.hs_local_index != r_idx) - goto error; - /* e */ noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue); /* ee */ if (noise_mix_dh(hs.hs_ck, NULL, hs.hs_e, ue) != 0) - goto error; + goto error_zero; /* se */ if (noise_mix_dh(hs.hs_ck, NULL, l->l_private, ue) != 0) - goto error; + goto error_zero; /* psk */ noise_mix_psk(hs.hs_ck, hs.hs_hash, key, preshared_key); @@ -423,369 +1107,28 @@ noise_consume_response(struct noise_remote *r, uint32_t s_idx, uint32_t r_idx, /* {} */ if (noise_msg_decrypt(NULL, en, 0 + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0) - goto error; - - hs.hs_remote_index = s_idx; + goto error_zero; - rw_enter_write(&r->r_handshake_lock); - if (r->r_handshake.hs_state == hs.hs_state && - r->r_handshake.hs_local_index == hs.hs_local_index) { + rw_enter_write_spin(&r->r_handshake_lock); + if (r->r_handshake_alive && r->r_handshake_initiator && + r->r_index.i_local_index == r_idx) { r->r_handshake = hs; - r->r_handshake.hs_state = CONSUMED_RESPONSE; - ret = 0; + r->r_index.i_remote_index = s_idx; + ret = noise_begin_session(r); + *rp = noise_remote_ref(r); } rw_exit_write(&r->r_handshake_lock); -error: - rw_exit_read(&l->l_identity_lock); - explicit_bzero(&hs, sizeof(hs)); +error_zero: + explicit_bzero(preshared_key, NOISE_SYMMETRIC_KEY_LEN); explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); - return ret; -} - -int -noise_remote_begin_session(struct noise_remote *r) -{ - struct noise_handshake *hs = &r->r_handshake; - struct noise_keypair kp, *next, *current, *previous; - - rw_enter_write(&r->r_handshake_lock); - - /* We now derive the keypair from the handshake */ - if (hs->hs_state == CONSUMED_RESPONSE) { - kp.kp_is_initiator = 1; - noise_kdf(kp.kp_send, kp.kp_recv, NULL, NULL, - NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, - hs->hs_ck); - } else if (hs->hs_state == CREATED_RESPONSE) { - kp.kp_is_initiator = 0; - noise_kdf(kp.kp_recv, kp.kp_send, NULL, NULL, - NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, - hs->hs_ck); - } else { - rw_exit_write(&r->r_handshake_lock); - return EINVAL; - } - - kp.kp_valid = 1; - kp.kp_local_index = hs->hs_local_index; - kp.kp_remote_index = hs->hs_remote_index; - getnanouptime(&kp.kp_birthdate); - bzero(&kp.kp_ctr, sizeof(kp.kp_ctr)); - rw_init(&kp.kp_ctr.c_lock, "noise_counter"); - - /* Now we need to add_new_keypair */ - rw_enter_write(&r->r_keypair_lock); - next = r->r_next; - current = r->r_current; - previous = r->r_previous; - - if (kp.kp_is_initiator) { - if (next != NULL) { - r->r_next = NULL; - r->r_previous = next; - noise_remote_keypair_free(r, current); - } else { - r->r_previous = current; - } - - noise_remote_keypair_free(r, previous); - - r->r_current = noise_remote_keypair_allocate(r); - *r->r_current = kp; - } else { - noise_remote_keypair_free(r, next); - r->r_previous = NULL; - noise_remote_keypair_free(r, previous); - - r->r_next = noise_remote_keypair_allocate(r); - *r->r_next = kp; - } - rw_exit_write(&r->r_keypair_lock); - - explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); - rw_exit_write(&r->r_handshake_lock); - - explicit_bzero(&kp, sizeof(kp)); - return 0; -} - -void -noise_remote_clear(struct noise_remote *r) -{ - rw_enter_write(&r->r_handshake_lock); - noise_remote_handshake_index_drop(r); - explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); - rw_exit_write(&r->r_handshake_lock); - - rw_enter_write(&r->r_keypair_lock); - noise_remote_keypair_free(r, r->r_next); - noise_remote_keypair_free(r, r->r_current); - noise_remote_keypair_free(r, r->r_previous); - r->r_next = NULL; - r->r_current = NULL; - r->r_previous = NULL; - rw_exit_write(&r->r_keypair_lock); -} - -void -noise_remote_expire_current(struct noise_remote *r) -{ - rw_enter_write(&r->r_keypair_lock); - if (r->r_next != NULL) - r->r_next->kp_valid = 0; - if (r->r_current != NULL) - r->r_current->kp_valid = 0; - rw_exit_write(&r->r_keypair_lock); -} - -int -noise_remote_ready(struct noise_remote *r) -{ - struct noise_keypair *kp; - int ret; - - rw_enter_read(&r->r_keypair_lock); - /* kp_ctr isn't locked here, we're happy to accept a racy read. */ - if ((kp = r->r_current) == NULL || - !kp->kp_valid || - noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) || - kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES || - kp->kp_ctr.c_send >= REJECT_AFTER_MESSAGES) - ret = EINVAL; - else - ret = 0; - rw_exit_read(&r->r_keypair_lock); - return ret; -} - -int -noise_remote_encrypt(struct noise_remote *r, uint32_t *r_idx, uint64_t *nonce, - uint8_t *buf, size_t buflen) -{ - struct noise_keypair *kp; - int ret = EINVAL; - - rw_enter_read(&r->r_keypair_lock); - if ((kp = r->r_current) == NULL) - goto error; - - /* We confirm that our values are within our tolerances. We want: - * - a valid keypair - * - our keypair to be less than REJECT_AFTER_TIME seconds old - * - our receive counter to be less than REJECT_AFTER_MESSAGES - * - our send counter to be less than REJECT_AFTER_MESSAGES - * - * kp_ctr isn't locked here, we're happy to accept a racy read. */ - if (!kp->kp_valid || - noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) || - kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES || - ((*nonce = noise_counter_send(&kp->kp_ctr)) > REJECT_AFTER_MESSAGES)) - goto error; - - /* We encrypt into the same buffer, so the caller must ensure that buf - * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index - * are passed back out to the caller through the provided data pointer. */ - *r_idx = kp->kp_remote_index; - chacha20poly1305_encrypt(buf, buf, buflen, - NULL, 0, *nonce, kp->kp_send); - - /* If our values are still within tolerances, but we are approaching - * the tolerances, we notify the caller with ESTALE that they should - * establish a new keypair. The current keypair can continue to be used - * until the tolerances are hit. We notify if: - * - our send counter is valid and not less than REKEY_AFTER_MESSAGES - * - we're the initiator and our keypair is older than - * REKEY_AFTER_TIME seconds */ - ret = ESTALE; - if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) || - (kp->kp_is_initiator && - noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME, 0))) - goto error; - - ret = 0; -error: - rw_exit_read(&r->r_keypair_lock); - return ret; -} - -int -noise_remote_decrypt(struct noise_remote *r, uint32_t r_idx, uint64_t nonce, - uint8_t *buf, size_t buflen) -{ - struct noise_keypair *kp; - int ret = EINVAL; - - /* We retrieve the keypair corresponding to the provided index. We - * attempt the current keypair first as that is most likely. We also - * want to make sure that the keypair is valid as it would be - * catastrophic to decrypt against a zero'ed keypair. */ - rw_enter_read(&r->r_keypair_lock); - - if (r->r_current != NULL && r->r_current->kp_local_index == r_idx) { - kp = r->r_current; - } else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx) { - kp = r->r_previous; - } else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx) { - kp = r->r_next; - } else { - goto error; - } - - /* We confirm that our values are within our tolerances. These values - * are the same as the encrypt routine. - * - * kp_ctr isn't locked here, we're happy to accept a racy read. */ - if (noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) || - kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES) - goto error; - - /* Decrypt, then validate the counter. We don't want to validate the - * counter before decrypting as we do not know the message is authentic - * prior to decryption. */ - if (chacha20poly1305_decrypt(buf, buf, buflen, - NULL, 0, nonce, kp->kp_recv) == 0) - goto error; - - if (noise_counter_recv(&kp->kp_ctr, nonce) != 0) - goto error; - - /* If we've received the handshake confirming data packet then move the - * next keypair into current. If we do slide the next keypair in, then - * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a - * data packet can't confirm a session that we are an INITIATOR of. */ - if (kp == r->r_next) { - rw_exit_read(&r->r_keypair_lock); - rw_enter_write(&r->r_keypair_lock); - if (kp == r->r_next && kp->kp_local_index == r_idx) { - noise_remote_keypair_free(r, r->r_previous); - r->r_previous = r->r_current; - r->r_current = r->r_next; - r->r_next = NULL; - - ret = ECONNRESET; - goto error; - } - rw_enter(&r->r_keypair_lock, RW_DOWNGRADE); - } - - /* Similar to when we encrypt, we want to notify the caller when we - * are approaching our tolerances. We notify if: - * - we're the initiator and the current keypair is older than - * REKEY_AFTER_TIME_RECV seconds. */ - ret = ESTALE; - kp = r->r_current; - if (kp != NULL && - kp->kp_valid && - kp->kp_is_initiator && - noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME_RECV, 0)) - goto error; - - ret = 0; - -error: - rw_exit(&r->r_keypair_lock); - return ret; -} - -/* Private functions - these should not be called outside this file under any - * circumstances. */ -static struct noise_keypair * -noise_remote_keypair_allocate(struct noise_remote *r) -{ - struct noise_keypair *kp; - kp = SLIST_FIRST(&r->r_unused_keypairs); - SLIST_REMOVE_HEAD(&r->r_unused_keypairs, kp_entry); - return kp; -} - -static void -noise_remote_keypair_free(struct noise_remote *r, struct noise_keypair *kp) -{ - struct noise_upcall *u = &r->r_local->l_upcall; - if (kp != NULL) { - SLIST_INSERT_HEAD(&r->r_unused_keypairs, kp, kp_entry); - u->u_index_drop(u->u_arg, kp->kp_local_index); - bzero(kp->kp_send, sizeof(kp->kp_send)); - bzero(kp->kp_recv, sizeof(kp->kp_recv)); - } -} - -static uint32_t -noise_remote_handshake_index_get(struct noise_remote *r) -{ - struct noise_upcall *u = &r->r_local->l_upcall; - return u->u_index_set(u->u_arg, r); -} - -static void -noise_remote_handshake_index_drop(struct noise_remote *r) -{ - struct noise_handshake *hs = &r->r_handshake; - struct noise_upcall *u = &r->r_local->l_upcall; - rw_assert_wrlock(&r->r_handshake_lock); - if (hs->hs_state != HS_ZEROED) - u->u_index_drop(u->u_arg, hs->hs_local_index); -} - -static uint64_t -noise_counter_send(struct noise_counter *ctr) -{ -#ifdef __LP64__ - return atomic_inc_long_nv((u_long *)&ctr->c_send) - 1; -#else - uint64_t ret; - rw_enter_write(&ctr->c_lock); - ret = ctr->c_send++; - rw_exit_write(&ctr->c_lock); - return ret; -#endif -} - -static int -noise_counter_recv(struct noise_counter *ctr, uint64_t recv) -{ - uint64_t i, top, index_recv, index_ctr; - unsigned long bit; - int ret = EEXIST; - - rw_enter_write(&ctr->c_lock); - - /* Check that the recv counter is valid */ - if (ctr->c_recv >= REJECT_AFTER_MESSAGES || - recv >= REJECT_AFTER_MESSAGES) - goto error; - - /* If the packet is out of the window, invalid */ - if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv) - goto error; - - /* If the new counter is ahead of the current counter, we'll need to - * zero out the bitmap that has previously been used */ - index_recv = recv / COUNTER_BITS; - index_ctr = ctr->c_recv / COUNTER_BITS; - - if (recv > ctr->c_recv) { - top = MIN(index_recv - index_ctr, COUNTER_NUM); - for (i = 1; i <= top; i++) - ctr->c_backtrack[ - (i + index_ctr) & (COUNTER_NUM - 1)] = 0; - ctr->c_recv = recv; - } - - index_recv %= COUNTER_NUM; - bit = 1ul << (recv % COUNTER_BITS); - - if (ctr->c_backtrack[index_recv] & bit) - goto error; - - ctr->c_backtrack[index_recv] |= bit; - - ret = 0; + explicit_bzero(&hs, sizeof(hs)); error: - rw_exit_write(&ctr->c_lock); + rw_exit_read(&l->l_identity_lock); + noise_remote_put(r); return ret; } +/* Handshake helper functions */ static void noise_kdf(uint8_t *a, uint8_t *b, uint8_t *c, const uint8_t *x, size_t a_len, size_t b_len, size_t c_len, size_t x_len, @@ -794,13 +1137,6 @@ noise_kdf(uint8_t *a, uint8_t *b, uint8_t *c, const uint8_t *x, uint8_t out[BLAKE2S_HASH_SIZE + 1]; uint8_t sec[BLAKE2S_HASH_SIZE]; -#ifdef DIAGNOSTIC - KASSERT(a_len <= BLAKE2S_HASH_SIZE && b_len <= BLAKE2S_HASH_SIZE && - c_len <= BLAKE2S_HASH_SIZE); - KASSERT(!(b || b_len || c || c_len) || (a && a_len)); - KASSERT(!(c || c_len) || (b && b_len)); -#endif - /* Extract entropy from "x" into sec */ blake2s_hmac(sec, x, ck, BLAKE2S_HASH_SIZE, x_len, NOISE_HASH_LEN); @@ -964,10 +1300,6 @@ noise_timer_expired(struct timespec *birthdate, time_t sec, long nsec) struct timespec uptime; struct timespec expire = { .tv_sec = sec, .tv_nsec = nsec }; - /* We don't really worry about a zeroed birthdate, to avoid the extra - * check on every encrypt/decrypt. This does mean that r_last_init - * check may fail if getnanouptime is < REJECT_INTERVAL from 0. */ - getnanouptime(&uptime); timespecadd(birthdate, &expire, &expire); return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0; @@ -975,59 +1307,23 @@ noise_timer_expired(struct timespec *birthdate, time_t sec, long nsec) #ifdef WGTEST -#define MESSAGE_LEN 64 -#define LARGE_MESSAGE_LEN 1420 - #define T_LIM (COUNTER_WINDOW_SIZE + 1) #define T_INIT do { \ - bzero(&ctr, sizeof(ctr)); \ - rw_init(&ctr.c_lock, "counter"); \ + bzero(&kp, sizeof(kp)); \ + rw_init(&kp.kp_nonce_lock, "counter"); \ } while (0) #define T(num, v, e) do { \ - if (noise_counter_recv(&ctr, v) != e) { \ + if (noise_keypair_nonce_check(&kp, v) != e) { \ printf("%s, test %d: failed.\n", __func__, num); \ return; \ } \ } while (0) -#define T_FAILED(test) do { \ - printf("%s %s: failed\n", __func__, test); \ - return; \ -} while (0) #define T_PASSED printf("%s: passed.\n", __func__) -static struct noise_local al, bl; -static struct noise_remote ar, br; - -static struct noise_initiation { - uint32_t s_idx; - uint8_t ue[NOISE_PUBLIC_KEY_LEN]; - uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN]; - uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]; -} init; - -static struct noise_response { - uint32_t s_idx; - uint32_t r_idx; - uint8_t ue[NOISE_PUBLIC_KEY_LEN]; - uint8_t en[0 + NOISE_AUTHTAG_LEN]; -} resp; - -static uint64_t nonce; -static uint32_t index; -static uint8_t data[MESSAGE_LEN + NOISE_AUTHTAG_LEN]; -static uint8_t largedata[LARGE_MESSAGE_LEN + NOISE_AUTHTAG_LEN]; - -static struct noise_remote * -upcall_get(void *x0, uint8_t *x1) { return x0; } -static uint32_t -upcall_set(void *x0, struct noise_remote *x1) { return 5; } -static void -upcall_drop(void *x0, uint32_t x1) { } - -static void -noise_counter_test() +void +noise_test() { - struct noise_counter ctr; + struct noise_keypair kp; int i; T_INIT; @@ -1102,246 +1398,4 @@ noise_counter_test() T_PASSED; } - -static void -noise_handshake_init(struct noise_local *al, struct noise_remote *ar, - struct noise_local *bl, struct noise_remote *br) -{ - uint8_t apriv[NOISE_PUBLIC_KEY_LEN], bpriv[NOISE_PUBLIC_KEY_LEN]; - uint8_t apub[NOISE_PUBLIC_KEY_LEN], bpub[NOISE_PUBLIC_KEY_LEN]; - uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]; - - struct noise_upcall upcall = { - .u_arg = NULL, - .u_remote_get = upcall_get, - .u_index_set = upcall_set, - .u_index_drop = upcall_drop, - }; - - upcall.u_arg = ar; - noise_local_init(al, &upcall); - upcall.u_arg = br; - noise_local_init(bl, &upcall); - - arc4random_buf(apriv, NOISE_PUBLIC_KEY_LEN); - arc4random_buf(bpriv, NOISE_PUBLIC_KEY_LEN); - - noise_local_lock_identity(al); - noise_local_set_private(al, apriv); - noise_local_unlock_identity(al); - - noise_local_lock_identity(bl); - noise_local_set_private(bl, bpriv); - noise_local_unlock_identity(bl); - - noise_local_keys(al, apub, NULL); - noise_local_keys(bl, bpub, NULL); - - noise_remote_init(ar, bpub, al); - noise_remote_init(br, apub, bl); - - arc4random_buf(psk, NOISE_SYMMETRIC_KEY_LEN); - noise_remote_set_psk(ar, psk); - noise_remote_set_psk(br, psk); -} - -static void -noise_handshake_test() -{ - struct noise_remote *r; - int i; - - noise_handshake_init(&al, &ar, &bl, &br); - - /* Create initiation */ - if (noise_create_initiation(&ar, &init.s_idx, - init.ue, init.es, init.ets) != 0) - T_FAILED("create_initiation"); - - /* Check encrypted (es) validation */ - for (i = 0; i < sizeof(init.es); i++) { - init.es[i] = ~init.es[i]; - if (noise_consume_initiation(&bl, &r, init.s_idx, - init.ue, init.es, init.ets) != EINVAL) - T_FAILED("consume_initiation_es"); - init.es[i] = ~init.es[i]; - } - - /* Check encrypted (ets) validation */ - for (i = 0; i < sizeof(init.ets); i++) { - init.ets[i] = ~init.ets[i]; - if (noise_consume_initiation(&bl, &r, init.s_idx, - init.ue, init.es, init.ets) != EINVAL) - T_FAILED("consume_initiation_ets"); - init.ets[i] = ~init.ets[i]; - } - - /* Consume initiation properly */ - if (noise_consume_initiation(&bl, &r, init.s_idx, - init.ue, init.es, init.ets) != 0) - T_FAILED("consume_initiation"); - if (r != &br) - T_FAILED("remote_lookup"); - - /* Replay initiation */ - if (noise_consume_initiation(&bl, &r, init.s_idx, - init.ue, init.es, init.ets) != EINVAL) - T_FAILED("consume_initiation_replay"); - if (r != &br) - T_FAILED("remote_lookup_r_unchanged"); - - /* Create response */ - if (noise_create_response(&br, &resp.s_idx, - &resp.r_idx, resp.ue, resp.en) != 0) - T_FAILED("create_response"); - - /* Check encrypted (en) validation */ - for (i = 0; i < sizeof(resp.en); i++) { - resp.en[i] = ~resp.en[i]; - if (noise_consume_response(&ar, resp.s_idx, - resp.r_idx, resp.ue, resp.en) != EINVAL) - T_FAILED("consume_response_en"); - resp.en[i] = ~resp.en[i]; - } - - /* Consume response properly */ - if (noise_consume_response(&ar, resp.s_idx, - resp.r_idx, resp.ue, resp.en) != 0) - T_FAILED("consume_response"); - - /* Derive keys on both sides */ - if (noise_remote_begin_session(&ar) != 0) - T_FAILED("promote_ar"); - if (noise_remote_begin_session(&br) != 0) - T_FAILED("promote_br"); - - for (i = 0; i < MESSAGE_LEN; i++) - data[i] = i; - - /* Since bob is responder, he must not encrypt until confirmed */ - if (noise_remote_encrypt(&br, &index, &nonce, - data, MESSAGE_LEN) != EINVAL) - T_FAILED("encrypt_kci_wait"); - - /* Alice now encrypt and gets bob to decrypt */ - if (noise_remote_encrypt(&ar, &index, &nonce, - data, MESSAGE_LEN) != 0) - T_FAILED("encrypt_akp"); - if (noise_remote_decrypt(&br, index, nonce, - data, MESSAGE_LEN + NOISE_AUTHTAG_LEN) != ECONNRESET) - T_FAILED("decrypt_bkp"); - - for (i = 0; i < MESSAGE_LEN; i++) - if (data[i] != i) - T_FAILED("decrypt_message_akp_bkp"); - - /* Now bob has received confirmation, he can encrypt */ - if (noise_remote_encrypt(&br, &index, &nonce, - data, MESSAGE_LEN) != 0) - T_FAILED("encrypt_kci_ready"); - if (noise_remote_decrypt(&ar, index, nonce, - data, MESSAGE_LEN + NOISE_AUTHTAG_LEN) != 0) - T_FAILED("decrypt_akp"); - - for (i = 0; i < MESSAGE_LEN; i++) - if (data[i] != i) - T_FAILED("decrypt_message_bkp_akp"); - - T_PASSED; -} - -static void -noise_speed_test() -{ -#define SPEED_ITER (1<<16) - struct timespec start, end; - struct noise_remote *r; - int nsec, i; - -#define NSEC 1000000000 -#define T_TIME_START(iter, size) do { \ - printf("%s %d %d byte encryptions\n", __func__, iter, size); \ - nanouptime(&start); \ -} while (0) -#define T_TIME_END(iter, size) do { \ - nanouptime(&end); \ - timespecsub(&end, &start, &end); \ - nsec = (end.tv_sec * NSEC + end.tv_nsec) / iter; \ - printf("%s %d nsec/iter, %d iter/sec, %d byte/sec\n", \ - __func__, nsec, NSEC / nsec, NSEC / nsec * size); \ -} while (0) -#define T_TIME_START_SINGLE(name) do { \ - printf("%s %s\n", __func__, name); \ - nanouptime(&start); \ -} while (0) -#define T_TIME_END_SINGLE() do { \ - nanouptime(&end); \ - timespecsub(&end, &start, &end); \ - nsec = (end.tv_sec * NSEC + end.tv_nsec); \ - printf("%s %d nsec/iter, %d iter/sec\n", \ - __func__, nsec, NSEC / nsec); \ -} while (0) - - noise_handshake_init(&al, &ar, &bl, &br); - - T_TIME_START_SINGLE("create_initiation"); - if (noise_create_initiation(&ar, &init.s_idx, - init.ue, init.es, init.ets) != 0) - T_FAILED("create_initiation"); - T_TIME_END_SINGLE(); - - T_TIME_START_SINGLE("consume_initiation"); - if (noise_consume_initiation(&bl, &r, init.s_idx, - init.ue, init.es, init.ets) != 0) - T_FAILED("consume_initiation"); - T_TIME_END_SINGLE(); - - T_TIME_START_SINGLE("create_response"); - if (noise_create_response(&br, &resp.s_idx, - &resp.r_idx, resp.ue, resp.en) != 0) - T_FAILED("create_response"); - T_TIME_END_SINGLE(); - - T_TIME_START_SINGLE("consume_response"); - if (noise_consume_response(&ar, resp.s_idx, - resp.r_idx, resp.ue, resp.en) != 0) - T_FAILED("consume_response"); - T_TIME_END_SINGLE(); - - /* Derive keys on both sides */ - T_TIME_START_SINGLE("derive_keys"); - if (noise_remote_begin_session(&ar) != 0) - T_FAILED("begin_ar"); - T_TIME_END_SINGLE(); - if (noise_remote_begin_session(&br) != 0) - T_FAILED("begin_br"); - - /* Small data encryptions */ - T_TIME_START(SPEED_ITER, MESSAGE_LEN); - for (i = 0; i < SPEED_ITER; i++) { - if (noise_remote_encrypt(&ar, &index, &nonce, - data, MESSAGE_LEN) != 0) - T_FAILED("encrypt_akp"); - } - T_TIME_END(SPEED_ITER, MESSAGE_LEN); - - - /* Large data encryptions */ - T_TIME_START(SPEED_ITER, LARGE_MESSAGE_LEN); - for (i = 0; i < SPEED_ITER; i++) { - if (noise_remote_encrypt(&ar, &index, &nonce, - largedata, LARGE_MESSAGE_LEN) != 0) - T_FAILED("encrypt_akp"); - } - T_TIME_END(SPEED_ITER, LARGE_MESSAGE_LEN); -} - -void -noise_test() -{ - noise_counter_test(); - noise_handshake_test(); - noise_speed_test(); -} - #endif /* WGTEST */ diff --git a/sys/net/wg_noise.h b/sys/net/wg_noise.h index a90ed617ba1..ccee48c19a4 100644 --- a/sys/net/wg_noise.h +++ b/sys/net/wg_noise.h @@ -33,115 +33,85 @@ #define NOISE_AUTHTAG_LEN CHACHA20POLY1305_AUTHTAG_SIZE #define NOISE_HASH_LEN BLAKE2S_HASH_SIZE -/* Protocol string constants */ -#define NOISE_HANDSHAKE_NAME "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" -#define NOISE_IDENTIFIER_NAME "WireGuard v1 zx2c4 Jason@zx2c4.com" - -/* Constants for the counter */ -#define COUNTER_BITS_TOTAL 8192 -#define COUNTER_BITS (sizeof(unsigned long) * 8) -#define COUNTER_NUM (COUNTER_BITS_TOTAL / COUNTER_BITS) -#define COUNTER_WINDOW_SIZE (COUNTER_BITS_TOTAL - COUNTER_BITS) - -/* Constants for the keypair */ -#define REKEY_AFTER_MESSAGES (1ull << 60) -#define REJECT_AFTER_MESSAGES (UINT64_MAX - COUNTER_WINDOW_SIZE - 1) -#define REKEY_AFTER_TIME 120 -#define REKEY_AFTER_TIME_RECV 165 #define REJECT_AFTER_TIME 180 -#define REJECT_INTERVAL (1000000000 / 50) /* fifty times per sec */ -/* 24 = floor(log2(REJECT_INTERVAL)) */ -#define REJECT_INTERVAL_MASK (~((1ull<<24)-1)) - -enum noise_state_hs { - HS_ZEROED = 0, - CREATED_INITIATION, - CONSUMED_INITIATION, - CREATED_RESPONSE, - CONSUMED_RESPONSE, -}; - -struct noise_handshake { - enum noise_state_hs hs_state; - uint32_t hs_local_index; - uint32_t hs_remote_index; - uint8_t hs_e[NOISE_PUBLIC_KEY_LEN]; - uint8_t hs_hash[NOISE_HASH_LEN]; - uint8_t hs_ck[NOISE_HASH_LEN]; -}; - -struct noise_counter { - struct rwlock c_lock; - uint64_t c_send; - uint64_t c_recv; - unsigned long c_backtrack[COUNTER_NUM]; -}; - -struct noise_keypair { - SLIST_ENTRY(noise_keypair) kp_entry; - int kp_valid; - int kp_is_initiator; - uint32_t kp_local_index; - uint32_t kp_remote_index; - uint8_t kp_send[NOISE_SYMMETRIC_KEY_LEN]; - uint8_t kp_recv[NOISE_SYMMETRIC_KEY_LEN]; - struct timespec kp_birthdate; /* nanouptime */ - struct noise_counter kp_ctr; -}; - -struct noise_remote { - uint8_t r_public[NOISE_PUBLIC_KEY_LEN]; - struct noise_local *r_local; - uint8_t r_ss[NOISE_PUBLIC_KEY_LEN]; - - struct rwlock r_handshake_lock; - struct noise_handshake r_handshake; - uint8_t r_psk[NOISE_SYMMETRIC_KEY_LEN]; - uint8_t r_timestamp[NOISE_TIMESTAMP_LEN]; - struct timespec r_last_init; /* nanouptime */ - - struct rwlock r_keypair_lock; - SLIST_HEAD(,noise_keypair) r_unused_keypairs; - struct noise_keypair *r_next, *r_current, *r_previous; - struct noise_keypair r_keypair[3]; /* 3: next, current, previous. */ - -}; - -struct noise_local { - struct rwlock l_identity_lock; - int l_has_identity; - uint8_t l_public[NOISE_PUBLIC_KEY_LEN]; - uint8_t l_private[NOISE_PUBLIC_KEY_LEN]; - - struct noise_upcall { - void *u_arg; - struct noise_remote * - (*u_remote_get)(void *, uint8_t[NOISE_PUBLIC_KEY_LEN]); - uint32_t - (*u_index_set)(void *, struct noise_remote *); - void (*u_index_drop)(void *, uint32_t); - } l_upcall; -}; - -/* Set/Get noise parameters */ -void noise_local_init(struct noise_local *, struct noise_upcall *); -void noise_local_deinit(struct noise_local *); -void noise_local_lock_identity(struct noise_local *); -void noise_local_unlock_identity(struct noise_local *); -int noise_local_set_private(struct noise_local *, uint8_t[NOISE_PUBLIC_KEY_LEN]); -int noise_local_keys(struct noise_local *, uint8_t[NOISE_PUBLIC_KEY_LEN], +#define REKEY_TIMEOUT 5 +#define KEEPALIVE_TIMEOUT 10 + +struct noise_local; +struct noise_remote; +struct noise_keypair; + +/* Local configuration */ +struct noise_local * + noise_local_alloc(void *); +struct noise_local * + noise_local_ref(struct noise_local *); +void noise_local_put(struct noise_local *); +void noise_local_free(struct noise_local *, void (*)(struct noise_local *)); +void * noise_local_arg(struct noise_local *); + +void noise_local_private(struct noise_local *, + const uint8_t[NOISE_PUBLIC_KEY_LEN]); +int noise_local_keys(struct noise_local *, + uint8_t[NOISE_PUBLIC_KEY_LEN], uint8_t[NOISE_PUBLIC_KEY_LEN]); -void noise_remote_init(struct noise_remote *, uint8_t[NOISE_PUBLIC_KEY_LEN], - struct noise_local *); -void noise_remote_set_psk(struct noise_remote *, uint8_t[NOISE_SYMMETRIC_KEY_LEN]); -int noise_remote_keys(struct noise_remote *, uint8_t[NOISE_PUBLIC_KEY_LEN], +/* Remote configuration */ +struct noise_remote * + noise_remote_alloc(struct noise_local *, void *, + const uint8_t[NOISE_PUBLIC_KEY_LEN], + const uint8_t[NOISE_SYMMETRIC_KEY_LEN]); +struct noise_remote * + noise_remote_lookup(struct noise_local *, const uint8_t[NOISE_PUBLIC_KEY_LEN]); +struct noise_remote * + noise_remote_index_lookup(struct noise_local *, uint32_t); +struct noise_remote * + noise_remote_ref(struct noise_remote *); +void noise_remote_put(struct noise_remote *); +void noise_remote_free(struct noise_remote *, void (*)(struct noise_remote *)); +struct noise_local * + noise_remote_local(struct noise_remote *); +void * noise_remote_arg(struct noise_remote *); + +void noise_remote_set_psk(struct noise_remote *, + const uint8_t[NOISE_SYMMETRIC_KEY_LEN]); +int noise_remote_keys(struct noise_remote *, + uint8_t[NOISE_PUBLIC_KEY_LEN], uint8_t[NOISE_SYMMETRIC_KEY_LEN]); +int noise_remote_initiation_expired(struct noise_remote *); +void noise_remote_handshake_clear(struct noise_remote *); +void noise_remote_keypairs_clear(struct noise_remote *); + +/* Keypair functions */ +struct noise_keypair * + noise_keypair_lookup(struct noise_local *, uint32_t); +struct noise_keypair * + noise_keypair_current(struct noise_remote *); +struct noise_keypair * + noise_keypair_ref(struct noise_keypair *); +void noise_keypair_put(struct noise_keypair *); + +struct noise_remote * + noise_keypair_remote(struct noise_keypair *); + +int noise_keypair_nonce_next(struct noise_keypair *, uint64_t *); +int noise_keypair_nonce_check(struct noise_keypair *, uint64_t); + +int noise_keep_key_fresh_send(struct noise_remote *); +int noise_keep_key_fresh_recv(struct noise_remote *); +void noise_keypair_encrypt( + struct noise_keypair *, + uint32_t *r_idx, + uint64_t nonce, + uint8_t *buf, + size_t buflen); +int noise_keypair_decrypt( + struct noise_keypair *, + uint64_t nonce, + uint8_t *buf, + size_t buflen); -/* Should be called anytime noise_local_set_private is called */ -void noise_remote_precompute(struct noise_remote *); - -/* Cryptographic functions */ +/* Handshake functions */ int noise_create_initiation( struct noise_remote *, uint32_t *s_idx, @@ -165,31 +135,13 @@ int noise_create_response( uint8_t en[0 + NOISE_AUTHTAG_LEN]); int noise_consume_response( - struct noise_remote *, + struct noise_local *, + struct noise_remote **, uint32_t s_idx, uint32_t r_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN]); -int noise_remote_begin_session(struct noise_remote *); -void noise_remote_clear(struct noise_remote *); -void noise_remote_expire_current(struct noise_remote *); - -int noise_remote_ready(struct noise_remote *); - -int noise_remote_encrypt( - struct noise_remote *, - uint32_t *r_idx, - uint64_t *nonce, - uint8_t *buf, - size_t buflen); -int noise_remote_decrypt( - struct noise_remote *, - uint32_t r_idx, - uint64_t nonce, - uint8_t *buf, - size_t buflen); - #ifdef WGTEST void noise_test(); #endif /* WGTEST */ -- cgit v1.2.3-59-g8ed1b