summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authordlg <dlg@openbsd.org>2020-06-21 12:11:26 +0000
committerdlg <dlg@openbsd.org>2020-06-21 12:11:26 +0000
commit58360b13e88a8f23faf2a04ee8f459832fd7e8d8 (patch)
treecbdc0608dc07ff800586401895afccdfff2a2f43
parentwireguard is taking over the gif mbuf tag. (diff)
downloadwireguard-openbsd-58360b13e88a8f23faf2a04ee8f459832fd7e8d8.tar.xz
wireguard-openbsd-58360b13e88a8f23faf2a04ee8f459832fd7e8d8.zip
add wg(4), an in kernel driver for WireGuard vpn communication.
thanks to Matt Dunwoodie and Jason A. Donenfeld for their effort. it's at least as functional as the go implementation, and maybe more so since this one works on more architectures. i'm sure there's further development that can be done, but you can say that about anything and everything that's in the tree. ok deraadt@
-rw-r--r--sys/net/if.c10
-rw-r--r--sys/net/if_wg.c2735
-rw-r--r--sys/net/if_wg.h107
-rw-r--r--sys/net/wg_cookie.c697
-rw-r--r--sys/net/wg_cookie.h131
-rw-r--r--sys/net/wg_noise.c1344
-rw-r--r--sys/net/wg_noise.h195
7 files changed, 5218 insertions, 1 deletions
diff --git a/sys/net/if.c b/sys/net/if.c
index 1cd42e48bc5..8ba5cda6494 100644
--- a/sys/net/if.c
+++ b/sys/net/if.c
@@ -1,4 +1,4 @@
-/* $OpenBSD: if.c,v 1.607 2020/06/17 06:45:22 dlg Exp $ */
+/* $OpenBSD: if.c,v 1.608 2020/06/21 12:11:26 dlg Exp $ */
/* $NetBSD: if.c,v 1.35 1996/05/07 05:26:04 thorpej Exp $ */
/*
@@ -70,6 +70,7 @@
#include "ppp.h"
#include "pppoe.h"
#include "switch.h"
+#include "if_wg.h"
#include <sys/param.h>
#include <sys/systm.h>
@@ -2228,6 +2229,13 @@ ifioctl(struct socket *so, u_long cmd, caddr_t data, struct proc *p)
/* don't take NET_LOCK because i2c reads take a long time */
error = ((*ifp->if_ioctl)(ifp, cmd, data));
break;
+ case SIOCSWG:
+ case SIOCGWG:
+ /* Don't take NET_LOCK to allow wg(4) to continue to send and
+ * receive packets while we're loading a large number of
+ * peers. wg(4) uses its own lock to serialise access. */
+ error = ((*ifp->if_ioctl)(ifp, cmd, data));
+ break;
case SIOCSETKALIVE:
case SIOCDIFPHYADDR:
diff --git a/sys/net/if_wg.c b/sys/net/if_wg.c
new file mode 100644
index 00000000000..06ae1a05d73
--- /dev/null
+++ b/sys/net/if_wg.c
@@ -0,0 +1,2735 @@
+/*
+ * Copyright (C) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2020 Matt Dunwoodie <ncon@noconroy.net>
+ *
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ */
+
+#include "bpfilter.h"
+
+#include <sys/types.h>
+#include <sys/systm.h>
+#include <sys/param.h>
+#include <sys/pool.h>
+
+#include <sys/socket.h>
+#include <sys/socketvar.h>
+#include <sys/percpu.h>
+#include <sys/ioctl.h>
+#include <sys/mbuf.h>
+#include <sys/protosw.h>
+
+#include <net/if.h>
+#include <net/if_var.h>
+#include <net/if_types.h>
+#include <net/if_wg.h>
+
+#include <net/wg_noise.h>
+#include <net/wg_cookie.h>
+
+#include <net/pfvar.h>
+#include <net/route.h>
+#include <net/bpf.h>
+
+#include <netinet/ip.h>
+#include <netinet/ip6.h>
+#include <netinet/udp.h>
+#include <netinet/in_pcb.h>
+
+#include <crypto/siphash.h>
+
+#define DEFAULT_MTU 1420
+
+#define MAX_STAGED_PKT 128
+#define MAX_QUEUED_PKT 1024
+#define MAX_QUEUED_PKT_MASK (MAX_QUEUED_PKT - 1)
+
+#define MAX_QUEUED_HANDSHAKES 4096
+
+#define HASHTABLE_PEER_SIZE (1 << 11)
+#define HASHTABLE_INDEX_SIZE (1 << 13)
+#define MAX_PEERS_PER_IFACE (1 << 20)
+
+#define REKEY_TIMEOUT 5
+#define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */
+#define KEEPALIVE_TIMEOUT 10
+#define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT)
+#define NEW_HANDSHAKE_TIMEOUT (REKEY_TIMEOUT + KEEPALIVE_TIMEOUT)
+#define UNDERLOAD_TIMEOUT 1
+
+#define DPRINTF(sc, str, ...) do { if (ISSET((sc)->sc_if.if_flags, IFF_DEBUG))\
+ printf("%s: " str, (sc)->sc_if.if_xname, ##__VA_ARGS__); } while (0)
+
+#define CONTAINER_OF(ptr, type, member) ({ \
+ const __typeof( ((type *)0)->member ) *__mptr = (ptr); \
+ (type *)( (char *)__mptr - offsetof(type,member) );})
+
+/* First byte indicating packet type on the wire */
+#define WG_PKT_INITIATION htole32(1)
+#define WG_PKT_RESPONSE htole32(2)
+#define WG_PKT_COOKIE htole32(3)
+#define WG_PKT_DATA htole32(4)
+
+#define WG_PKT_WITH_PADDING(n) (((n) + (16-1)) & (~(16-1)))
+#define WG_KEY_SIZE WG_KEY_LEN
+
+struct wg_pkt_initiation {
+ uint32_t t;
+ uint32_t s_idx;
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN];
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN];
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN];
+ struct cookie_macs m;
+};
+
+struct wg_pkt_response {
+ uint32_t t;
+ uint32_t s_idx;
+ uint32_t r_idx;
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN];
+ uint8_t en[0 + NOISE_AUTHTAG_LEN];
+ struct cookie_macs m;
+};
+
+struct wg_pkt_cookie {
+ uint32_t t;
+ uint32_t r_idx;
+ uint8_t nonce[COOKIE_NONCE_SIZE];
+ uint8_t ec[COOKIE_ENCRYPTED_SIZE];
+};
+
+struct wg_pkt_data {
+ uint32_t t;
+ uint32_t r_idx;
+ uint8_t nonce[sizeof(uint64_t)];
+ uint8_t buf[];
+};
+
+struct wg_endpoint {
+ union {
+ struct sockaddr r_sa;
+ struct sockaddr_in r_sin;
+#ifdef INET6
+ struct sockaddr_in6 r_sin6;
+#endif
+ } e_remote;
+ union {
+ struct in_addr l_in;
+#ifdef INET6
+ struct in6_pktinfo l_pktinfo6;
+#define l_in6 l_pktinfo6.ipi6_addr
+#endif
+ } e_local;
+};
+
+struct wg_tag {
+ struct wg_endpoint t_endpoint;
+ struct wg_peer *t_peer;
+ struct mbuf *t_mbuf;
+ int t_done;
+ int t_mtu;
+};
+
+struct wg_index {
+ LIST_ENTRY(wg_index) i_entry;
+ SLIST_ENTRY(wg_index) i_unused_entry;
+ uint32_t i_key;
+ struct noise_remote *i_value;
+};
+
+struct wg_timers {
+ /* t_lock is for blocking wg_timers_event_* when setting t_disabled. */
+ struct rwlock t_lock;
+
+ int t_disabled;
+ int t_need_another_keepalive;
+ uint16_t t_persistent_keepalive_interval;
+ struct timeout t_new_handshake;
+ struct timeout t_send_keepalive;
+ struct timeout t_retry_handshake;
+ struct timeout t_zero_key_material;
+ struct timeout t_persistent_keepalive;
+
+ struct mutex t_handshake_mtx;
+ struct timespec t_handshake_last_sent; /* nanouptime */
+ struct timespec t_handshake_complete; /* nanotime */
+ int t_handshake_retries;
+};
+
+struct wg_aip {
+ struct art_node a_node;
+ LIST_ENTRY(wg_aip) a_entry;
+ struct wg_peer *a_peer;
+ struct wg_aip_io a_data;
+};
+
+struct wg_queue {
+ struct mutex q_mtx;
+ struct mbuf_list q_list;
+};
+
+struct wg_ring {
+ struct mutex r_mtx;
+ uint32_t r_head;
+ uint32_t r_tail;
+ struct mbuf *r_buf[MAX_QUEUED_PKT];
+};
+
+struct wg_peer {
+ LIST_ENTRY(wg_peer) p_pubkey_entry;
+ TAILQ_ENTRY(wg_peer) p_seq_entry;
+ uint64_t p_id;
+ struct wg_softc *p_sc;
+
+ struct noise_remote p_remote;
+ struct cookie_maker p_cookie;
+ struct wg_timers p_timers;
+
+ struct mutex p_counters_mtx;
+ uint64_t p_counters_tx;
+ uint64_t p_counters_rx;
+
+ struct mutex p_endpoint_mtx;
+ struct wg_endpoint p_endpoint;
+
+ struct task p_send_initiation;
+ struct task p_send_keepalive;
+ struct task p_clear_secrets;
+ struct task p_deliver_out;
+ struct task p_deliver_in;
+
+ struct mbuf_queue p_stage_queue;
+ struct wg_queue p_encap_queue;
+ struct wg_queue p_decap_queue;
+
+ SLIST_HEAD(,wg_index) p_unused_index;
+ struct wg_index p_index[3];
+
+ LIST_HEAD(,wg_aip) p_aip;
+
+ SLIST_ENTRY(wg_peer) p_start_list;
+ int p_start_onlist;
+};
+
+struct wg_softc {
+ struct ifnet sc_if;
+ SIPHASH_KEY sc_secret;
+
+ struct rwlock sc_lock;
+ struct noise_local sc_local;
+ struct cookie_checker sc_cookie;
+ in_port_t sc_udp_port;
+ int sc_udp_rtable;
+
+ struct rwlock sc_so_lock;
+ struct socket *sc_so4;
+#ifdef INET6
+ struct socket *sc_so6;
+#endif
+
+ size_t sc_aip_num;
+ struct art_root *sc_aip4;
+#ifdef INET6
+ struct art_root *sc_aip6;
+#endif
+
+ struct rwlock sc_peer_lock;
+ size_t sc_peer_num;
+ LIST_HEAD(,wg_peer) *sc_peer;
+ TAILQ_HEAD(,wg_peer) sc_peer_seq;
+ u_long sc_peer_mask;
+
+ struct mutex sc_index_mtx;
+ LIST_HEAD(,wg_index) *sc_index;
+ u_long sc_index_mask;
+
+ struct task sc_handshake;
+ struct mbuf_queue sc_handshake_queue;
+
+ struct task sc_encap;
+ struct task sc_decap;
+ struct wg_ring sc_encap_ring;
+ struct wg_ring sc_decap_ring;
+};
+
+struct wg_peer *
+ wg_peer_create(struct wg_softc *, uint8_t[WG_KEY_SIZE]);
+struct wg_peer *
+ wg_peer_lookup(struct wg_softc *, const uint8_t[WG_KEY_SIZE]);
+void wg_peer_destroy(struct wg_peer *);
+void wg_peer_set_endpoint_from_tag(struct wg_peer *, struct wg_tag *);
+void wg_peer_set_sockaddr(struct wg_peer *, struct sockaddr *);
+int wg_peer_get_sockaddr(struct wg_peer *, struct sockaddr *);
+void wg_peer_clear_src(struct wg_peer *);
+void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *);
+void wg_peer_counters_add(struct wg_peer *, uint64_t, uint64_t);
+
+int wg_aip_add(struct wg_softc *, struct wg_peer *, struct wg_aip_io *);
+struct wg_peer *
+ wg_aip_lookup(struct art_root *, void *);
+int wg_aip_remove(struct wg_softc *, struct wg_peer *,
+ struct wg_aip_io *);
+
+int wg_socket_open(struct socket **, int, in_port_t *, int *, void *);
+void wg_socket_close(struct socket **);
+int wg_bind(struct wg_softc *, in_port_t *, int *);
+void wg_unbind(struct wg_softc *);
+int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *);
+void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *,
+ size_t);
+
+struct wg_tag *
+ wg_tag_get(struct mbuf *);
+
+void wg_timers_init(struct wg_timers *);
+void wg_timers_enable(struct wg_timers *);
+void wg_timers_disable(struct wg_timers *);
+void wg_timers_set_persistent_keepalive(struct wg_timers *, uint16_t);
+int wg_timers_get_persistent_keepalive(struct wg_timers *, uint16_t *);
+void wg_timers_get_last_handshake(struct wg_timers *, struct timespec *);
+int wg_timers_expired_handshake_last_sent(struct wg_timers *);
+int wg_timers_check_handshake_last_sent(struct wg_timers *);
+
+void wg_timers_event_data_sent(struct wg_timers *);
+void wg_timers_event_data_received(struct wg_timers *);
+void wg_timers_event_any_authenticated_packet_sent(struct wg_timers *);
+void wg_timers_event_any_authenticated_packet_received(struct wg_timers *);
+void wg_timers_event_handshake_initiated(struct wg_timers *);
+void wg_timers_event_handshake_responded(struct wg_timers *);
+void wg_timers_event_handshake_complete(struct wg_timers *);
+void wg_timers_event_session_derived(struct wg_timers *);
+void wg_timers_event_any_authenticated_packet_traversal(struct wg_timers *);
+void wg_timers_event_want_initiation(struct wg_timers *);
+void wg_timers_event_reset_handshake_last_sent(struct wg_timers *);
+
+void wg_timers_run_send_initiation(void *, int);
+void wg_timers_run_retry_handshake(void *);
+void wg_timers_run_send_keepalive(void *);
+void wg_timers_run_new_handshake(void *);
+void wg_timers_run_zero_key_material(void *);
+void wg_timers_run_persistent_keepalive(void *);
+
+void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t);
+void wg_send_initiation(void *);
+void wg_send_response(struct wg_peer *);
+void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t,
+ struct wg_endpoint *e);
+void wg_send_keepalive(void *);
+void wg_peer_clear_secrets(void *);
+void wg_handshake(struct wg_softc *, struct mbuf *);
+void wg_handshake_worker(void *);
+
+void wg_encap(struct wg_softc *, struct mbuf *);
+void wg_decap(struct wg_softc *, struct mbuf *);
+void wg_encap_worker(void *);
+void wg_decap_worker(void *);
+void wg_deliver_out(void *);
+void wg_deliver_in(void *);
+
+int wg_queue_in(struct wg_softc *, struct wg_peer *, struct mbuf *);
+void wg_queue_out(struct wg_softc *, struct wg_peer *);
+struct mbuf *
+ wg_ring_dequeue(struct wg_ring *);
+struct mbuf *
+ wg_queue_dequeue(struct wg_queue *, struct wg_tag **);
+size_t wg_queue_len(struct wg_queue *);
+
+struct noise_remote *
+ wg_remote_get(void *, uint8_t[NOISE_PUBLIC_KEY_LEN]);
+uint32_t
+ wg_index_set(void *, struct noise_remote *);
+struct noise_remote *
+ wg_index_get(void *, uint32_t);
+void wg_index_drop(void *, uint32_t);
+
+struct mbuf *
+ wg_input(void *, struct mbuf *, struct ip *, struct ip6_hdr *, void *,
+ int);
+int wg_output(struct ifnet *, struct mbuf *, struct sockaddr *,
+ struct rtentry *);
+int wg_ioctl_set(struct wg_softc *, struct wg_data_io *);
+int wg_ioctl_get(struct wg_softc *, struct wg_data_io *);
+int wg_ioctl(struct ifnet *, u_long, caddr_t);
+int wg_up(struct wg_softc *);
+void wg_down(struct wg_softc *);
+
+int wg_clone_create(struct if_clone *, int);
+int wg_clone_destroy(struct ifnet *);
+void wgattach(int);
+
+uint64_t peer_counter = 0;
+uint64_t keypair_counter = 0;
+struct pool wg_aip_pool;
+struct pool wg_peer_pool;
+struct pool wg_ratelimit_pool;
+struct timeval underload_interval = { UNDERLOAD_TIMEOUT, 0 };
+
+size_t wg_counter = 0;
+struct taskq *wg_handshake_taskq;
+struct taskq *wg_crypt_taskq;
+
+struct if_clone wg_cloner =
+ IF_CLONE_INITIALIZER("wg", wg_clone_create, wg_clone_destroy);
+
+struct wg_peer *
+wg_peer_create(struct wg_softc *sc, uint8_t public[WG_KEY_SIZE])
+{
+ struct wg_peer *peer;
+ uint64_t idx;
+
+ rw_assert_wrlock(&sc->sc_lock);
+
+ if (sc->sc_peer_num >= MAX_PEERS_PER_IFACE)
+ return NULL;
+
+ if ((peer = pool_get(&wg_peer_pool, PR_NOWAIT)) == NULL)
+ return NULL;
+
+ peer->p_id = peer_counter++;
+ peer->p_sc = sc;
+
+ noise_remote_init(&peer->p_remote, public, &sc->sc_local);
+ cookie_maker_init(&peer->p_cookie, public);
+ wg_timers_init(&peer->p_timers);
+
+ mtx_init(&peer->p_counters_mtx, IPL_NET);
+ peer->p_counters_tx = 0;
+ peer->p_counters_rx = 0;
+
+ mtx_init(&peer->p_endpoint_mtx, IPL_NET);
+ bzero(&peer->p_endpoint, sizeof(peer->p_endpoint));
+
+ task_set(&peer->p_send_initiation, wg_send_initiation, peer);
+ task_set(&peer->p_send_keepalive, wg_send_keepalive, peer);
+ task_set(&peer->p_clear_secrets, wg_peer_clear_secrets, peer);
+ task_set(&peer->p_deliver_out, wg_deliver_out, peer);
+ task_set(&peer->p_deliver_in, wg_deliver_in, peer);
+
+ mq_init(&peer->p_stage_queue, MAX_STAGED_PKT, IPL_NET);
+ mtx_init(&peer->p_encap_queue.q_mtx, IPL_NET);
+ ml_init(&peer->p_encap_queue.q_list);
+ mtx_init(&peer->p_decap_queue.q_mtx, IPL_NET);
+ ml_init(&peer->p_decap_queue.q_list);
+
+ SLIST_INIT(&peer->p_unused_index);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[0],
+ i_unused_entry);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[1],
+ i_unused_entry);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[2],
+ i_unused_entry);
+
+ LIST_INIT(&peer->p_aip);
+
+ peer->p_start_onlist = 0;
+
+ idx = SipHash24(&sc->sc_secret, public, WG_KEY_SIZE);
+ idx &= sc->sc_peer_mask;
+
+ rw_enter_write(&sc->sc_peer_lock);
+ LIST_INSERT_HEAD(&sc->sc_peer[idx], peer, p_pubkey_entry);
+ TAILQ_INSERT_TAIL(&sc->sc_peer_seq, peer, p_seq_entry);
+ sc->sc_peer_num++;
+ rw_exit_write(&sc->sc_peer_lock);
+
+ DPRINTF(sc, "Peer %llu created\n", peer->p_id);
+ return peer;
+}
+
+struct wg_peer *
+wg_peer_lookup(struct wg_softc *sc, const uint8_t public[WG_KEY_SIZE])
+{
+ uint8_t peer_key[WG_KEY_SIZE];
+ struct wg_peer *peer;
+ uint64_t idx;
+
+ idx = SipHash24(&sc->sc_secret, public, WG_KEY_SIZE);
+ idx &= sc->sc_peer_mask;
+
+ rw_enter_read(&sc->sc_peer_lock);
+ LIST_FOREACH(peer, &sc->sc_peer[idx], p_pubkey_entry) {
+ noise_remote_keys(&peer->p_remote, peer_key, NULL);
+ if (timingsafe_bcmp(peer_key, public, WG_KEY_SIZE) == 0)
+ goto done;
+ }
+ peer = NULL;
+done:
+ rw_exit_read(&sc->sc_peer_lock);
+ return peer;
+}
+
+void
+wg_peer_destroy(struct wg_peer *peer)
+{
+ struct wg_softc *sc = peer->p_sc;
+ struct wg_aip *aip, *taip;
+
+ rw_assert_wrlock(&sc->sc_lock);
+
+ /*
+ * Remove peer from the pubkey hashtable and disable all timeouts.
+ * After this, and flushing wg_handshake_taskq, then no more handshakes
+ * can be started.
+ */
+ rw_enter_write(&sc->sc_peer_lock);
+ LIST_REMOVE(peer, p_pubkey_entry);
+ TAILQ_REMOVE(&sc->sc_peer_seq, peer, p_seq_entry);
+ sc->sc_peer_num--;
+ rw_exit_write(&sc->sc_peer_lock);
+
+ wg_timers_disable(&peer->p_timers);
+
+ taskq_barrier(wg_handshake_taskq);
+
+ /*
+ * Now we drop all allowed ips, to drop all outgoing packets to the
+ * peer. Then drop all the indexes to drop all incoming packets to the
+ * peer. Then we can flush if_snd, wg_crypt_taskq and then nettq to
+ * ensure no more references to the peer exist.
+ */
+ LIST_FOREACH_SAFE(aip, &peer->p_aip, a_entry, taip)
+ wg_aip_remove(sc, peer, &aip->a_data);
+
+ noise_remote_clear(&peer->p_remote);
+
+ NET_LOCK();
+ while (!ifq_empty(&sc->sc_if.if_snd)) {
+ NET_UNLOCK();
+ tsleep_nsec(sc, PWAIT, "wg_ifq", 1000);
+ NET_LOCK();
+ }
+ NET_UNLOCK();
+
+ taskq_barrier(wg_crypt_taskq);
+ taskq_barrier(net_tq(sc->sc_if.if_index));
+
+ DPRINTF(sc, "Peer %llu destroyed\n", peer->p_id);
+ explicit_bzero(peer, sizeof(*peer));
+ pool_put(&wg_peer_pool, peer);
+}
+
+void
+wg_peer_set_endpoint_from_tag(struct wg_peer *peer, struct wg_tag *t)
+{
+ if (memcmp(&t->t_endpoint, &peer->p_endpoint,
+ sizeof(t->t_endpoint)) == 0)
+ return;
+
+ mtx_enter(&peer->p_endpoint_mtx);
+ peer->p_endpoint = t->t_endpoint;
+ mtx_leave(&peer->p_endpoint_mtx);
+}
+
+void
+wg_peer_set_sockaddr(struct wg_peer *peer, struct sockaddr *remote)
+{
+ mtx_enter(&peer->p_endpoint_mtx);
+ memcpy(&peer->p_endpoint.e_remote, remote,
+ sizeof(peer->p_endpoint.e_remote));
+ bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local));
+ mtx_leave(&peer->p_endpoint_mtx);
+}
+
+int
+wg_peer_get_sockaddr(struct wg_peer *peer, struct sockaddr *remote)
+{
+ int ret = 0;
+
+ mtx_enter(&peer->p_endpoint_mtx);
+ if (peer->p_endpoint.e_remote.r_sa.sa_family != AF_UNSPEC)
+ memcpy(remote, &peer->p_endpoint.e_remote,
+ sizeof(peer->p_endpoint.e_remote));
+ else
+ ret = ENOENT;
+ mtx_leave(&peer->p_endpoint_mtx);
+ return ret;
+}
+
+void
+wg_peer_clear_src(struct wg_peer *peer)
+{
+ mtx_enter(&peer->p_endpoint_mtx);
+ bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local));
+ mtx_leave(&peer->p_endpoint_mtx);
+}
+
+void
+wg_peer_get_endpoint(struct wg_peer *peer, struct wg_endpoint *endpoint)
+{
+ mtx_enter(&peer->p_endpoint_mtx);
+ memcpy(endpoint, &peer->p_endpoint, sizeof(*endpoint));
+ mtx_leave(&peer->p_endpoint_mtx);
+}
+
+void
+wg_peer_counters_add(struct wg_peer *peer, uint64_t tx, uint64_t rx)
+{
+ mtx_enter(&peer->p_counters_mtx);
+ peer->p_counters_tx += tx;
+ peer->p_counters_rx += rx;
+ mtx_leave(&peer->p_counters_mtx);
+}
+
+int
+wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, struct wg_aip_io *d)
+{
+ struct art_root *root;
+ struct art_node *node;
+ struct wg_aip *aip;
+ int ret = 0;
+
+ switch (d->a_af) {
+ case AF_INET: root = sc->sc_aip4; break;
+#ifdef INET6
+ case AF_INET6: root = sc->sc_aip6; break;
+#endif
+ default: return EAFNOSUPPORT;
+ }
+
+ if ((aip = pool_get(&wg_aip_pool, PR_NOWAIT)) == NULL)
+ return ENOBUFS;
+ bzero(aip, sizeof(*aip));
+
+ rw_enter_write(&root->ar_lock);
+ node = art_insert(root, &aip->a_node, &d->a_addr, d->a_cidr);
+
+ if (node == &aip->a_node) {
+ aip->a_peer = peer;
+ aip->a_data = *d;
+ LIST_INSERT_HEAD(&peer->p_aip, aip, a_entry);
+ sc->sc_aip_num++;
+ } else {
+ pool_put(&wg_aip_pool, aip);
+ aip = (struct wg_aip *) node;
+ if (aip->a_peer != peer) {
+ LIST_REMOVE(aip, a_entry);
+ LIST_INSERT_HEAD(&peer->p_aip, aip, a_entry);
+ aip->a_peer = peer;
+ }
+ }
+ rw_exit_write(&root->ar_lock);
+ return ret;
+}
+
+struct wg_peer *
+wg_aip_lookup(struct art_root *root, void *addr)
+{
+ struct srp_ref sr;
+ struct art_node *node;
+
+ node = art_match(root, addr, &sr);
+ srp_leave(&sr);
+
+ return node == NULL ? NULL : ((struct wg_aip *) node)->a_peer;
+}
+
+int
+wg_aip_remove(struct wg_softc *sc, struct wg_peer *peer, struct wg_aip_io *d)
+{
+ struct srp_ref sr;
+ struct art_root *root;
+ struct art_node *node;
+ struct wg_aip *aip;
+ int ret = 0;
+
+ switch (d->a_af) {
+ case AF_INET: root = sc->sc_aip4; break;
+#ifdef INET6
+ case AF_INET6: root = sc->sc_aip6; break;
+#endif
+ default: return EAFNOSUPPORT;
+ }
+
+ rw_enter_write(&root->ar_lock);
+ if ((node = art_lookup(root, &d->a_addr, d->a_cidr, &sr)) == NULL) {
+ ret = ENOENT;
+ } else if (((struct wg_aip *) node)->a_peer != peer) {
+ ret = EXDEV;
+ } else {
+ aip = (struct wg_aip *)node;
+ if (art_delete(root, node, &d->a_addr, d->a_cidr) == NULL)
+ panic("art_delete failed to delete node %p", node);
+
+ sc->sc_aip_num--;
+ LIST_REMOVE(aip, a_entry);
+ pool_put(&wg_aip_pool, aip);
+ }
+
+ srp_leave(&sr);
+ rw_exit_write(&root->ar_lock);
+ return ret;
+}
+
+int
+wg_socket_open(struct socket **so, int af, in_port_t *port,
+ int *rtable, void *upcall_arg)
+{
+ struct mbuf mhostnam, mrtable;
+#ifdef INET6
+ struct sockaddr_in6 *sin6;
+#endif
+ struct sockaddr_in *sin;
+ int ret, s;
+
+ m_inithdr(&mhostnam);
+ m_inithdr(&mrtable);
+
+ bzero(mtod(&mrtable, u_int *), sizeof(u_int));
+ *mtod(&mrtable, u_int *) = *rtable;
+ mrtable.m_len = sizeof(u_int);
+
+ if (af == AF_INET) {
+ sin = mtod(&mhostnam, struct sockaddr_in *);
+ bzero(sin, sizeof(*sin));
+ sin->sin_len = sizeof(*sin);
+ sin->sin_family = AF_INET;
+ sin->sin_port = *port;
+ sin->sin_addr.s_addr = INADDR_ANY;
+ mhostnam.m_len = sin->sin_len;
+#ifdef INET6
+ } else if (af == AF_INET6) {
+ sin6 = mtod(&mhostnam, struct sockaddr_in6 *);
+ bzero(sin6, sizeof(*sin6));
+ sin6->sin6_len = sizeof(*sin6);
+ sin6->sin6_family = AF_INET6;
+ sin6->sin6_port = *port;
+ sin6->sin6_addr = (struct in6_addr) { .s6_addr = { 0 } };
+ mhostnam.m_len = sin6->sin6_len;
+#endif
+ } else {
+ return EAFNOSUPPORT;
+ }
+
+ if ((ret = socreate(af, so, SOCK_DGRAM, 0)) != 0)
+ return ret;
+
+ s = solock(*so);
+ sotoinpcb(*so)->inp_upcall = wg_input;
+ sotoinpcb(*so)->inp_upcall_arg = upcall_arg;
+
+ if ((ret = sosetopt(*so, SOL_SOCKET, SO_RTABLE, &mrtable)) == 0) {
+ if ((ret = sobind(*so, &mhostnam, curproc)) == 0) {
+ *port = sotoinpcb(*so)->inp_lport;
+ *rtable = sotoinpcb(*so)->inp_rtableid;
+ }
+ }
+ sounlock(*so, s);
+
+ if (ret != 0)
+ wg_socket_close(so);
+
+ return ret;
+}
+
+void
+wg_socket_close(struct socket **so)
+{
+ if (*so != NULL && soclose(*so, 0) != 0)
+ panic("Unable to close wg socket");
+ *so = NULL;
+}
+
+int
+wg_bind(struct wg_softc *sc, in_port_t *portp, int *rtablep)
+{
+ int ret = 0, rtable = *rtablep;
+ in_port_t port = *portp;
+ struct socket *so4;
+#ifdef INET6
+ struct socket *so6;
+ int retries = 0;
+retry:
+#endif
+ if ((ret = wg_socket_open(&so4, AF_INET, &port, &rtable, sc)) != 0)
+ return ret;
+
+#ifdef INET6
+ if ((ret = wg_socket_open(&so6, AF_INET6, &port, &rtable, sc)) != 0) {
+ if (ret == EADDRINUSE && *portp == 0 && retries++ < 100)
+ goto retry;
+ wg_socket_close(&so4);
+ return ret;
+ }
+#endif
+
+ rw_enter_write(&sc->sc_so_lock);
+ wg_socket_close(&sc->sc_so4);
+ sc->sc_so4 = so4;
+#ifdef INET6
+ wg_socket_close(&sc->sc_so6);
+ sc->sc_so6 = so6;
+#endif
+ rw_exit_write(&sc->sc_so_lock);
+
+ *portp = port;
+ *rtablep = rtable;
+ return 0;
+}
+
+void
+wg_unbind(struct wg_softc *sc)
+{
+ rw_enter_write(&sc->sc_so_lock);
+ wg_socket_close(&sc->sc_so4);
+#ifdef INET6
+ wg_socket_close(&sc->sc_so6);
+#endif
+ rw_exit_write(&sc->sc_so_lock);
+}
+
+int
+wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m)
+{
+ struct mbuf peernam, *control = NULL;
+ int ret;
+
+ /* Get local control address before locking */
+ if (e->e_remote.r_sa.sa_family == AF_INET) {
+ if (e->e_local.l_in.s_addr != INADDR_ANY)
+ control = sbcreatecontrol(&e->e_local.l_in,
+ sizeof(struct in_addr), IP_SENDSRCADDR,
+ IPPROTO_IP);
+#ifdef INET6
+ } else if (e->e_remote.r_sa.sa_family == AF_INET6) {
+ if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6))
+ control = sbcreatecontrol(&e->e_local.l_pktinfo6,
+ sizeof(struct in6_pktinfo), IPV6_PKTINFO,
+ IPPROTO_IPV6);
+#endif
+ } else {
+ return EAFNOSUPPORT;
+ }
+
+ /* Get remote address */
+ peernam.m_type = MT_SONAME;
+ peernam.m_next = NULL;
+ peernam.m_nextpkt = NULL;
+ peernam.m_data = (void *)&e->e_remote.r_sa;
+ peernam.m_len = e->e_remote.r_sa.sa_len;
+ peernam.m_flags = 0;
+
+ rw_enter_read(&sc->sc_so_lock);
+ if (e->e_remote.r_sa.sa_family == AF_INET && sc->sc_so4 != NULL)
+ ret = sosend(sc->sc_so4, &peernam, NULL, m, control, 0);
+#ifdef INET6
+ else if (e->e_remote.r_sa.sa_family == AF_INET6 && sc->sc_so6 != NULL)
+ ret = sosend(sc->sc_so6, &peernam, NULL, m, control, 0);
+#endif
+ else {
+ ret = ENOTCONN;
+ m_freem(control);
+ m_freem(m);
+ }
+ rw_exit_read(&sc->sc_so_lock);
+
+ return ret;
+}
+
+void
+wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf,
+ size_t len)
+{
+ struct mbuf *m;
+ int ret = 0;
+
+retry:
+ m = m_gethdr(M_WAIT, MT_DATA);
+ m->m_len = 0;
+ m_copyback(m, 0, len, buf, M_WAIT);
+
+ /* As we're sending a handshake packet here, we want high priority */
+ m->m_pkthdr.pf.prio = IFQ_MAXPRIO;
+
+ if (ret == 0) {
+ ret = wg_send(sc, e, m);
+ /* Retry if we couldn't bind to e->e_local */
+ if (ret == EADDRNOTAVAIL) {
+ bzero(&e->e_local, sizeof(e->e_local));
+ goto retry;
+ }
+ } else {
+ ret = wg_send(sc, e, m);
+ if (ret != 0)
+ DPRINTF(sc, "Unable to send packet\n");
+ }
+}
+
+struct wg_tag *
+wg_tag_get(struct mbuf *m)
+{
+ struct m_tag *mtag;
+
+ if ((mtag = m_tag_find(m, PACKET_TAG_WIREGUARD, NULL)) == NULL) {
+ mtag = m_tag_get(PACKET_TAG_WIREGUARD, sizeof(struct wg_tag),
+ M_NOWAIT);
+ if (mtag == NULL)
+ return (NULL);
+ bzero(mtag + 1, sizeof(struct wg_tag));
+ m_tag_prepend(m, mtag);
+ }
+ return ((struct wg_tag *)(mtag + 1));
+}
+
+/*
+ * The following section handles the timeout callbacks for a WireGuard session.
+ * These functions provide an "event based" model for controling wg(8) session
+ * timers. All function calls occur after the specified event below.
+ *
+ * wg_timers_event_data_sent:
+ * tx: data
+ * wg_timers_event_data_received:
+ * rx: data
+ * wg_timers_event_any_authenticated_packet_sent:
+ * tx: keepalive, data, handshake
+ * wg_timers_event_any_authenticated_packet_received:
+ * rx: keepalive, data, handshake
+ * wg_timers_event_any_authenticated_packet_traversal:
+ * tx, rx: keepalive, data, handshake
+ * wg_timers_event_handshake_initiated:
+ * tx: initiation
+ * wg_timers_event_handshake_responded:
+ * tx: response
+ * wg_timers_event_handshake_complete:
+ * rx: response, confirmation data
+ * wg_timers_event_session_derived:
+ * tx: response, rx: response
+ * wg_timers_event_want_initiation:
+ * tx: data failed, old keys expiring
+ * wg_timers_event_reset_handshake_last_sent:
+ * anytime we may immediately want a new handshake
+ */
+void
+wg_timers_init(struct wg_timers *t)
+{
+ bzero(t, sizeof(*t));
+ rw_init(&t->t_lock, "wg_timers");
+ mtx_init(&t->t_handshake_mtx, IPL_NET);
+
+ timeout_set(&t->t_new_handshake, wg_timers_run_new_handshake, t);
+ timeout_set(&t->t_send_keepalive, wg_timers_run_send_keepalive, t);
+ timeout_set(&t->t_retry_handshake, wg_timers_run_retry_handshake, t);
+ timeout_set(&t->t_persistent_keepalive,
+ wg_timers_run_persistent_keepalive, t);
+ timeout_set(&t->t_zero_key_material,
+ wg_timers_run_zero_key_material, t);
+}
+
+void
+wg_timers_enable(struct wg_timers *t)
+{
+ rw_enter_write(&t->t_lock);
+ t->t_disabled = 0;
+ rw_exit_write(&t->t_lock);
+ wg_timers_run_persistent_keepalive(t);
+}
+
+void
+wg_timers_disable(struct wg_timers *t)
+{
+ rw_enter_write(&t->t_lock);
+ t->t_disabled = 1;
+ t->t_need_another_keepalive = 0;
+ rw_exit_write(&t->t_lock);
+
+ timeout_del_barrier(&t->t_new_handshake);
+ timeout_del_barrier(&t->t_send_keepalive);
+ timeout_del_barrier(&t->t_retry_handshake);
+ timeout_del_barrier(&t->t_persistent_keepalive);
+ timeout_del_barrier(&t->t_zero_key_material);
+}
+
+void
+wg_timers_set_persistent_keepalive(struct wg_timers *t, uint16_t interval)
+{
+ rw_enter_read(&t->t_lock);
+ if (!t->t_disabled) {
+ t->t_persistent_keepalive_interval = interval;
+ wg_timers_run_persistent_keepalive(t);
+ }
+ rw_exit_read(&t->t_lock);
+}
+
+int
+wg_timers_get_persistent_keepalive(struct wg_timers *t, uint16_t *interval)
+{
+ *interval = t->t_persistent_keepalive_interval;
+ return *interval > 0 ? 0 : ENOENT;
+}
+
+void
+wg_timers_get_last_handshake(struct wg_timers *t, struct timespec *time)
+{
+ mtx_enter(&t->t_handshake_mtx);
+ *time = t->t_handshake_complete;
+ mtx_leave(&t->t_handshake_mtx);
+}
+
+int
+wg_timers_expired_handshake_last_sent(struct wg_timers *t)
+{
+ struct timespec uptime;
+ struct timespec expire = { .tv_sec = REKEY_TIMEOUT, .tv_nsec = 0 };
+
+ getnanouptime(&uptime);
+ timespecadd(&t->t_handshake_last_sent, &expire, &expire);
+ return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0;
+}
+
+int
+wg_timers_check_handshake_last_sent(struct wg_timers *t)
+{
+ int ret;
+ mtx_enter(&t->t_handshake_mtx);
+ if ((ret = wg_timers_expired_handshake_last_sent(t)) == ETIMEDOUT)
+ getnanouptime(&t->t_handshake_last_sent);
+ mtx_leave(&t->t_handshake_mtx);
+ return ret;
+}
+
+void
+wg_timers_event_data_sent(struct wg_timers *t)
+{
+ int msecs = NEW_HANDSHAKE_TIMEOUT * 1000;
+ msecs += arc4random_uniform(REKEY_TIMEOUT_JITTER);
+
+ rw_enter_read(&t->t_lock);
+ if (!t->t_disabled && !timeout_pending(&t->t_new_handshake))
+ timeout_add_msec(&t->t_new_handshake, msecs);
+ rw_exit_read(&t->t_lock);
+}
+
+void
+wg_timers_event_data_received(struct wg_timers *t)
+{
+ rw_enter_read(&t->t_lock);
+ if (!t->t_disabled) {
+ if (!timeout_pending(&t->t_send_keepalive))
+ timeout_add_sec(&t->t_send_keepalive,
+ KEEPALIVE_TIMEOUT);
+ else
+ t->t_need_another_keepalive = 1;
+ }
+ rw_exit_read(&t->t_lock);
+}
+
+void
+wg_timers_event_any_authenticated_packet_sent(struct wg_timers *t)
+{
+ timeout_del(&t->t_send_keepalive);
+}
+
+void
+wg_timers_event_any_authenticated_packet_received(struct wg_timers *t)
+{
+ timeout_del(&t->t_new_handshake);
+}
+
+void
+wg_timers_event_any_authenticated_packet_traversal(struct wg_timers *t)
+{
+ rw_enter_read(&t->t_lock);
+ if (!t->t_disabled && t->t_persistent_keepalive_interval > 0)
+ timeout_add_sec(&t->t_persistent_keepalive,
+ t->t_persistent_keepalive_interval);
+ rw_exit_read(&t->t_lock);
+}
+
+void
+wg_timers_event_handshake_initiated(struct wg_timers *t)
+{
+ int msecs = REKEY_TIMEOUT * 1000;
+ msecs += arc4random_uniform(REKEY_TIMEOUT_JITTER);
+
+ rw_enter_read(&t->t_lock);
+ if (!t->t_disabled)
+ timeout_add_msec(&t->t_retry_handshake, msecs);
+ rw_exit_read(&t->t_lock);
+}
+
+void
+wg_timers_event_handshake_responded(struct wg_timers *t)
+{
+ mtx_enter(&t->t_handshake_mtx);
+ getnanouptime(&t->t_handshake_last_sent);
+ mtx_leave(&t->t_handshake_mtx);
+}
+
+void
+wg_timers_event_handshake_complete(struct wg_timers *t)
+{
+ rw_enter_read(&t->t_lock);
+ if (!t->t_disabled) {
+ mtx_enter(&t->t_handshake_mtx);
+ timeout_del(&t->t_retry_handshake);
+ t->t_handshake_retries = 0;
+ getnanotime(&t->t_handshake_complete);
+ mtx_leave(&t->t_handshake_mtx);
+ wg_timers_run_send_keepalive(t);
+ }
+ rw_exit_read(&t->t_lock);
+}
+
+void
+wg_timers_event_session_derived(struct wg_timers *t)
+{
+ rw_enter_read(&t->t_lock);
+ if (!t->t_disabled)
+ timeout_add_sec(&t->t_zero_key_material, REJECT_AFTER_TIME * 3);
+ rw_exit_read(&t->t_lock);
+}
+
+void
+wg_timers_event_want_initiation(struct wg_timers *t)
+{
+ rw_enter_read(&t->t_lock);
+ if (!t->t_disabled)
+ wg_timers_run_send_initiation(t, 0);
+ rw_exit_read(&t->t_lock);
+}
+
+void
+wg_timers_event_reset_handshake_last_sent(struct wg_timers *t)
+{
+ mtx_enter(&t->t_handshake_mtx);
+ t->t_handshake_last_sent.tv_sec -= (REKEY_TIMEOUT + 1);
+ mtx_leave(&t->t_handshake_mtx);
+}
+
+void
+wg_timers_run_send_initiation(void *_t, int is_retry)
+{
+ struct wg_timers *t = _t;
+ struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers);
+ if (!is_retry)
+ t->t_handshake_retries = 0;
+ if (wg_timers_expired_handshake_last_sent(t) == ETIMEDOUT)
+ task_add(wg_handshake_taskq, &peer->p_send_initiation);
+}
+
+void
+wg_timers_run_retry_handshake(void *_t)
+{
+ struct wg_timers *t = _t;
+ struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers);
+
+ mtx_enter(&t->t_handshake_mtx);
+ if (t->t_handshake_retries <= MAX_TIMER_HANDSHAKES) {
+ t->t_handshake_retries++;
+ mtx_leave(&t->t_handshake_mtx);
+
+ DPRINTF(peer->p_sc, "Handshake for peer %llu did not complete "
+ "after %d seconds, retrying (try %d)\n", peer->p_id,
+ REKEY_TIMEOUT, t->t_handshake_retries + 1);
+ wg_peer_clear_src(peer);
+ wg_timers_run_send_initiation(t, 1);
+ } else {
+ mtx_leave(&t->t_handshake_mtx);
+
+ DPRINTF(peer->p_sc, "Handshake for peer %llu did not complete "
+ "after %d retries, giving up\n", peer->p_id,
+ MAX_TIMER_HANDSHAKES + 2);
+
+ timeout_del(&t->t_send_keepalive);
+ mq_purge(&peer->p_stage_queue);
+ if (!timeout_pending(&t->t_zero_key_material))
+ timeout_add_sec(&t->t_zero_key_material,
+ REJECT_AFTER_TIME * 3);
+ }
+}
+
+void
+wg_timers_run_send_keepalive(void *_t)
+{
+ struct wg_timers *t = _t;
+ struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers);
+
+ task_add(wg_crypt_taskq, &peer->p_send_keepalive);
+ if (t->t_need_another_keepalive) {
+ t->t_need_another_keepalive = 0;
+ timeout_add_sec(&t->t_send_keepalive, KEEPALIVE_TIMEOUT);
+ }
+}
+
+void
+wg_timers_run_new_handshake(void *_t)
+{
+ struct wg_timers *t = _t;
+ struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers);
+
+ DPRINTF(peer->p_sc, "Retrying handshake with peer %llu because we "
+ "stopped hearing back after %d seconds\n",
+ peer->p_id, NEW_HANDSHAKE_TIMEOUT);
+ wg_peer_clear_src(peer);
+
+ wg_timers_run_send_initiation(t, 0);
+}
+
+void
+wg_timers_run_zero_key_material(void *_t)
+{
+ struct wg_timers *t = _t;
+ struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers);
+
+ DPRINTF(peer->p_sc, "Zeroing out keys for peer %llu\n", peer->p_id);
+ task_add(wg_handshake_taskq, &peer->p_clear_secrets);
+}
+
+void
+wg_timers_run_persistent_keepalive(void *_t)
+{
+ struct wg_timers *t = _t;
+ struct wg_peer *peer = CONTAINER_OF(t, struct wg_peer, p_timers);
+ if (t->t_persistent_keepalive_interval != 0)
+ task_add(wg_crypt_taskq, &peer->p_send_keepalive);
+}
+
+/* The following functions handle handshakes */
+void
+wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len)
+{
+ struct wg_endpoint endpoint;
+
+ wg_peer_counters_add(peer, len, 0);
+ wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers);
+ wg_timers_event_any_authenticated_packet_sent(&peer->p_timers);
+ wg_peer_get_endpoint(peer, &endpoint);
+ wg_send_buf(peer->p_sc, &endpoint, buf, len);
+}
+
+void
+wg_send_initiation(void *_peer)
+{
+ struct wg_peer *peer = _peer;
+ struct wg_pkt_initiation pkt;
+
+ if (wg_timers_check_handshake_last_sent(&peer->p_timers) != ETIMEDOUT)
+ return;
+
+ DPRINTF(peer->p_sc, "Sending handshake initiation to peer %llu\n",
+ peer->p_id);
+
+ if (noise_create_initiation(&peer->p_remote, &pkt.s_idx, pkt.ue, pkt.es,
+ pkt.ets) != 0)
+ return;
+ pkt.t = WG_PKT_INITIATION;
+ cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
+ sizeof(pkt)-sizeof(pkt.m));
+ wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
+ wg_timers_event_handshake_initiated(&peer->p_timers);
+}
+
+void
+wg_send_response(struct wg_peer *peer)
+{
+ struct wg_pkt_response pkt;
+
+ DPRINTF(peer->p_sc, "Sending handshake response to peer %llu\n",
+ peer->p_id);
+
+ if (noise_create_response(&peer->p_remote, &pkt.s_idx, &pkt.r_idx,
+ pkt.ue, pkt.en) != 0)
+ return;
+ if (noise_remote_begin_session(&peer->p_remote) != 0)
+ return;
+ wg_timers_event_session_derived(&peer->p_timers);
+ pkt.t = WG_PKT_RESPONSE;
+ cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
+ sizeof(pkt)-sizeof(pkt.m));
+ wg_timers_event_handshake_responded(&peer->p_timers);
+ wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
+}
+
+void
+wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx,
+ struct wg_endpoint *e)
+{
+ struct wg_pkt_cookie pkt;
+
+ DPRINTF(sc, "Sending cookie response for denied handshake message\n");
+
+ pkt.t = WG_PKT_COOKIE;
+ pkt.r_idx = idx;
+
+ cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce,
+ pkt.ec, &e->e_remote.r_sa);
+
+ wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt));
+}
+
+void
+wg_send_keepalive(void *_peer)
+{
+ struct wg_peer *peer = _peer;
+ struct wg_softc *sc = peer->p_sc;
+ struct wg_tag *t;
+ struct mbuf *m;
+
+ if (!mq_empty(&peer->p_stage_queue))
+ goto send;
+
+ if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL)
+ return;
+
+ if ((t = wg_tag_get(m)) == NULL) {
+ m_freem(m);
+ return;
+ }
+
+ m->m_len = 0;
+ m_calchdrlen(m);
+
+ t->t_peer = peer;
+ t->t_mbuf = NULL;
+ t->t_done = 0;
+ t->t_mtu = 0; /* MTU == 0 OK for keepalive */
+
+ mq_push(&peer->p_stage_queue, m);
+send:
+ if (noise_remote_ready(&peer->p_remote) == 0) {
+ wg_queue_out(sc, peer);
+ task_add(wg_crypt_taskq, &sc->sc_encap);
+ } else {
+ wg_timers_event_want_initiation(&peer->p_timers);
+ }
+}
+
+void
+wg_peer_clear_secrets(void *_peer)
+{
+ struct wg_peer *peer = _peer;
+ noise_remote_clear(&peer->p_remote);
+}
+
+void
+wg_handshake(struct wg_softc *sc, struct mbuf *m)
+{
+ struct wg_tag *t;
+ struct wg_pkt_initiation *init;
+ struct wg_pkt_response *resp;
+ struct wg_pkt_cookie *cook;
+ struct wg_peer *peer;
+ struct noise_remote *remote;
+ int res, underload = 0;
+ static struct timeval wg_last_underload; /* microuptime */
+
+ if (mq_len(&sc->sc_handshake_queue) >= MAX_QUEUED_HANDSHAKES/8) {
+ getmicrouptime(&wg_last_underload);
+ underload = 1;
+ } else if (wg_last_underload.tv_sec != 0) {
+ if (!ratecheck(&wg_last_underload, &underload_interval))
+ underload = 1;
+ else
+ bzero(&wg_last_underload, sizeof(wg_last_underload));
+ }
+
+ t = wg_tag_get(m);
+
+ switch (*mtod(m, uint32_t *)) {
+ case WG_PKT_INITIATION:
+ init = mtod(m, struct wg_pkt_initiation *);
+
+ res = cookie_checker_validate_macs(&sc->sc_cookie, &init->m,
+ init, sizeof(*init) - sizeof(init->m),
+ underload, &t->t_endpoint.e_remote.r_sa);
+
+ if (res == EINVAL) {
+ DPRINTF(sc, "Invalid initiation MAC\n");
+ goto error;
+ } else if (res == ECONNREFUSED) {
+ DPRINTF(sc, "Handshake ratelimited\n");
+ goto error;
+ } else if (res == EAGAIN) {
+ wg_send_cookie(sc, &init->m, init->s_idx,
+ &t->t_endpoint);
+ goto error;
+ } else if (res != 0) {
+ panic("unexpected response: %d\n", res);
+ }
+
+ if (noise_consume_initiation(&sc->sc_local, &remote,
+ init->s_idx, init->ue, init->es, init->ets) != 0) {
+ DPRINTF(sc, "Invalid handshake initiation\n");
+ goto error;
+ }
+
+ peer = CONTAINER_OF(remote, struct wg_peer, p_remote);
+
+ DPRINTF(sc, "Receiving handshake initiation from peer %llu\n",
+ peer->p_id);
+
+ wg_peer_counters_add(peer, 0, sizeof(*init));
+ wg_peer_set_endpoint_from_tag(peer, t);
+ wg_send_response(peer);
+ break;
+ case WG_PKT_RESPONSE:
+ resp = mtod(m, struct wg_pkt_response *);
+
+ res = cookie_checker_validate_macs(&sc->sc_cookie, &resp->m,
+ resp, sizeof(*resp) - sizeof(resp->m),
+ underload, &t->t_endpoint.e_remote.r_sa);
+
+ if (res == EINVAL) {
+ DPRINTF(sc, "Invalid response MAC\n");
+ goto error;
+ } else if (res == ECONNREFUSED) {
+ DPRINTF(sc, "Handshake ratelimited\n");
+ goto error;
+ } else if (res == EAGAIN) {
+ wg_send_cookie(sc, &resp->m, resp->s_idx,
+ &t->t_endpoint);
+ goto error;
+ } else if (res != 0) {
+ panic("unexpected response: %d\n", res);
+ }
+
+ if ((remote = wg_index_get(sc, resp->r_idx)) == NULL) {
+ DPRINTF(sc, "Unknown handshake response\n");
+ goto error;
+ }
+
+ peer = CONTAINER_OF(remote, struct wg_peer, p_remote);
+
+ if (noise_consume_response(remote, resp->s_idx, resp->r_idx,
+ resp->ue, resp->en) != 0) {
+ DPRINTF(sc, "Invalid handshake response\n");
+ goto error;
+ }
+
+ DPRINTF(sc, "Receiving handshake response from peer %llu\n",
+ peer->p_id);
+
+ wg_peer_counters_add(peer, 0, sizeof(*resp));
+ wg_peer_set_endpoint_from_tag(peer, t);
+ if (noise_remote_begin_session(&peer->p_remote) == 0) {
+ wg_timers_event_session_derived(&peer->p_timers);
+ wg_timers_event_handshake_complete(&peer->p_timers);
+ }
+ break;
+ case WG_PKT_COOKIE:
+ cook = mtod(m, struct wg_pkt_cookie *);
+
+ if ((remote = wg_index_get(sc, cook->r_idx)) == NULL) {
+ DPRINTF(sc, "Unknown cookie index\n");
+ goto error;
+ }
+
+ peer = CONTAINER_OF(remote, struct wg_peer, p_remote);
+
+ if (cookie_maker_consume_payload(&peer->p_cookie,
+ cook->nonce, cook->ec) != 0) {
+ DPRINTF(sc, "Could not decrypt cookie response\n");
+ goto error;
+ }
+
+ DPRINTF(sc, "Receiving cookie response\n");
+ goto error;
+ default:
+ panic("invalid packet in handshake queue");
+ }
+
+ wg_timers_event_any_authenticated_packet_received(&peer->p_timers);
+ wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers);
+error:
+ m_freem(m);
+}
+
+void
+wg_handshake_worker(void *_sc)
+{
+ struct mbuf *m;
+ struct wg_softc *sc = _sc;
+ while ((m = mq_dequeue(&sc->sc_handshake_queue)) != NULL)
+ wg_handshake(sc, m);
+}
+
+/*
+ * The following functions handle encapsulation (encryption) and
+ * decapsulation (decryption). The wg_{en,de}cap functions will run in the
+ * sc_crypt_taskq, while wg_deliver_{in,out} must be serialised and will run
+ * in nettq.
+ *
+ * The packets are tracked in two queues, a serial queue and a parallel queue.
+ * - The parallel queue is used to distribute the encryption across multiple
+ * threads.
+ * - The serial queue ensures that packets are not reordered and are
+ * delievered in sequence.
+ * The wg_tag attached to the packet contains two flags to help the two queues
+ * interact.
+ * - t_done: The parallel queue has finished with the packet, now the serial
+ * queue can do it's work.
+ * - t_mbuf: Used to store the *crypted packet. in the case of encryption,
+ * this is a newly allocated packet, and in the case of decryption,
+ * it is a pointer to the same packet, that has been decrypted and
+ * truncated. If t_mbuf is NULL, then *cryption failed and this
+ * packet should not be passed.
+ * wg_{en,de}cap work on the parallel queue, while wg_deliver_{in,out} work
+ * on the serial queue.
+ */
+void
+wg_encap(struct wg_softc *sc, struct mbuf *m)
+{
+ int res = 0;
+ struct wg_pkt_data *data;
+ struct wg_peer *peer;
+ struct wg_tag *t;
+ struct mbuf *mc;
+ size_t padding_len, plaintext_len, out_len;
+ uint64_t nonce;
+
+ t = wg_tag_get(m);
+ peer = t->t_peer;
+
+ plaintext_len = min(WG_PKT_WITH_PADDING(m->m_pkthdr.len), t->t_mtu);
+ padding_len = plaintext_len - m->m_pkthdr.len;
+ out_len = sizeof(struct wg_pkt_data) + plaintext_len + NOISE_AUTHTAG_LEN;
+
+ /*
+ * For the time being we allocate a new packet with sufficient size to
+ * hold the encrypted data and headers. It would be difficult to
+ * overcome as p_encap_queue (mbuf_list) holds a reference to the mbuf.
+ * If we m_makespace or similar, we risk corrupting that list.
+ * Additionally, we only pass a buf and buf length to
+ * noise_remote_encrypt. Technically it would be possible to teach
+ * noise_remote_encrypt about mbufs, but we would need to sort out the
+ * p_encap_queue situation first.
+ */
+ if ((mc = m_clget(NULL, M_NOWAIT, out_len)) == NULL)
+ goto error;
+
+ data = mtod(mc, struct wg_pkt_data *);
+ m_copydata(m, 0, m->m_pkthdr.len, data->buf);
+ bzero(data->buf + m->m_pkthdr.len, padding_len);
+ data->t = WG_PKT_DATA;
+
+ /*
+ * Copy the flow hash from the inner packet to the outer packet, so
+ * that fq_codel can property separate streams, rather than falling
+ * back to random buckets.
+ */
+ mc->m_pkthdr.ph_flowid = m->m_pkthdr.ph_flowid;
+
+ res = noise_remote_encrypt(&peer->p_remote, &data->r_idx, &nonce,
+ data->buf, plaintext_len);
+ nonce = htole64(nonce); /* Wire format is little endian. */
+ memcpy(data->nonce, &nonce, sizeof(data->nonce));
+
+ if (__predict_false(res == EINVAL)) {
+ m_freem(mc);
+ goto error;
+ } else if (__predict_false(res == ESTALE)) {
+ wg_timers_event_want_initiation(&peer->p_timers);
+ } else if (__predict_false(res != 0)) {
+ panic("unexpected result: %d\n", res);
+ }
+
+ /* A packet with length 0 is a keepalive packet */
+ if (__predict_false(m->m_pkthdr.len == 0))
+ DPRINTF(sc, "Sending keepalive packet to peer %llu\n",
+ peer->p_id);
+
+ mc->m_pkthdr.ph_loopcnt = m->m_pkthdr.ph_loopcnt;
+ mc->m_flags &= ~(M_MCAST | M_BCAST);
+ mc->m_len = out_len;
+ m_calchdrlen(mc);
+
+ /*
+ * We would count ifc_opackets, ifc_obytes of m here, except if_snd
+ * already does that for us, so no need to worry about it.
+ counters_pkt(sc->sc_if.if_counters, ifc_opackets, ifc_obytes,
+ m->m_pkthdr.len);
+ */
+ wg_peer_counters_add(peer, mc->m_pkthdr.len, 0);
+
+ t->t_mbuf = mc;
+error:
+ t->t_done = 1;
+ task_add(net_tq(sc->sc_if.if_index), &peer->p_deliver_out);
+}
+
+void
+wg_decap(struct wg_softc *sc, struct mbuf *m)
+{
+ int res, len;
+ struct ip *ip;
+ struct ip6_hdr *ip6;
+ struct wg_pkt_data *data;
+ struct wg_peer *peer, *allowed_peer;
+ struct wg_tag *t;
+ size_t payload_len;
+ uint64_t nonce;
+
+ t = wg_tag_get(m);
+ peer = t->t_peer;
+
+ /*
+ * Likewise to wg_encap, we pass a buf and buf length to
+ * noise_remote_decrypt. Again, possible to teach it about mbufs
+ * but need to get over the p_decap_queue situation first. However,
+ * we do not need to allocate a new mbuf as the decrypted packet is
+ * strictly smaller than encrypted. We just set t_mbuf to m and
+ * wg_deliver_in knows how to deal with that.
+ */
+ data = mtod(m, struct wg_pkt_data *);
+ payload_len = m->m_pkthdr.len - sizeof(struct wg_pkt_data);
+ memcpy(&nonce, data->nonce, sizeof(nonce));
+ nonce = le64toh(nonce); /* Wire format is little endian. */
+ res = noise_remote_decrypt(&peer->p_remote, data->r_idx, nonce,
+ data->buf, payload_len);
+
+ if (__predict_false(res == EINVAL)) {
+ goto error;
+ } else if (__predict_false(res == ECONNRESET)) {
+ wg_timers_event_handshake_complete(&peer->p_timers);
+ } else if (__predict_false(res == ESTALE)) {
+ wg_timers_event_want_initiation(&peer->p_timers);
+ } else if (__predict_false(res != 0)) {
+ panic("unexpected response: %d\n", res);
+ }
+
+ wg_peer_set_endpoint_from_tag(peer, t);
+
+ wg_peer_counters_add(peer, 0, m->m_pkthdr.len);
+
+ m_adj(m, sizeof(struct wg_pkt_data));
+ m_adj(m, -NOISE_AUTHTAG_LEN);
+
+ counters_pkt(sc->sc_if.if_counters, ifc_ipackets, ifc_ibytes,
+ m->m_pkthdr.len);
+
+ /* A packet with length 0 is a keepalive packet */
+ if (__predict_false(m->m_pkthdr.len == 0)) {
+ DPRINTF(sc, "Receiving keepalive packet from peer "
+ "%llu\n", peer->p_id);
+ goto done;
+ }
+
+ /*
+ * We can let the network stack handle the intricate validation of the
+ * IP header, we just worry about the sizeof and the version, so we can
+ * read the source address in wg_aip_lookup.
+ *
+ * We also need to trim the packet, as it was likely paddded before
+ * encryption. While we could drop it here, it will be more helpful to
+ * pass it to bpf_mtap and use the counters that people are expecting
+ * in ipv4_input and ipv6_input. We can rely on ipv4_input and
+ * ipv6_input to properly validate the headers.
+ */
+ ip = mtod(m, struct ip *);
+ ip6 = mtod(m, struct ip6_hdr *);
+
+ if (m->m_pkthdr.len >= sizeof(struct ip) && ip->ip_v == IPVERSION) {
+ m->m_pkthdr.ph_family = AF_INET;
+
+ len = ntohs(ip->ip_len);
+ if (len >= sizeof(struct ip) && len < m->m_pkthdr.len)
+ m_adj(m, len - m->m_pkthdr.len);
+
+ allowed_peer = wg_aip_lookup(sc->sc_aip4, &ip->ip_src);
+#ifdef INET6
+ } else if (m->m_pkthdr.len >= sizeof(struct ip6_hdr) &&
+ (ip6->ip6_vfc & IPV6_VERSION_MASK) == IPV6_VERSION) {
+ m->m_pkthdr.ph_family = AF_INET6;
+
+ len = ntohs(ip6->ip6_plen) + sizeof(struct ip6_hdr);
+ if (len < m->m_pkthdr.len)
+ m_adj(m, len - m->m_pkthdr.len);
+
+ allowed_peer = wg_aip_lookup(sc->sc_aip6, &ip6->ip6_src);
+#endif
+ } else {
+ DPRINTF(sc, "Packet is neither ipv4 nor ipv6 from "
+ "peer %llu\n", peer->p_id);
+ goto error;
+ }
+
+ if (__predict_false(peer != allowed_peer)) {
+ DPRINTF(sc, "Packet has unallowed src IP from peer "
+ "%llu\n", peer->p_id);
+ goto error;
+ }
+
+ /*
+ * We can mark incoming packet csum OK. We mark all flags OK
+ * irrespective to the packet type.
+ */
+ m->m_pkthdr.csum_flags |= (M_IPV4_CSUM_IN_OK | M_TCP_CSUM_IN_OK |
+ M_UDP_CSUM_IN_OK | M_ICMP_CSUM_IN_OK);
+ m->m_pkthdr.csum_flags &= ~(M_IPV4_CSUM_IN_BAD | M_TCP_CSUM_IN_BAD |
+ M_UDP_CSUM_IN_BAD | M_ICMP_CSUM_IN_BAD);
+
+ m->m_pkthdr.ph_ifidx = sc->sc_if.if_index;
+ m->m_pkthdr.ph_rtableid = sc->sc_if.if_rdomain;
+ m->m_flags &= ~(M_MCAST | M_BCAST);
+ pf_pkt_addr_changed(m);
+
+done:
+ t->t_mbuf = m;
+error:
+ t->t_done = 1;
+ task_add(net_tq(sc->sc_if.if_index), &peer->p_deliver_in);
+}
+
+void
+wg_encap_worker(void *_sc)
+{
+ struct mbuf *m;
+ struct wg_softc *sc = _sc;
+ while ((m = wg_ring_dequeue(&sc->sc_encap_ring)) != NULL)
+ wg_encap(sc, m);
+}
+
+void
+wg_decap_worker(void *_sc)
+{
+ struct mbuf *m;
+ struct wg_softc *sc = _sc;
+ while ((m = wg_ring_dequeue(&sc->sc_decap_ring)) != NULL)
+ wg_decap(sc, m);
+}
+
+void
+wg_deliver_out(void *_peer)
+{
+ struct wg_peer *peer = _peer;
+ struct wg_softc *sc = peer->p_sc;
+ struct wg_endpoint endpoint;
+ struct wg_tag *t;
+ struct mbuf *m;
+ int ret;
+
+ wg_peer_get_endpoint(peer, &endpoint);
+
+ while ((m = wg_queue_dequeue(&peer->p_encap_queue, &t)) != NULL) {
+ /* t_mbuf will contain the encrypted packet */
+ if (t->t_mbuf == NULL){
+ counters_inc(sc->sc_if.if_counters, ifc_oerrors);
+ m_freem(m);
+ continue;
+ }
+
+ ret = wg_send(sc, &endpoint, t->t_mbuf);
+
+ if (ret == 0) {
+ wg_timers_event_any_authenticated_packet_traversal(
+ &peer->p_timers);
+ wg_timers_event_any_authenticated_packet_sent(
+ &peer->p_timers);
+
+ if (m->m_pkthdr.len != 0)
+ wg_timers_event_data_sent(&peer->p_timers);
+ } else if (ret == EADDRNOTAVAIL) {
+ wg_peer_clear_src(peer);
+ wg_peer_get_endpoint(peer, &endpoint);
+ }
+
+ m_freem(m);
+ }
+}
+
+void
+wg_deliver_in(void *_peer)
+{
+ struct wg_peer *peer = _peer;
+ struct wg_softc *sc = peer->p_sc;
+ struct wg_tag *t;
+ struct mbuf *m;
+
+ while ((m = wg_queue_dequeue(&peer->p_decap_queue, &t)) != NULL) {
+ /* t_mbuf will contain the decrypted packet */
+ if (t->t_mbuf == NULL) {
+ counters_inc(sc->sc_if.if_counters, ifc_ierrors);
+ m_freem(m);
+ continue;
+ }
+
+ /* From here on m == t->t_mbuf */
+ KASSERT(m == t->t_mbuf);
+
+ wg_timers_event_any_authenticated_packet_received(
+ &peer->p_timers);
+ wg_timers_event_any_authenticated_packet_traversal(
+ &peer->p_timers);
+
+ if (m->m_pkthdr.len == 0) {
+ m_freem(m);
+ continue;
+ }
+
+#if NBPFILTER > 0
+ if (sc->sc_if.if_bpf != NULL)
+ bpf_mtap_af(sc->sc_if.if_bpf,
+ m->m_pkthdr.ph_family, m, BPF_DIRECTION_IN);
+#endif
+
+ NET_LOCK();
+ if (m->m_pkthdr.ph_family == AF_INET)
+ ipv4_input(&sc->sc_if, m);
+#ifdef INET6
+ else if (m->m_pkthdr.ph_family == AF_INET6)
+ ipv6_input(&sc->sc_if, m);
+#endif
+ else
+ panic("invalid ph_family");
+ NET_UNLOCK();
+
+ wg_timers_event_data_received(&peer->p_timers);
+ }
+}
+
+int
+wg_queue_in(struct wg_softc *sc, struct wg_peer *peer, struct mbuf *m)
+{
+ struct wg_ring *parallel = &sc->sc_decap_ring;
+ struct wg_queue *serial = &peer->p_decap_queue;
+ struct wg_tag *t;
+
+ mtx_enter(&serial->q_mtx);
+ if (serial->q_list.ml_len < MAX_QUEUED_PKT) {
+ ml_enqueue(&serial->q_list, m);
+ mtx_leave(&serial->q_mtx);
+ } else {
+ mtx_leave(&serial->q_mtx);
+ m_freem(m);
+ return ENOBUFS;
+ }
+
+ mtx_enter(&parallel->r_mtx);
+ if (parallel->r_tail - parallel->r_head < MAX_QUEUED_PKT) {
+ parallel->r_buf[parallel->r_tail & MAX_QUEUED_PKT_MASK] = m;
+ parallel->r_tail++;
+ mtx_leave(&parallel->r_mtx);
+ } else {
+ mtx_leave(&parallel->r_mtx);
+ t = wg_tag_get(m);
+ t->t_done = 1;
+ return ENOBUFS;
+ }
+
+ return 0;
+}
+
+void
+wg_queue_out(struct wg_softc *sc, struct wg_peer *peer)
+{
+ struct wg_ring *parallel = &sc->sc_encap_ring;
+ struct wg_queue *serial = &peer->p_encap_queue;
+ struct mbuf_list ml, ml_free;
+ struct mbuf *m;
+ struct wg_tag *t;
+ int dropped;
+
+ /*
+ * We delist all staged packets and then add them to the queues. This
+ * can race with wg_start when called from wg_send_keepalive, however
+ * wg_start will not race as it is serialised.
+ */
+ mq_delist(&peer->p_stage_queue, &ml);
+ ml_init(&ml_free);
+
+ while ((m = ml_dequeue(&ml)) != NULL) {
+ mtx_enter(&serial->q_mtx);
+ if (serial->q_list.ml_len < MAX_QUEUED_PKT) {
+ ml_enqueue(&serial->q_list, m);
+ mtx_leave(&serial->q_mtx);
+ } else {
+ mtx_leave(&serial->q_mtx);
+ ml_enqueue(&ml_free, m);
+ continue;
+ }
+
+ mtx_enter(&parallel->r_mtx);
+ if (parallel->r_tail - parallel->r_head < MAX_QUEUED_PKT) {
+ parallel->r_buf[parallel->r_tail & MAX_QUEUED_PKT_MASK] = m;
+ parallel->r_tail++;
+ mtx_leave(&parallel->r_mtx);
+ } else {
+ mtx_leave(&parallel->r_mtx);
+ t = wg_tag_get(m);
+ t->t_done = 1;
+ }
+ }
+
+ if ((dropped = ml_purge(&ml_free)) > 0)
+ counters_add(sc->sc_if.if_counters, ifc_oqdrops, dropped);
+}
+
+struct mbuf *
+wg_ring_dequeue(struct wg_ring *r)
+{
+ struct mbuf *m = NULL;
+ mtx_enter(&r->r_mtx);
+ if (r->r_head != r->r_tail) {
+ m = r->r_buf[r->r_head & MAX_QUEUED_PKT_MASK];
+ r->r_head++;
+ }
+ mtx_leave(&r->r_mtx);
+ return m;
+}
+
+struct mbuf *
+wg_queue_dequeue(struct wg_queue *q, struct wg_tag **t)
+{
+ struct mbuf *m;
+ mtx_enter(&q->q_mtx);
+ if ((m = q->q_list.ml_head) != NULL && (*t = wg_tag_get(m))->t_done)
+ ml_dequeue(&q->q_list);
+ else
+ m = NULL;
+ mtx_leave(&q->q_mtx);
+ return m;
+}
+
+size_t
+wg_queue_len(struct wg_queue *q)
+{
+ size_t len;
+ mtx_enter(&q->q_mtx);
+ len = q->q_list.ml_len;
+ mtx_leave(&q->q_mtx);
+ return len;
+}
+
+struct noise_remote *
+wg_remote_get(void *_sc, uint8_t public[NOISE_PUBLIC_KEY_LEN])
+{
+ struct wg_peer *peer;
+ struct wg_softc *sc = _sc;
+ if ((peer = wg_peer_lookup(sc, public)) == NULL)
+ return NULL;
+ return &peer->p_remote;
+}
+
+uint32_t
+wg_index_set(void *_sc, struct noise_remote *remote)
+{
+ struct wg_peer *peer;
+ struct wg_softc *sc = _sc;
+ struct wg_index *index, *iter;
+ uint32_t key;
+
+ /*
+ * We can modify this without a lock as wg_index_set, wg_index_drop are
+ * guaranteed to be serialised (per remote).
+ */
+ peer = CONTAINER_OF(remote, struct wg_peer, p_remote);
+ index = SLIST_FIRST(&peer->p_unused_index);
+ KASSERT(index != NULL);
+ SLIST_REMOVE_HEAD(&peer->p_unused_index, i_unused_entry);
+
+ index->i_value = remote;
+
+ mtx_enter(&sc->sc_index_mtx);
+assign_id:
+ key = index->i_key = arc4random();
+ key &= sc->sc_index_mask;
+ LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
+ if (iter->i_key == index->i_key)
+ goto assign_id;
+
+ LIST_INSERT_HEAD(&sc->sc_index[key], index, i_entry);
+
+ mtx_leave(&sc->sc_index_mtx);
+
+ /* Likewise, no need to lock for index here. */
+ return index->i_key;
+}
+
+struct noise_remote *
+wg_index_get(void *_sc, uint32_t key0)
+{
+ struct wg_softc *sc = _sc;
+ struct wg_index *iter;
+ struct noise_remote *remote = NULL;
+ uint32_t key = key0 & sc->sc_index_mask;
+
+ mtx_enter(&sc->sc_index_mtx);
+ LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
+ if (iter->i_key == key0) {
+ remote = iter->i_value;
+ break;
+ }
+ mtx_leave(&sc->sc_index_mtx);
+ return remote;
+}
+
+void
+wg_index_drop(void *_sc, uint32_t key0)
+{
+ struct wg_softc *sc = _sc;
+ struct wg_index *iter;
+ struct wg_peer *peer = NULL;
+ uint32_t key = key0 & sc->sc_index_mask;
+
+ mtx_enter(&sc->sc_index_mtx);
+ LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
+ if (iter->i_key == key0) {
+ LIST_REMOVE(iter, i_entry);
+ break;
+ }
+ mtx_leave(&sc->sc_index_mtx);
+
+ /* We expect a peer */
+ peer = CONTAINER_OF(iter->i_value, struct wg_peer, p_remote);
+ KASSERT(peer != NULL);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, iter, i_unused_entry);
+}
+
+struct mbuf *
+wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6,
+ void *_uh, int hlen)
+{
+ struct wg_pkt_data *data;
+ struct noise_remote *remote;
+ struct wg_tag *t;
+ struct wg_softc *sc = _sc;
+ struct udphdr *uh = _uh;
+
+ NET_ASSERT_LOCKED();
+
+ if ((t = wg_tag_get(m)) == NULL) {
+ m_freem(m);
+ return NULL;
+ }
+
+ if (ip != NULL) {
+ t->t_endpoint.e_remote.r_sa.sa_len = sizeof(struct sockaddr_in);
+ t->t_endpoint.e_remote.r_sa.sa_family = AF_INET;
+ t->t_endpoint.e_remote.r_sin.sin_port = uh->uh_sport;
+ t->t_endpoint.e_remote.r_sin.sin_addr = ip->ip_src;
+ t->t_endpoint.e_local.l_in = ip->ip_dst;
+#ifdef INET6
+ } else if (ip6 != NULL) {
+ t->t_endpoint.e_remote.r_sa.sa_len = sizeof(struct sockaddr_in6);
+ t->t_endpoint.e_remote.r_sa.sa_family = AF_INET6;
+ t->t_endpoint.e_remote.r_sin6.sin6_port = uh->uh_sport;
+ t->t_endpoint.e_remote.r_sin6.sin6_addr = ip6->ip6_src;
+ t->t_endpoint.e_local.l_in6 = ip6->ip6_dst;
+#endif
+ } else {
+ m_freem(m);
+ return NULL;
+ }
+
+ /* m has a IP/IPv6 header of hlen length, we don't need it anymore. */
+ m_adj(m, hlen);
+
+ if (m_defrag(m, M_NOWAIT) != 0)
+ return NULL;
+
+ if ((m->m_pkthdr.len == sizeof(struct wg_pkt_initiation) &&
+ *mtod(m, uint32_t *) == WG_PKT_INITIATION) ||
+ (m->m_pkthdr.len == sizeof(struct wg_pkt_response) &&
+ *mtod(m, uint32_t *) == WG_PKT_RESPONSE) ||
+ (m->m_pkthdr.len == sizeof(struct wg_pkt_cookie) &&
+ *mtod(m, uint32_t *) == WG_PKT_COOKIE)) {
+
+ if (mq_enqueue(&sc->sc_handshake_queue, m) != 0)
+ DPRINTF(sc, "Dropping handshake packet\n");
+ task_add(wg_handshake_taskq, &sc->sc_handshake);
+
+ } else if (m->m_pkthdr.len >= sizeof(struct wg_pkt_data) +
+ NOISE_AUTHTAG_LEN && *mtod(m, uint32_t *) == WG_PKT_DATA) {
+
+ data = mtod(m, struct wg_pkt_data *);
+
+ if ((remote = wg_index_get(sc, data->r_idx)) != NULL) {
+ t->t_peer = CONTAINER_OF(remote, struct wg_peer,
+ p_remote);
+ t->t_mbuf = NULL;
+ t->t_done = 0;
+
+ if (wg_queue_in(sc, t->t_peer, m) != 0)
+ counters_inc(sc->sc_if.if_counters,
+ ifc_iqdrops);
+ task_add(wg_crypt_taskq, &sc->sc_decap);
+ } else {
+ counters_inc(sc->sc_if.if_counters, ifc_ierrors);
+ m_freem(m);
+ }
+ } else {
+ counters_inc(sc->sc_if.if_counters, ifc_ierrors);
+ m_freem(m);
+ }
+
+ return NULL;
+}
+
+void
+wg_start(struct ifnet *ifp)
+{
+ struct wg_softc *sc = ifp->if_softc;
+ struct wg_peer *peer;
+ struct wg_tag *t;
+ struct mbuf *m;
+ SLIST_HEAD(,wg_peer) start_list;
+
+ SLIST_INIT(&start_list);
+
+ /*
+ * We should be OK to modify p_start_list, p_start_onlist in this
+ * function as the interface is not IFXF_MPSAFE and therefore should
+ * only be one instance of this function running at a time. These
+ * values are not modified anywhere else.
+ */
+ while ((m = ifq_dequeue(&ifp->if_snd)) != NULL) {
+ t = wg_tag_get(m);
+ peer = t->t_peer;
+ if (mq_push(&peer->p_stage_queue, m) != 0)
+ counters_inc(ifp->if_counters, ifc_oqdrops);
+ if (!peer->p_start_onlist) {
+ SLIST_INSERT_HEAD(&start_list, peer, p_start_list);
+ peer->p_start_onlist = 1;
+ }
+ }
+ SLIST_FOREACH(peer, &start_list, p_start_list) {
+ if (noise_remote_ready(&peer->p_remote) == 0)
+ wg_queue_out(sc, peer);
+ else
+ wg_timers_event_want_initiation(&peer->p_timers);
+ peer->p_start_onlist = 0;
+ }
+ task_add(wg_crypt_taskq, &sc->sc_encap);
+}
+
+int
+wg_output(struct ifnet *ifp, struct mbuf *m, struct sockaddr *sa,
+ struct rtentry *rt)
+{
+ struct wg_softc *sc = ifp->if_softc;
+ struct wg_peer *peer;
+ struct wg_tag *t;
+ int af, ret = EINVAL;
+
+ NET_ASSERT_LOCKED();
+
+ if ((t = wg_tag_get(m)) == NULL) {
+ ret = ENOBUFS;
+ goto error;
+ }
+
+ m->m_pkthdr.ph_family = sa->sa_family;
+ if (sa->sa_family == AF_INET) {
+ peer = wg_aip_lookup(sc->sc_aip4,
+ &mtod(m, struct ip *)->ip_dst);
+#ifdef INET6
+ } else if (sa->sa_family == AF_INET6) {
+ peer = wg_aip_lookup(sc->sc_aip6,
+ &mtod(m, struct ip6_hdr *)->ip6_dst);
+#endif
+ } else {
+ ret = EAFNOSUPPORT;
+ goto error;
+ }
+
+#if NBPFILTER > 0
+ if (sc->sc_if.if_bpf)
+ bpf_mtap_af(sc->sc_if.if_bpf, sa->sa_family, m,
+ BPF_DIRECTION_OUT);
+#endif
+
+ if (peer == NULL) {
+ ret = ENETUNREACH;
+ goto error;
+ }
+
+ af = peer->p_endpoint.e_remote.r_sa.sa_family;
+ if (af != AF_INET && af != AF_INET6) {
+ DPRINTF(sc, "No valid endpoint has been configured or "
+ "discovered for peer %llu\n", peer->p_id);
+ ret = EDESTADDRREQ;
+ goto error;
+ }
+
+ if (m->m_pkthdr.ph_loopcnt++ > M_MAXLOOP) {
+ DPRINTF(sc, "Packet looped");
+ ret = ELOOP;
+ goto error;
+ }
+
+ /*
+ * As we hold a reference to peer in the mbuf, we can't handle a
+ * delayed packet without doing some refcnting. If a peer is removed
+ * while a delayed holds a reference, bad things will happen. For the
+ * time being, delayed packets are unsupported. This may be fixed with
+ * another aip_lookup in wg_start, or refcnting as mentioned before.
+ */
+ if (m->m_pkthdr.pf.delay > 0) {
+ DPRINTF(sc, "PF Delay Unsupported");
+ ret = EOPNOTSUPP;
+ goto error;
+ }
+
+ t->t_peer = peer;
+ t->t_mbuf = NULL;
+ t->t_done = 0;
+ t->t_mtu = ifp->if_mtu;
+
+ /*
+ * We still have an issue with ifq that will count a packet that gets
+ * dropped in wg_start, or not encrypted. These get counted as
+ * ofails or oqdrops, so the packet gets counted twice.
+ */
+ return if_enqueue(ifp, m);
+error:
+ counters_inc(ifp->if_counters, ifc_oerrors);
+ m_freem(m);
+ return ret;
+}
+
+int
+wg_ioctl_set(struct wg_softc *sc, struct wg_data_io *data)
+{
+ struct wg_interface_io *iface_p, iface_o;
+ struct wg_peer_io *peer_p, peer_o;
+ struct wg_aip_io *aip_p, aip_o;
+
+ struct wg_peer *peer, *tpeer;
+ struct wg_aip *aip, *taip;
+
+ in_port_t port;
+ int rtable;
+
+ uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE];
+ size_t i, j;
+ int ret, has_identity;
+
+ if ((ret = suser(curproc)) != 0)
+ return ret;
+
+ rw_enter_write(&sc->sc_lock);
+
+ iface_p = data->wgd_interface;
+ if ((ret = copyin(iface_p, &iface_o, sizeof(iface_o))) != 0)
+ goto error;
+
+ if (iface_o.i_flags & WG_INTERFACE_REPLACE_PEERS)
+ TAILQ_FOREACH_SAFE(peer, &sc->sc_peer_seq, p_seq_entry, tpeer)
+ wg_peer_destroy(peer);
+
+ if (iface_o.i_flags & WG_INTERFACE_HAS_PRIVATE &&
+ (noise_local_keys(&sc->sc_local, NULL, private) ||
+ timingsafe_bcmp(private, iface_o.i_private, WG_KEY_SIZE))) {
+ if (curve25519_generate_public(public, iface_o.i_private)) {
+ if ((peer = wg_peer_lookup(sc, public)) != NULL)
+ wg_peer_destroy(peer);
+ }
+ noise_local_lock_identity(&sc->sc_local);
+ has_identity = noise_local_set_private(&sc->sc_local,
+ iface_o.i_private);
+ TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) {
+ noise_remote_precompute(&peer->p_remote);
+ wg_timers_event_reset_handshake_last_sent(&peer->p_timers);
+ noise_remote_expire_current(&peer->p_remote);
+ }
+ cookie_checker_update(&sc->sc_cookie,
+ has_identity == 0 ? public : NULL);
+ noise_local_unlock_identity(&sc->sc_local);
+ }
+
+ if (iface_o.i_flags & WG_INTERFACE_HAS_PORT)
+ port = htons(iface_o.i_port);
+ else
+ port = sc->sc_udp_port;
+
+ if (iface_o.i_flags & WG_INTERFACE_HAS_RTABLE)
+ rtable = iface_o.i_rtable;
+ else
+ rtable = sc->sc_udp_rtable;
+
+ if (port != sc->sc_udp_port || rtable != sc->sc_udp_rtable) {
+ TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry)
+ wg_peer_clear_src(peer);
+
+ if (sc->sc_if.if_flags & IFF_RUNNING)
+ if ((ret = wg_bind(sc, &port, &rtable)) != 0)
+ goto error;
+
+ sc->sc_udp_port = port;
+ sc->sc_udp_rtable = rtable;
+ }
+
+ peer_p = &iface_p->i_peers[0];
+ for (i = 0; i < iface_o.i_peers_count; i++) {
+ if ((ret = copyin(peer_p, &peer_o, sizeof(peer_o))) != 0)
+ goto error;
+
+ /* Peer must have public key */
+ if (!(peer_o.p_flags & WG_PEER_HAS_PUBLIC))
+ continue;
+
+ /* 0 = latest protocol, 1 = this protocol */
+ if (peer_o.p_protocol_version != 0) {
+ if (peer_o.p_protocol_version > 1) {
+ ret = EPFNOSUPPORT;
+ goto error;
+ }
+ }
+
+ /* Get local public and check that peer key doesn't match */
+ if (noise_local_keys(&sc->sc_local, public, NULL) == 0 &&
+ bcmp(public, peer_o.p_public, WG_KEY_SIZE) == 0)
+ continue;
+
+ /* Lookup peer, or create if it doesn't exist */
+ if ((peer = wg_peer_lookup(sc, peer_o.p_public)) == NULL) {
+ /* If we want to delete, no need creating a new one.
+ * Also, don't create a new one if we only want to
+ * update. */
+ if (peer_o.p_flags & (WG_PEER_REMOVE|WG_PEER_UPDATE))
+ continue;
+
+ if ((peer = wg_peer_create(sc,
+ peer_o.p_public)) == NULL) {
+ ret = ENOMEM;
+ goto error;
+ }
+ }
+
+ /* Remove peer and continue if specified */
+ if (peer_o.p_flags & WG_PEER_REMOVE) {
+ wg_peer_destroy(peer);
+ continue;
+ }
+
+ if (peer_o.p_flags & WG_PEER_HAS_ENDPOINT)
+ wg_peer_set_sockaddr(peer, &peer_o.p_sa);
+
+ if (peer_o.p_flags & WG_PEER_HAS_PSK)
+ noise_remote_set_psk(&peer->p_remote, peer_o.p_psk);
+
+ if (peer_o.p_flags & WG_PEER_HAS_PKA)
+ wg_timers_set_persistent_keepalive(&peer->p_timers,
+ peer_o.p_pka);
+
+ if (peer_o.p_flags & WG_PEER_REPLACE_AIPS) {
+ LIST_FOREACH_SAFE(aip, &peer->p_aip, a_entry, taip) {
+ wg_aip_remove(sc, peer, &aip->a_data);
+ }
+ }
+
+ aip_p = &peer_p->p_aips[0];
+ for (j = 0; j < peer_o.p_aips_count; j++) {
+ if ((ret = copyin(aip_p, &aip_o, sizeof(aip_o))) != 0)
+ goto error;
+ ret = wg_aip_add(sc, peer, &aip_o);
+ if (ret != 0)
+ goto error;
+ aip_p++;
+ }
+
+ peer_p = (struct wg_peer_io *)aip_p;
+ }
+
+error:
+ rw_exit_write(&sc->sc_lock);
+ explicit_bzero(&iface_o, sizeof(iface_o));
+ explicit_bzero(&peer_o, sizeof(peer_o));
+ explicit_bzero(&aip_o, sizeof(aip_o));
+ explicit_bzero(public, sizeof(public));
+ explicit_bzero(private, sizeof(private));
+ return ret;
+}
+
+int
+wg_ioctl_get(struct wg_softc *sc, struct wg_data_io *data)
+{
+ struct wg_interface_io *iface_p, iface_o;
+ struct wg_peer_io *peer_p, peer_o;
+ struct wg_aip_io *aip_p;
+
+ struct wg_peer *peer;
+ struct wg_aip *aip;
+
+ size_t size, peer_count, aip_count;
+ int ret = 0, is_suser = suser(curproc) == 0;
+
+ size = sizeof(struct wg_interface_io);
+ if (data->wgd_size < size && !is_suser)
+ goto ret_size;
+
+ iface_p = data->wgd_interface;
+ bzero(&iface_o, sizeof(iface_o));
+
+ rw_enter_read(&sc->sc_lock);
+
+ if (sc->sc_udp_port != 0) {
+ iface_o.i_port = ntohs(sc->sc_udp_port);
+ iface_o.i_flags |= WG_INTERFACE_HAS_PORT;
+ }
+
+ if (sc->sc_udp_rtable != 0) {
+ iface_o.i_rtable = sc->sc_udp_rtable;
+ iface_o.i_flags |= WG_INTERFACE_HAS_RTABLE;
+ }
+
+ if (!is_suser)
+ goto copy_out_iface;
+
+ if (noise_local_keys(&sc->sc_local, iface_o.i_public,
+ iface_o.i_private) == 0) {
+ iface_o.i_flags |= WG_INTERFACE_HAS_PUBLIC;
+ iface_o.i_flags |= WG_INTERFACE_HAS_PRIVATE;
+ }
+
+ size += sizeof(struct wg_peer_io) * sc->sc_peer_num;
+ size += sizeof(struct wg_aip_io) * sc->sc_aip_num;
+ if (data->wgd_size < size)
+ goto unlock_and_ret_size;
+
+ peer_count = 0;
+ peer_p = &iface_p->i_peers[0];
+ TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) {
+ bzero(&peer_o, sizeof(peer_o));
+ peer_o.p_flags = WG_PEER_HAS_PUBLIC;
+ peer_o.p_protocol_version = 1;
+
+ if (noise_remote_keys(&peer->p_remote, peer_o.p_public,
+ peer_o.p_psk) == 0)
+ peer_o.p_flags |= WG_PEER_HAS_PSK;
+
+ if (wg_timers_get_persistent_keepalive(&peer->p_timers,
+ &peer_o.p_pka) == 0)
+ peer_o.p_flags |= WG_PEER_HAS_PKA;
+
+ if (wg_peer_get_sockaddr(peer, &peer_o.p_sa) == 0)
+ peer_o.p_flags |= WG_PEER_HAS_ENDPOINT;
+
+ mtx_enter(&peer->p_counters_mtx);
+ peer_o.p_txbytes = peer->p_counters_tx;
+ peer_o.p_rxbytes = peer->p_counters_rx;
+ mtx_leave(&peer->p_counters_mtx);
+
+ wg_timers_get_last_handshake(&peer->p_timers,
+ &peer_o.p_last_handshake);
+
+ aip_count = 0;
+ aip_p = &peer_p->p_aips[0];
+ LIST_FOREACH(aip, &peer->p_aip, a_entry) {
+ if ((ret = copyout(&aip->a_data, aip_p, sizeof(*aip_p))) != 0)
+ goto unlock_and_ret_size;
+ aip_p++;
+ aip_count++;
+ }
+ peer_o.p_aips_count = aip_count;
+
+ if ((ret = copyout(&peer_o, peer_p, sizeof(peer_o))) != 0)
+ goto unlock_and_ret_size;
+
+ peer_p = (struct wg_peer_io *)aip_p;
+ peer_count++;
+ }
+ iface_o.i_peers_count = peer_count;
+
+copy_out_iface:
+ ret = copyout(&iface_o, iface_p, sizeof(iface_o));
+unlock_and_ret_size:
+ rw_exit_read(&sc->sc_lock);
+ explicit_bzero(&iface_o, sizeof(iface_o));
+ explicit_bzero(&peer_o, sizeof(peer_o));
+ret_size:
+ data->wgd_size = size;
+ return ret;
+}
+
+int
+wg_ioctl(struct ifnet *ifp, u_long cmd, caddr_t data)
+{
+ struct ifreq *ifr = (struct ifreq *) data;
+ struct wg_softc *sc = ifp->if_softc;
+ int ret = 0;
+
+ switch (cmd) {
+ case SIOCSWG:
+ ret = wg_ioctl_set(sc, (struct wg_data_io *) data);
+ break;
+ case SIOCGWG:
+ ret = wg_ioctl_get(sc, (struct wg_data_io *) data);
+ break;
+ /* Interface IOCTLs */
+ case SIOCSIFADDR:
+ SET(ifp->if_flags, IFF_UP);
+ /* FALLTHROUGH */
+ case SIOCSIFFLAGS:
+ if (ISSET(ifp->if_flags, IFF_UP))
+ ret = wg_up(sc);
+ else
+ wg_down(sc);
+ break;
+ case SIOCSIFMTU:
+ /* Arbitrary limits */
+ if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > 9000)
+ ret = EINVAL;
+ else
+ ifp->if_mtu = ifr->ifr_mtu;
+ break;
+ case SIOCADDMULTI:
+ case SIOCDELMULTI:
+ break;
+ default:
+ ret = ENOTTY;
+ }
+
+ return ret;
+}
+
+int
+wg_up(struct wg_softc *sc)
+{
+ struct wg_peer *peer;
+ int ret = 0;
+
+ NET_ASSERT_LOCKED();
+ /*
+ * We use IFF_RUNNING as an exclusive access here. We also may want
+ * an exclusive sc_lock as wg_bind may write to sc_udp_port. We also
+ * want to drop NET_LOCK as we want to call socreate, sobind, etc. Once
+ * solock is no longer === NET_LOCK, we may be able to avoid this.
+ */
+ if (!ISSET(sc->sc_if.if_flags, IFF_RUNNING)) {
+ SET(sc->sc_if.if_flags, IFF_RUNNING);
+ NET_UNLOCK();
+
+ rw_enter_write(&sc->sc_lock);
+ /*
+ * If we successfully bind the socket, then enable the timers
+ * for the peer. This will send all staged packets and a
+ * keepalive if necessary.
+ */
+ ret = wg_bind(sc, &sc->sc_udp_port, &sc->sc_udp_rtable);
+ if (ret == 0) {
+ TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) {
+ wg_timers_enable(&peer->p_timers);
+ wg_queue_out(sc, peer);
+ }
+ }
+ rw_exit_write(&sc->sc_lock);
+
+ NET_LOCK();
+ if (ret != 0)
+ CLR(sc->sc_if.if_flags, IFF_RUNNING);
+ }
+ return ret;
+}
+
+void
+wg_down(struct wg_softc *sc)
+{
+ struct wg_peer *peer;
+
+ NET_ASSERT_LOCKED();
+ if (!ISSET(sc->sc_if.if_flags, IFF_RUNNING))
+ return;
+ CLR(sc->sc_if.if_flags, IFF_RUNNING);
+ NET_UNLOCK();
+
+ /*
+ * We only need a read lock here, as we aren't writing to anything
+ * that isn't granularly locked.
+ */
+ rw_enter_read(&sc->sc_lock);
+ TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) {
+ mq_purge(&peer->p_stage_queue);
+ wg_timers_disable(&peer->p_timers);
+ }
+
+ taskq_barrier(wg_handshake_taskq);
+ TAILQ_FOREACH(peer, &sc->sc_peer_seq, p_seq_entry) {
+ noise_remote_clear(&peer->p_remote);
+ wg_timers_event_reset_handshake_last_sent(&peer->p_timers);
+ }
+
+ wg_unbind(sc);
+ rw_exit_read(&sc->sc_lock);
+ NET_LOCK();
+}
+
+int
+wg_clone_create(struct if_clone *ifc, int unit)
+{
+ struct ifnet *ifp;
+ struct wg_softc *sc;
+ struct noise_upcall local_upcall;
+
+ KERNEL_ASSERT_LOCKED();
+
+ if (wg_counter == 0) {
+ wg_handshake_taskq = taskq_create("wg_handshake",
+ 2, IPL_NET, TASKQ_MPSAFE);
+ wg_crypt_taskq = taskq_create("wg_crypt",
+ ncpus, IPL_NET, TASKQ_MPSAFE);
+
+ if (wg_handshake_taskq == NULL || wg_crypt_taskq == NULL) {
+ if (wg_handshake_taskq != NULL)
+ taskq_destroy(wg_handshake_taskq);
+ if (wg_crypt_taskq != NULL)
+ taskq_destroy(wg_crypt_taskq);
+ wg_handshake_taskq = NULL;
+ wg_crypt_taskq = NULL;
+ return ENOTRECOVERABLE;
+ }
+ }
+ wg_counter++;
+
+ if ((sc = malloc(sizeof(*sc), M_DEVBUF, M_NOWAIT | M_ZERO)) == NULL)
+ goto ret_00;
+
+ local_upcall.u_arg = sc;
+ local_upcall.u_remote_get = wg_remote_get;
+ local_upcall.u_index_set = wg_index_set;
+ local_upcall.u_index_drop = wg_index_drop;
+
+ TAILQ_INIT(&sc->sc_peer_seq);
+
+ /* sc_if is initialised after everything else */
+ arc4random_buf(&sc->sc_secret, sizeof(sc->sc_secret));
+
+ rw_init(&sc->sc_lock, "wg");
+ noise_local_init(&sc->sc_local, &local_upcall);
+ if (cookie_checker_init(&sc->sc_cookie, &wg_ratelimit_pool) != 0)
+ goto ret_01;
+ sc->sc_udp_port = 0;
+ sc->sc_udp_rtable = 0;
+
+ rw_init(&sc->sc_so_lock, "wg_so");
+ sc->sc_so4 = NULL;
+#ifdef INET6
+ sc->sc_so6 = NULL;
+#endif
+
+ sc->sc_aip_num = 0;
+ if ((sc->sc_aip4 = art_alloc(0, 32, 0)) == NULL)
+ goto ret_02;
+#ifdef INET6
+ if ((sc->sc_aip6 = art_alloc(0, 128, 0)) == NULL)
+ goto ret_03;
+#endif
+
+ rw_init(&sc->sc_peer_lock, "wg_peer");
+ sc->sc_peer_num = 0;
+ if ((sc->sc_peer = hashinit(HASHTABLE_PEER_SIZE, M_DEVBUF,
+ M_NOWAIT, &sc->sc_peer_mask)) == NULL)
+ goto ret_04;
+
+ mtx_init(&sc->sc_index_mtx, IPL_NET);
+ if ((sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF,
+ M_NOWAIT, &sc->sc_index_mask)) == NULL)
+ goto ret_05;
+
+ task_set(&sc->sc_handshake, wg_handshake_worker, sc);
+ mq_init(&sc->sc_handshake_queue, MAX_QUEUED_HANDSHAKES, IPL_NET);
+
+ task_set(&sc->sc_encap, wg_encap_worker, sc);
+ task_set(&sc->sc_decap, wg_decap_worker, sc);
+
+ bzero(&sc->sc_encap_ring, sizeof(sc->sc_encap_ring));
+ mtx_init(&sc->sc_encap_ring.r_mtx, IPL_NET);
+ bzero(&sc->sc_decap_ring, sizeof(sc->sc_decap_ring));
+ mtx_init(&sc->sc_decap_ring.r_mtx, IPL_NET);
+
+ /* We've setup the softc, now we can setup the ifnet */
+ ifp = &sc->sc_if;
+ ifp->if_softc = sc;
+
+ snprintf(ifp->if_xname, sizeof(ifp->if_xname), "wg%d", unit);
+
+ ifp->if_mtu = DEFAULT_MTU;
+ ifp->if_flags = IFF_BROADCAST | IFF_MULTICAST | IFF_NOARP;
+ ifp->if_xflags = IFXF_CLONED;
+
+ ifp->if_ioctl = wg_ioctl;
+ ifp->if_start = wg_start;
+ ifp->if_output = wg_output;
+
+ ifp->if_type = IFT_WIREGUARD;
+ IFQ_SET_MAXLEN(&ifp->if_snd, IFQ_MAXLEN);
+
+ if_attach(ifp);
+ if_alloc_sadl(ifp);
+ if_counters_alloc(ifp);
+
+#if NBPFILTER > 0
+ bpfattach(&ifp->if_bpf, ifp, DLT_LOOP, sizeof(uint32_t));
+#endif
+
+ DPRINTF(sc, "Interface created\n");
+
+ return 0;
+ret_05:
+ hashfree(sc->sc_peer, HASHTABLE_PEER_SIZE, M_DEVBUF);
+ret_04:
+#ifdef INET6
+ free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6));
+ret_03:
+#endif
+ free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4));
+ret_02:
+ cookie_checker_deinit(&sc->sc_cookie);
+ret_01:
+ free(sc, M_DEVBUF, sizeof(*sc));
+ret_00:
+ return ENOBUFS;
+}
+int
+wg_clone_destroy(struct ifnet *ifp)
+{
+ struct wg_softc *sc = ifp->if_softc;
+ struct wg_peer *peer, *tpeer;
+
+ KERNEL_ASSERT_LOCKED();
+
+ rw_enter_write(&sc->sc_lock);
+ TAILQ_FOREACH_SAFE(peer, &sc->sc_peer_seq, p_seq_entry, tpeer)
+ wg_peer_destroy(peer);
+ rw_exit_write(&sc->sc_lock);
+
+ wg_unbind(sc);
+ if_detach(ifp);
+
+ wg_counter--;
+ if (wg_counter == 0) {
+ KASSERT(wg_handshake_taskq != NULL && wg_crypt_taskq != NULL);
+ taskq_destroy(wg_handshake_taskq);
+ taskq_destroy(wg_crypt_taskq);
+ wg_handshake_taskq = NULL;
+ wg_crypt_taskq = NULL;
+ }
+
+ DPRINTF(sc, "Destroyed interface\n");
+
+ hashfree(sc->sc_index, HASHTABLE_INDEX_SIZE, M_DEVBUF);
+ hashfree(sc->sc_peer, HASHTABLE_PEER_SIZE, M_DEVBUF);
+#ifdef INET6
+ free(sc->sc_aip6, M_RTABLE, sizeof(*sc->sc_aip6));
+#endif
+ free(sc->sc_aip4, M_RTABLE, sizeof(*sc->sc_aip4));
+ cookie_checker_deinit(&sc->sc_cookie);
+ free(sc, M_DEVBUF, sizeof(*sc));
+ return 0;
+}
+
+void
+wgattach(int nwg)
+{
+#ifdef WGTEST
+ cookie_test();
+ noise_test();
+#endif
+ if_clone_attach(&wg_cloner);
+
+ pool_init(&wg_aip_pool, sizeof(struct wg_aip), 0,
+ IPL_NET, 0, "wgaip", NULL);
+ pool_init(&wg_peer_pool, sizeof(struct wg_peer), 0,
+ IPL_NET, 0, "wgpeer", NULL);
+ pool_init(&wg_ratelimit_pool, sizeof(struct ratelimit_entry), 0,
+ IPL_NET, 0, "wgratelimit", NULL);
+}
diff --git a/sys/net/if_wg.h b/sys/net/if_wg.h
new file mode 100644
index 00000000000..fcbd3e167d8
--- /dev/null
+++ b/sys/net/if_wg.h
@@ -0,0 +1,107 @@
+/*
+ * Copyright (C) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2020 Matt Dunwoodie <ncon@noconroy.net>
+ *
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ */
+
+#ifndef __IF_WG_H__
+#define __IF_WG_H__
+
+#include <sys/limits.h>
+#include <sys/errno.h>
+
+#include <net/if.h>
+#include <netinet/in.h>
+
+
+/*
+ * This is the public interface to the WireGuard network interface.
+ *
+ * It is designed to be used by tools such as ifconfig(8) and wg(8).
+ */
+
+#define WG_KEY_LEN 32
+
+/*
+ * These ioctls do not need a NETLOCK as they use their own locks to serialise
+ * access.
+ */
+#define SIOCSWG _IOWR('i', 210, struct wg_data_io)
+#define SIOCGWG _IOWR('i', 211, struct wg_data_io)
+
+#define a_ipv4 a_addr.addr_ipv4
+#define a_ipv6 a_addr.addr_ipv6
+
+struct wg_aip_io {
+ sa_family_t a_af;
+ int a_cidr;
+ union wg_aip_addr {
+ struct in_addr addr_ipv4;
+ struct in6_addr addr_ipv6;
+ } a_addr;
+};
+
+#define WG_PEER_HAS_PUBLIC (1 << 0)
+#define WG_PEER_HAS_PSK (1 << 1)
+#define WG_PEER_HAS_PKA (1 << 2)
+#define WG_PEER_HAS_ENDPOINT (1 << 3)
+#define WG_PEER_REPLACE_AIPS (1 << 4)
+#define WG_PEER_REMOVE (1 << 5)
+#define WG_PEER_UPDATE (1 << 6)
+
+#define p_sa p_endpoint.sa_sa
+#define p_sin p_endpoint.sa_sin
+#define p_sin6 p_endpoint.sa_sin6
+
+struct wg_peer_io {
+ int p_flags;
+ int p_protocol_version;
+ uint8_t p_public[WG_KEY_LEN];
+ uint8_t p_psk[WG_KEY_LEN];
+ uint16_t p_pka;
+ union wg_peer_endpoint {
+ struct sockaddr sa_sa;
+ struct sockaddr_in sa_sin;
+ struct sockaddr_in6 sa_sin6;
+ } p_endpoint;
+ uint64_t p_txbytes;
+ uint64_t p_rxbytes;
+ struct timespec p_last_handshake; /* nanotime */
+ size_t p_aips_count;
+ struct wg_aip_io p_aips[];
+};
+
+#define WG_INTERFACE_HAS_PUBLIC (1 << 0)
+#define WG_INTERFACE_HAS_PRIVATE (1 << 1)
+#define WG_INTERFACE_HAS_PORT (1 << 2)
+#define WG_INTERFACE_HAS_RTABLE (1 << 3)
+#define WG_INTERFACE_REPLACE_PEERS (1 << 4)
+
+struct wg_interface_io {
+ uint8_t i_flags;
+ in_port_t i_port;
+ int i_rtable;
+ uint8_t i_public[WG_KEY_LEN];
+ uint8_t i_private[WG_KEY_LEN];
+ size_t i_peers_count;
+ struct wg_peer_io i_peers[];
+};
+
+struct wg_data_io {
+ char wgd_name[IFNAMSIZ];
+ size_t wgd_size; /* total size of the memory pointed to by wgd_interface */
+ struct wg_interface_io *wgd_interface;
+};
+
+#endif /* __IF_WG_H__ */
diff --git a/sys/net/wg_cookie.c b/sys/net/wg_cookie.c
new file mode 100644
index 00000000000..85c97a60aef
--- /dev/null
+++ b/sys/net/wg_cookie.c
@@ -0,0 +1,697 @@
+/*
+ * Copyright (C) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2020 Matt Dunwoodie <ncon@noconroy.net>
+ *
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ */
+
+#include <sys/types.h>
+#include <sys/systm.h>
+#include <sys/param.h>
+#include <sys/rwlock.h>
+#include <sys/malloc.h> /* Because systm doesn't include M_NOWAIT, M_DEVBUF */
+#include <sys/pool.h>
+#include <sys/socket.h>
+
+#include <crypto/chachapoly.h>
+
+#include <net/wg_cookie.h>
+
+static void cookie_precompute_key(uint8_t *,
+ const uint8_t[COOKIE_INPUT_SIZE], const char *);
+static void cookie_macs_mac1(struct cookie_macs *, const void *, size_t,
+ const uint8_t[COOKIE_KEY_SIZE]);
+static void cookie_macs_mac2(struct cookie_macs *, const void *, size_t,
+ const uint8_t[COOKIE_COOKIE_SIZE]);
+static int cookie_timer_expired(struct timespec *, time_t, long);
+static void cookie_checker_make_cookie(struct cookie_checker *,
+ uint8_t[COOKIE_COOKIE_SIZE], struct sockaddr *);
+static int ratelimit_init(struct ratelimit *, struct pool *pool);
+static void ratelimit_deinit(struct ratelimit *);
+static void ratelimit_gc(struct ratelimit *, int);
+static int ratelimit_allow(struct ratelimit *, struct sockaddr *);
+
+/* Public Functions */
+void
+cookie_maker_init(struct cookie_maker *cp, uint8_t key[COOKIE_INPUT_SIZE])
+{
+ bzero(cp, sizeof(*cp));
+ cookie_precompute_key(cp->cp_mac1_key, key, COOKIE_MAC1_KEY_LABEL);
+ cookie_precompute_key(cp->cp_cookie_key, key, COOKIE_COOKIE_KEY_LABEL);
+ rw_init(&cp->cp_lock, "cookie_maker");
+}
+
+int
+cookie_checker_init(struct cookie_checker *cc, struct pool *pool)
+{
+ int res;
+ bzero(cc, sizeof(*cc));
+
+ rw_init(&cc->cc_key_lock, "cookie_checker_key");
+ rw_init(&cc->cc_secret_lock, "cookie_checker_secret");
+
+ if ((res = ratelimit_init(&cc->cc_ratelimit_v4, pool)) != 0)
+ return res;
+#ifdef INET6
+ if ((res = ratelimit_init(&cc->cc_ratelimit_v6, pool)) != 0) {
+ ratelimit_deinit(&cc->cc_ratelimit_v4);
+ return res;
+ }
+#endif
+ return 0;
+}
+
+void
+cookie_checker_update(struct cookie_checker *cc,
+ uint8_t key[COOKIE_INPUT_SIZE])
+{
+ rw_enter_write(&cc->cc_key_lock);
+ if (key) {
+ cookie_precompute_key(cc->cc_mac1_key, key, COOKIE_MAC1_KEY_LABEL);
+ cookie_precompute_key(cc->cc_cookie_key, key, COOKIE_COOKIE_KEY_LABEL);
+ } else {
+ bzero(cc->cc_mac1_key, sizeof(cc->cc_mac1_key));
+ bzero(cc->cc_cookie_key, sizeof(cc->cc_cookie_key));
+ }
+ rw_exit_write(&cc->cc_key_lock);
+}
+
+void
+cookie_checker_deinit(struct cookie_checker *cc)
+{
+ ratelimit_deinit(&cc->cc_ratelimit_v4);
+#ifdef INET6
+ ratelimit_deinit(&cc->cc_ratelimit_v6);
+#endif
+}
+
+void
+cookie_checker_create_payload(struct cookie_checker *cc,
+ struct cookie_macs *cm, uint8_t nonce[COOKIE_NONCE_SIZE],
+ uint8_t ecookie[COOKIE_ENCRYPTED_SIZE], struct sockaddr *sa)
+{
+ uint8_t cookie[COOKIE_COOKIE_SIZE];
+
+ cookie_checker_make_cookie(cc, cookie, sa);
+ arc4random_buf(nonce, COOKIE_NONCE_SIZE);
+
+ rw_enter_read(&cc->cc_key_lock);
+ xchacha20poly1305_encrypt(ecookie, cookie, COOKIE_COOKIE_SIZE,
+ cm->mac1, COOKIE_MAC_SIZE, nonce, cc->cc_cookie_key);
+ rw_exit_read(&cc->cc_key_lock);
+
+ explicit_bzero(cookie, sizeof(cookie));
+}
+
+int
+cookie_maker_consume_payload(struct cookie_maker *cp,
+ uint8_t nonce[COOKIE_NONCE_SIZE], uint8_t ecookie[COOKIE_ENCRYPTED_SIZE])
+{
+ int ret = 0;
+ uint8_t cookie[COOKIE_COOKIE_SIZE];
+
+ rw_enter_write(&cp->cp_lock);
+
+ if (cp->cp_mac1_valid == 0) {
+ ret = ETIMEDOUT;
+ goto error;
+ }
+
+ if (xchacha20poly1305_decrypt(cookie, ecookie, COOKIE_ENCRYPTED_SIZE,
+ cp->cp_mac1_last, COOKIE_MAC_SIZE, nonce, cp->cp_cookie_key) == 0) {
+ ret = EINVAL;
+ goto error;
+ }
+
+ memcpy(cp->cp_cookie, cookie, COOKIE_COOKIE_SIZE);
+ getnanouptime(&cp->cp_birthdate);
+ cp->cp_mac1_valid = 0;
+
+error:
+ rw_exit_write(&cp->cp_lock);
+ return ret;
+}
+
+void
+cookie_maker_mac(struct cookie_maker *cp, struct cookie_macs *cm, void *buf,
+ size_t len)
+{
+ rw_enter_read(&cp->cp_lock);
+
+ cookie_macs_mac1(cm, buf, len, cp->cp_mac1_key);
+
+ memcpy(cp->cp_mac1_last, cm->mac1, COOKIE_MAC_SIZE);
+ cp->cp_mac1_valid = 1;
+
+ if (!cookie_timer_expired(&cp->cp_birthdate,
+ COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY, 0))
+ cookie_macs_mac2(cm, buf, len, cp->cp_cookie);
+ else
+ bzero(cm->mac2, COOKIE_MAC_SIZE);
+
+ rw_exit_read(&cp->cp_lock);
+}
+
+int
+cookie_checker_validate_macs(struct cookie_checker *cc, struct cookie_macs *cm,
+ void *buf, size_t len, int busy, struct sockaddr *sa)
+{
+ struct cookie_macs our_cm;
+ uint8_t cookie[COOKIE_COOKIE_SIZE];
+
+ /* Validate incoming MACs */
+ rw_enter_read(&cc->cc_key_lock);
+ cookie_macs_mac1(&our_cm, buf, len, cc->cc_mac1_key);
+ rw_exit_read(&cc->cc_key_lock);
+
+ /* If mac1 is invald, we want to drop the packet */
+ if (timingsafe_bcmp(our_cm.mac1, cm->mac1, COOKIE_MAC_SIZE) != 0)
+ return EINVAL;
+
+ if (busy != 0) {
+ cookie_checker_make_cookie(cc, cookie, sa);
+ cookie_macs_mac2(&our_cm, buf, len, cookie);
+
+ /* If the mac2 is invalid, we want to send a cookie response */
+ if (timingsafe_bcmp(our_cm.mac2, cm->mac2, COOKIE_MAC_SIZE) != 0)
+ return EAGAIN;
+
+ /* If the mac2 is valid, we may want rate limit the peer.
+ * ratelimit_allow will return either 0 or ECONNREFUSED,
+ * implying there is no ratelimiting, or we should ratelimit
+ * (refuse) respectively. */
+ if (sa->sa_family == AF_INET)
+ return ratelimit_allow(&cc->cc_ratelimit_v4, sa);
+#ifdef INET6
+ else if (sa->sa_family == AF_INET6)
+ return ratelimit_allow(&cc->cc_ratelimit_v6, sa);
+#endif
+ else
+ return EAFNOSUPPORT;
+ }
+ return 0;
+}
+
+/* Private functions */
+static void
+cookie_precompute_key(uint8_t *key, const uint8_t input[COOKIE_INPUT_SIZE],
+ const char *label)
+{
+ struct blake2s_state blake;
+
+ blake2s_init(&blake, COOKIE_KEY_SIZE);
+ blake2s_update(&blake, label, strlen(label));
+ blake2s_update(&blake, input, COOKIE_INPUT_SIZE);
+ blake2s_final(&blake, key);
+}
+
+static void
+cookie_macs_mac1(struct cookie_macs *cm, const void *buf, size_t len,
+ const uint8_t key[COOKIE_KEY_SIZE])
+{
+ struct blake2s_state state;
+ blake2s_init_key(&state, COOKIE_MAC_SIZE, key, COOKIE_KEY_SIZE);
+ blake2s_update(&state, buf, len);
+ blake2s_final(&state, cm->mac1);
+}
+
+static void
+cookie_macs_mac2(struct cookie_macs *cm, const void *buf, size_t len,
+ const uint8_t key[COOKIE_COOKIE_SIZE])
+{
+ struct blake2s_state state;
+ blake2s_init_key(&state, COOKIE_MAC_SIZE, key, COOKIE_COOKIE_SIZE);
+ blake2s_update(&state, buf, len);
+ blake2s_update(&state, cm->mac1, COOKIE_MAC_SIZE);
+ blake2s_final(&state, cm->mac2);
+}
+
+static int
+cookie_timer_expired(struct timespec *birthdate, time_t sec, long nsec)
+{
+ struct timespec uptime;
+ struct timespec expire = { .tv_sec = sec, .tv_nsec = nsec };
+
+ if (birthdate->tv_sec == 0 && birthdate->tv_nsec == 0)
+ return ETIMEDOUT;
+
+ getnanouptime(&uptime);
+ timespecadd(birthdate, &expire, &expire);
+ return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0;
+}
+
+static void
+cookie_checker_make_cookie(struct cookie_checker *cc,
+ uint8_t cookie[COOKIE_COOKIE_SIZE], struct sockaddr *sa)
+{
+ struct blake2s_state state;
+
+ rw_enter_write(&cc->cc_secret_lock);
+ if (cookie_timer_expired(&cc->cc_secret_birthdate,
+ COOKIE_SECRET_MAX_AGE, 0)) {
+ arc4random_buf(cc->cc_secret, COOKIE_SECRET_SIZE);
+ getnanouptime(&cc->cc_secret_birthdate);
+ }
+ blake2s_init_key(&state, COOKIE_COOKIE_SIZE, cc->cc_secret,
+ COOKIE_SECRET_SIZE);
+ rw_exit_write(&cc->cc_secret_lock);
+
+ if (sa->sa_family == AF_INET) {
+ blake2s_update(&state, (uint8_t *)&satosin(sa)->sin_addr,
+ sizeof(struct in_addr));
+ blake2s_update(&state, (uint8_t *)&satosin(sa)->sin_port,
+ sizeof(in_port_t));
+ blake2s_final(&state, cookie);
+#ifdef INET6
+ } else if (sa->sa_family == AF_INET6) {
+ blake2s_update(&state, (uint8_t *)&satosin6(sa)->sin6_addr,
+ sizeof(struct in6_addr));
+ blake2s_update(&state, (uint8_t *)&satosin6(sa)->sin6_port,
+ sizeof(in_port_t));
+ blake2s_final(&state, cookie);
+#endif
+ } else {
+ arc4random_buf(cookie, COOKIE_COOKIE_SIZE);
+ }
+}
+
+static int
+ratelimit_init(struct ratelimit *rl, struct pool *pool)
+{
+ rw_init(&rl->rl_lock, "ratelimit_lock");
+ arc4random_buf(&rl->rl_secret, sizeof(rl->rl_secret));
+ rl->rl_table = hashinit(RATELIMIT_SIZE, M_DEVBUF, M_NOWAIT,
+ &rl->rl_table_mask);
+ rl->rl_pool = pool;
+ rl->rl_table_num = 0;
+ return rl->rl_table == NULL ? ENOBUFS : 0;
+}
+
+static void
+ratelimit_deinit(struct ratelimit *rl)
+{
+ rw_enter_write(&rl->rl_lock);
+ ratelimit_gc(rl, 1);
+ hashfree(rl->rl_table, RATELIMIT_SIZE, M_DEVBUF);
+ rw_exit_write(&rl->rl_lock);
+}
+
+static void
+ratelimit_gc(struct ratelimit *rl, int force)
+{
+ size_t i;
+ struct ratelimit_entry *r, *tr;
+ struct timespec expiry;
+
+ rw_assert_wrlock(&rl->rl_lock);
+
+ if (force) {
+ for (i = 0; i < RATELIMIT_SIZE; i++) {
+ LIST_FOREACH_SAFE(r, &rl->rl_table[i], r_entry, tr) {
+ rl->rl_table_num--;
+ LIST_REMOVE(r, r_entry);
+ pool_put(rl->rl_pool, r);
+ }
+ }
+ return;
+ }
+
+ if ((cookie_timer_expired(&rl->rl_last_gc, ELEMENT_TIMEOUT, 0) &&
+ rl->rl_table_num > 0)) {
+ getnanouptime(&rl->rl_last_gc);
+ getnanouptime(&expiry);
+ expiry.tv_sec -= ELEMENT_TIMEOUT;
+
+ for (i = 0; i < RATELIMIT_SIZE; i++) {
+ LIST_FOREACH_SAFE(r, &rl->rl_table[i], r_entry, tr) {
+ if (timespeccmp(&r->r_last_time, &expiry, <)) {
+ rl->rl_table_num--;
+ LIST_REMOVE(r, r_entry);
+ pool_put(rl->rl_pool, r);
+ }
+ }
+ }
+ }
+}
+
+static int
+ratelimit_allow(struct ratelimit *rl, struct sockaddr *sa)
+{
+ uint64_t key, tokens;
+ struct timespec diff;
+ struct ratelimit_entry *r;
+ int ret = ECONNREFUSED;
+
+ if (sa->sa_family == AF_INET)
+ key = SipHash24(&rl->rl_secret, &satosin(sa)->sin_addr,
+ IPV4_MASK_SIZE);
+#ifdef INET6
+ else if (sa->sa_family == AF_INET6)
+ key = SipHash24(&rl->rl_secret, &satosin6(sa)->sin6_addr,
+ IPV6_MASK_SIZE);
+#endif
+ else
+ return ret;
+
+ rw_enter_write(&rl->rl_lock);
+
+ LIST_FOREACH(r, &rl->rl_table[key & rl->rl_table_mask], r_entry) {
+ if (r->r_af != sa->sa_family)
+ continue;
+
+ if (r->r_af == AF_INET && bcmp(&r->r_in,
+ &satosin(sa)->sin_addr, IPV4_MASK_SIZE) != 0)
+ continue;
+
+#ifdef INET6
+ if (r->r_af == AF_INET6 && bcmp(&r->r_in6,
+ &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE) != 0)
+ continue;
+#endif
+
+ /* If we get to here, we've found an entry for the endpoint.
+ * We apply standard token bucket, by calculating the time
+ * lapsed since our last_time, adding that, ensuring that we
+ * cap the tokens at TOKEN_MAX. If the endpoint has no tokens
+ * left (that is tokens <= INITIATION_COST) then we block the
+ * request, otherwise we subtract the INITITIATION_COST and
+ * return OK. */
+ diff = r->r_last_time;
+ getnanouptime(&r->r_last_time);
+ timespecsub(&r->r_last_time, &diff, &diff);
+
+ tokens = r->r_tokens + diff.tv_sec * NSEC_PER_SEC + diff.tv_nsec;
+
+ if (tokens > TOKEN_MAX)
+ tokens = TOKEN_MAX;
+
+ if (tokens >= INITIATION_COST) {
+ r->r_tokens = tokens - INITIATION_COST;
+ goto ok;
+ } else {
+ r->r_tokens = tokens;
+ goto error;
+ }
+ }
+
+ /* If we get to here, we didn't have an entry for the endpoint. */
+ ratelimit_gc(rl, 0);
+
+ /* Hard limit on number of entries */
+ if (rl->rl_table_num >= RATELIMIT_SIZE_MAX)
+ goto error;
+
+ /* Goto error if out of memory */
+ if ((r = pool_get(rl->rl_pool, PR_NOWAIT)) == NULL)
+ goto error;
+
+ rl->rl_table_num++;
+
+ /* Insert entry into the hashtable and ensure it's initialised */
+ LIST_INSERT_HEAD(&rl->rl_table[key & rl->rl_table_mask], r, r_entry);
+ r->r_af = sa->sa_family;
+ if (r->r_af == AF_INET)
+ memcpy(&r->r_in, &satosin(sa)->sin_addr, IPV4_MASK_SIZE);
+#ifdef INET6
+ else if (r->r_af == AF_INET6)
+ memcpy(&r->r_in6, &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE);
+#endif
+
+ getnanouptime(&r->r_last_time);
+ r->r_tokens = TOKEN_MAX - INITIATION_COST;
+ok:
+ ret = 0;
+error:
+ rw_exit_write(&rl->rl_lock);
+ return ret;
+}
+
+#ifdef WGTEST
+
+#define MESSAGE_LEN 64
+#define T_FAILED_ITER(test) do { \
+ printf("%s %s: failed. iter: %d\n", __func__, test, i); \
+ goto cleanup; \
+} while (0)
+#define T_FAILED(test) do { \
+ printf("%s %s: failed.\n", __func__, test); \
+ goto cleanup; \
+} while (0)
+#define T_PASSED printf("%s: passed.\n", __func__)
+
+static const struct expected_results {
+ int result;
+ int sleep_time;
+} rl_expected[] = {
+ [0 ... INITIATIONS_BURSTABLE - 1] = { 0, 0 },
+ [INITIATIONS_BURSTABLE] = { ECONNREFUSED, 0 },
+ [INITIATIONS_BURSTABLE + 1] = { 0, NSEC_PER_SEC / INITIATIONS_PER_SECOND },
+ [INITIATIONS_BURSTABLE + 2] = { ECONNREFUSED, 0 },
+ [INITIATIONS_BURSTABLE + 3] = { 0, (NSEC_PER_SEC / INITIATIONS_PER_SECOND) * 2 },
+ [INITIATIONS_BURSTABLE + 4] = { 0, 0 },
+ [INITIATIONS_BURSTABLE + 5] = { ECONNREFUSED, 0 }
+};
+
+static void
+cookie_ratelimit_timings_test()
+{
+ struct ratelimit rl;
+ struct pool rl_pool;
+ struct sockaddr_in sin;
+#ifdef INET6
+ struct sockaddr_in6 sin6;
+#endif
+ int i;
+
+ pool_init(&rl_pool, sizeof(struct ratelimit_entry), 0,
+ IPL_NONE, 0, "rl", NULL);
+ ratelimit_init(&rl, &rl_pool);
+
+ sin.sin_family = AF_INET;
+#ifdef INET6
+ sin6.sin6_family = AF_INET6;
+#endif
+
+ for (i = 0; i < sizeof(rl_expected)/sizeof(*rl_expected); i++) {
+ if (rl_expected[i].sleep_time != 0)
+ tsleep_nsec(&rl, PWAIT, "rl", rl_expected[i].sleep_time);
+
+ /* The first v4 ratelimit_allow is against a constant address,
+ * and should be indifferent to the port. */
+ sin.sin_addr.s_addr = 0x01020304;
+ sin.sin_port = arc4random();
+
+ if (ratelimit_allow(&rl, sintosa(&sin)) != rl_expected[i].result)
+ T_FAILED_ITER("malicious v4");
+
+ /* The second ratelimit_allow is to test that an arbitrary
+ * address is still allowed. */
+ sin.sin_addr.s_addr += i + 1;
+ sin.sin_port = arc4random();
+
+ if (ratelimit_allow(&rl, sintosa(&sin)) != 0)
+ T_FAILED_ITER("non-malicious v4");
+
+#ifdef INET6
+ /* The first v6 ratelimit_allow is against a constant address,
+ * and should be indifferent to the port. We also mutate the
+ * lower 64 bits of the address as we want to ensure ratelimit
+ * occurs against the higher 64 bits (/64 network). */
+ sin6.sin6_addr.s6_addr32[0] = 0x01020304;
+ sin6.sin6_addr.s6_addr32[1] = 0x05060708;
+ sin6.sin6_addr.s6_addr32[2] = i;
+ sin6.sin6_addr.s6_addr32[3] = i;
+ sin6.sin6_port = arc4random();
+
+ if (ratelimit_allow(&rl, sin6tosa(&sin6)) != rl_expected[i].result)
+ T_FAILED_ITER("malicious v6");
+
+ /* Again, test that an address different to above is still
+ * allowed. */
+ sin6.sin6_addr.s6_addr32[0] += i + 1;
+ sin6.sin6_port = arc4random();
+
+ if (ratelimit_allow(&rl, sintosa(&sin)) != 0)
+ T_FAILED_ITER("non-malicious v6");
+#endif
+ }
+ T_PASSED;
+cleanup:
+ ratelimit_deinit(&rl);
+ pool_destroy(&rl_pool);
+}
+
+static void
+cookie_ratelimit_capacity_test()
+{
+ struct ratelimit rl;
+ struct pool rl_pool;
+ struct sockaddr_in sin;
+ int i;
+
+ pool_init(&rl_pool, sizeof(struct ratelimit_entry), 0,
+ IPL_NONE, 0, "rl", NULL);
+ ratelimit_init(&rl, &rl_pool);
+
+ sin.sin_family = AF_INET;
+ sin.sin_port = 1234;
+
+ /* Here we test that the ratelimiter has an upper bound on the number
+ * of addresses to be limited */
+ for (i = 0; i <= RATELIMIT_SIZE_MAX; i++) {
+ sin.sin_addr.s_addr = i;
+ if (i == RATELIMIT_SIZE_MAX) {
+ if (ratelimit_allow(&rl, sintosa(&sin)) != ECONNREFUSED)
+ T_FAILED_ITER("reject");
+ } else {
+ if (ratelimit_allow(&rl, sintosa(&sin)) != 0)
+ T_FAILED_ITER("allow");
+ }
+ }
+ T_PASSED;
+cleanup:
+ ratelimit_deinit(&rl);
+ pool_destroy(&rl_pool);
+}
+
+static void
+cookie_mac_test()
+{
+ struct pool rl_pool;
+ struct cookie_checker checker;
+ struct cookie_maker maker;
+ struct cookie_macs cm;
+ struct sockaddr_in sin;
+ int res, i;
+
+ uint8_t nonce[COOKIE_NONCE_SIZE];
+ uint8_t cookie[COOKIE_ENCRYPTED_SIZE];
+ uint8_t shared[COOKIE_INPUT_SIZE];
+ uint8_t message[MESSAGE_LEN];
+
+ arc4random_buf(shared, COOKIE_INPUT_SIZE);
+ arc4random_buf(message, MESSAGE_LEN);
+
+ /* Init cookie_maker. */
+ cookie_maker_init(&maker, shared);
+
+ /* Init cookie_checker. */
+ pool_init(&rl_pool, sizeof(struct ratelimit_entry), 0,
+ IPL_NONE, 0, "rl", NULL);
+
+ if (cookie_checker_init(&checker, &rl_pool) != 0)
+ T_FAILED("cookie_checker_allocate");
+ cookie_checker_update(&checker, shared);
+
+ /* Create dummy sockaddr */
+ sin.sin_family = AF_INET;
+ sin.sin_len = sizeof(sin);
+ sin.sin_addr.s_addr = 1;
+ sin.sin_port = 51820;
+
+ /* MAC message */
+ cookie_maker_mac(&maker, &cm, message, MESSAGE_LEN);
+
+ /* Check we have a null mac2 */
+ for (i = 0; i < sizeof(cm.mac2); i++)
+ if (cm.mac2[i] != 0)
+ T_FAILED("validate_macs_noload_mac2_zeroed");
+
+ /* Validate all bytes are checked in mac1 */
+ for (i = 0; i < sizeof(cm.mac1); i++) {
+ cm.mac1[i] = ~cm.mac1[i];
+ if (cookie_checker_validate_macs(&checker, &cm, message,
+ MESSAGE_LEN, 0, sintosa(&sin)) != EINVAL)
+ T_FAILED("validate_macs_noload_munge");
+ cm.mac1[i] = ~cm.mac1[i];
+ }
+
+ /* Check mac2 is zeroed */
+ res = 0;
+ for (i = 0; i < sizeof(cm.mac2); i++)
+ res |= cm.mac2[i];
+ if (res != 0)
+ T_FAILED("validate_macs_mac2_checkzero");
+
+
+ /* Check we can successfully validate the MAC */
+ if (cookie_checker_validate_macs(&checker, &cm, message,
+ MESSAGE_LEN, 0, sintosa(&sin)) != 0)
+ T_FAILED("validate_macs_noload_normal");
+
+ /* Check we get a EAGAIN if no mac2 and under load */
+ if (cookie_checker_validate_macs(&checker, &cm, message,
+ MESSAGE_LEN, 1, sintosa(&sin)) != EAGAIN)
+ T_FAILED("validate_macs_load_normal");
+
+ /* Simulate a cookie message */
+ cookie_checker_create_payload(&checker, &cm, nonce, cookie, sintosa(&sin));
+
+ /* Validate all bytes are checked in cookie */
+ for (i = 0; i < sizeof(cookie); i++) {
+ cookie[i] = ~cookie[i];
+ if (cookie_maker_consume_payload(&maker, nonce, cookie) != EINVAL)
+ T_FAILED("consume_payload_munge");
+ cookie[i] = ~cookie[i];
+ }
+
+ /* Check we can actually consume the payload */
+ if (cookie_maker_consume_payload(&maker, nonce, cookie) != 0)
+ T_FAILED("consume_payload_normal");
+
+ /* Check replay isn't allowed */
+ if (cookie_maker_consume_payload(&maker, nonce, cookie) != ETIMEDOUT)
+ T_FAILED("consume_payload_normal_replay");
+
+ /* MAC message again, with MAC2 */
+ cookie_maker_mac(&maker, &cm, message, MESSAGE_LEN);
+
+ /* Check we added a mac2 */
+ res = 0;
+ for (i = 0; i < sizeof(cm.mac2); i++)
+ res |= cm.mac2[i];
+ if (res == 0)
+ T_FAILED("validate_macs_make_mac2");
+
+ /* Check we get OK if mac2 and under load */
+ if (cookie_checker_validate_macs(&checker, &cm, message,
+ MESSAGE_LEN, 1, sintosa(&sin)) != 0)
+ T_FAILED("validate_macs_load_normal_mac2");
+
+ sin.sin_addr.s_addr = ~sin.sin_addr.s_addr;
+ /* Check we get EAGAIN if we munge the source IP */
+ if (cookie_checker_validate_macs(&checker, &cm, message,
+ MESSAGE_LEN, 1, sintosa(&sin)) != EAGAIN)
+ T_FAILED("validate_macs_load_spoofip_mac2");
+ sin.sin_addr.s_addr = ~sin.sin_addr.s_addr;
+
+ /* Check we get OK if mac2 and under load */
+ if (cookie_checker_validate_macs(&checker, &cm, message,
+ MESSAGE_LEN, 1, sintosa(&sin)) != 0)
+ T_FAILED("validate_macs_load_normal_mac2_retry");
+
+ printf("cookie_mac: passed.\n");
+cleanup:
+ cookie_checker_deinit(&checker);
+ pool_destroy(&rl_pool);
+}
+
+void
+cookie_test()
+{
+ cookie_ratelimit_timings_test();
+ cookie_ratelimit_capacity_test();
+ cookie_mac_test();
+}
+
+#endif /* WGTEST */
diff --git a/sys/net/wg_cookie.h b/sys/net/wg_cookie.h
new file mode 100644
index 00000000000..4eaff71b623
--- /dev/null
+++ b/sys/net/wg_cookie.h
@@ -0,0 +1,131 @@
+/*
+ * Copyright (C) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2020 Matt Dunwoodie <ncon@noconroy.net>
+ *
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ */
+
+#ifndef __COOKIE_H__
+#define __COOKIE_H__
+
+#include <sys/types.h>
+#include <sys/time.h>
+#include <sys/rwlock.h>
+#include <sys/queue.h>
+
+#include <netinet/in.h>
+
+#include <crypto/chachapoly.h>
+#include <crypto/blake2s.h>
+#include <crypto/siphash.h>
+
+#define COOKIE_MAC_SIZE 16
+#define COOKIE_KEY_SIZE 32
+#define COOKIE_NONCE_SIZE XCHACHA20POLY1305_NONCE_SIZE
+#define COOKIE_COOKIE_SIZE 16
+#define COOKIE_SECRET_SIZE 32
+#define COOKIE_INPUT_SIZE 32
+#define COOKIE_ENCRYPTED_SIZE (COOKIE_COOKIE_SIZE + COOKIE_MAC_SIZE)
+
+#define COOKIE_MAC1_KEY_LABEL "mac1----"
+#define COOKIE_COOKIE_KEY_LABEL "cookie--"
+#define COOKIE_SECRET_MAX_AGE 120
+#define COOKIE_SECRET_LATENCY 5
+
+/* Constants for initiation rate limiting */
+#define RATELIMIT_SIZE (1 << 13)
+#define RATELIMIT_SIZE_MAX (RATELIMIT_SIZE * 8)
+#define NSEC_PER_SEC 1000000000LL
+#define INITIATIONS_PER_SECOND 20
+#define INITIATIONS_BURSTABLE 5
+#define INITIATION_COST (NSEC_PER_SEC / INITIATIONS_PER_SECOND)
+#define TOKEN_MAX (INITIATION_COST * INITIATIONS_BURSTABLE)
+#define ELEMENT_TIMEOUT 1
+#define IPV4_MASK_SIZE 4 /* Use all 4 bytes of IPv4 address */
+#define IPV6_MASK_SIZE 8 /* Use top 8 bytes (/64) of IPv6 address */
+
+struct cookie_macs {
+ uint8_t mac1[COOKIE_MAC_SIZE];
+ uint8_t mac2[COOKIE_MAC_SIZE];
+};
+
+struct ratelimit_entry {
+ LIST_ENTRY(ratelimit_entry) r_entry;
+ sa_family_t r_af;
+ union {
+ struct in_addr r_in;
+#ifdef INET6
+ struct in6_addr r_in6;
+#endif
+ };
+ struct timespec r_last_time; /* nanouptime */
+ uint64_t r_tokens;
+};
+
+struct ratelimit {
+ SIPHASH_KEY rl_secret;
+ struct pool *rl_pool;
+
+ struct rwlock rl_lock;
+ LIST_HEAD(, ratelimit_entry) *rl_table;
+ u_long rl_table_mask;
+ size_t rl_table_num;
+ struct timespec rl_last_gc; /* nanouptime */
+};
+
+struct cookie_maker {
+ uint8_t cp_mac1_key[COOKIE_KEY_SIZE];
+ uint8_t cp_cookie_key[COOKIE_KEY_SIZE];
+
+ struct rwlock cp_lock;
+ uint8_t cp_cookie[COOKIE_COOKIE_SIZE];
+ struct timespec cp_birthdate; /* nanouptime */
+ int cp_mac1_valid;
+ uint8_t cp_mac1_last[COOKIE_MAC_SIZE];
+};
+
+struct cookie_checker {
+ struct ratelimit cc_ratelimit_v4;
+#ifdef INET6
+ struct ratelimit cc_ratelimit_v6;
+#endif
+
+ struct rwlock cc_key_lock;
+ uint8_t cc_mac1_key[COOKIE_KEY_SIZE];
+ uint8_t cc_cookie_key[COOKIE_KEY_SIZE];
+
+ struct rwlock cc_secret_lock;
+ struct timespec cc_secret_birthdate; /* nanouptime */
+ uint8_t cc_secret[COOKIE_SECRET_SIZE];
+};
+
+void cookie_maker_init(struct cookie_maker *, uint8_t[COOKIE_INPUT_SIZE]);
+int cookie_checker_init(struct cookie_checker *, struct pool *);
+void cookie_checker_update(struct cookie_checker *,
+ uint8_t[COOKIE_INPUT_SIZE]);
+void cookie_checker_deinit(struct cookie_checker *);
+void cookie_checker_create_payload(struct cookie_checker *,
+ struct cookie_macs *cm, uint8_t[COOKIE_NONCE_SIZE],
+ uint8_t [COOKIE_ENCRYPTED_SIZE], struct sockaddr *);
+int cookie_maker_consume_payload(struct cookie_maker *,
+ uint8_t[COOKIE_NONCE_SIZE], uint8_t[COOKIE_ENCRYPTED_SIZE]);
+void cookie_maker_mac(struct cookie_maker *, struct cookie_macs *,
+ void *, size_t);
+int cookie_checker_validate_macs(struct cookie_checker *,
+ struct cookie_macs *, void *, size_t, int, struct sockaddr *);
+
+#ifdef WGTEST
+void cookie_test();
+#endif /* WGTEST */
+
+#endif /* __COOKIE_H__ */
diff --git a/sys/net/wg_noise.c b/sys/net/wg_noise.c
new file mode 100644
index 00000000000..66bdecee80e
--- /dev/null
+++ b/sys/net/wg_noise.c
@@ -0,0 +1,1344 @@
+/*
+ * Copyright (C) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2020 Matt Dunwoodie <ncon@noconroy.net>
+ *
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ */
+
+#include <sys/types.h>
+#include <sys/systm.h>
+#include <sys/param.h>
+#include <sys/atomic.h>
+#include <sys/rwlock.h>
+
+#include <crypto/blake2s.h>
+#include <crypto/curve25519.h>
+#include <crypto/chachapoly.h>
+
+#include <net/wg_noise.h>
+
+/* Private functions */
+static struct noise_keypair *
+ noise_remote_keypair_allocate(struct noise_remote *);
+static void
+ noise_remote_keypair_free(struct noise_remote *,
+ struct noise_keypair *);
+static uint32_t noise_remote_handshake_index_get(struct noise_remote *);
+static void noise_remote_handshake_index_drop(struct noise_remote *);
+
+static uint64_t noise_counter_send(struct noise_counter *);
+static int noise_counter_recv(struct noise_counter *, uint64_t);
+
+static void noise_kdf(uint8_t *, uint8_t *, uint8_t *, const uint8_t *,
+ size_t, size_t, size_t, size_t,
+ const uint8_t [NOISE_HASH_LEN]);
+static int noise_mix_dh(
+ uint8_t [NOISE_HASH_LEN],
+ uint8_t [NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t [NOISE_PUBLIC_KEY_LEN],
+ const uint8_t [NOISE_PUBLIC_KEY_LEN]);
+static int noise_mix_ss(
+ uint8_t ck[NOISE_HASH_LEN],
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t ss[NOISE_PUBLIC_KEY_LEN]);
+static void noise_mix_hash(
+ uint8_t [NOISE_HASH_LEN],
+ const uint8_t *,
+ size_t);
+static void noise_mix_psk(
+ uint8_t [NOISE_HASH_LEN],
+ uint8_t [NOISE_HASH_LEN],
+ uint8_t [NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t [NOISE_SYMMETRIC_KEY_LEN]);
+static void noise_param_init(
+ uint8_t [NOISE_HASH_LEN],
+ uint8_t [NOISE_HASH_LEN],
+ const uint8_t [NOISE_PUBLIC_KEY_LEN]);
+
+static void noise_msg_encrypt(uint8_t *, const uint8_t *, size_t,
+ uint8_t [NOISE_SYMMETRIC_KEY_LEN],
+ uint8_t [NOISE_HASH_LEN]);
+static int noise_msg_decrypt(uint8_t *, const uint8_t *, size_t,
+ uint8_t [NOISE_SYMMETRIC_KEY_LEN],
+ uint8_t [NOISE_HASH_LEN]);
+static void noise_msg_ephemeral(
+ uint8_t [NOISE_HASH_LEN],
+ uint8_t [NOISE_HASH_LEN],
+ const uint8_t src[NOISE_PUBLIC_KEY_LEN]);
+
+static void noise_tai64n_now(uint8_t [NOISE_TIMESTAMP_LEN]);
+static int noise_timer_expired(struct timespec *, time_t, long);
+
+/* Set/Get noise parameters */
+void
+noise_local_init(struct noise_local *l, struct noise_upcall *upcall)
+{
+ bzero(l, sizeof(*l));
+ rw_init(&l->l_identity_lock, "noise_local_identity");
+ l->l_upcall = *upcall;
+}
+
+void
+noise_local_lock_identity(struct noise_local *l)
+{
+ rw_enter_write(&l->l_identity_lock);
+}
+
+void
+noise_local_unlock_identity(struct noise_local *l)
+{
+ rw_exit_write(&l->l_identity_lock);
+}
+
+int
+noise_local_set_private(struct noise_local *l,
+ uint8_t private[NOISE_PUBLIC_KEY_LEN])
+{
+ rw_assert_wrlock(&l->l_identity_lock);
+
+ memcpy(l->l_private, private, NOISE_PUBLIC_KEY_LEN);
+ curve25519_clamp_secret(l->l_private);
+ l->l_has_identity = curve25519_generate_public(l->l_public, private);
+
+ return l->l_has_identity ? 0 : ENXIO;
+}
+
+int
+noise_local_keys(struct noise_local *l, uint8_t public[NOISE_PUBLIC_KEY_LEN],
+ uint8_t private[NOISE_PUBLIC_KEY_LEN])
+{
+ int ret = 0;
+ rw_enter_read(&l->l_identity_lock);
+ if (l->l_has_identity) {
+ if (public != NULL)
+ memcpy(public, l->l_public, NOISE_PUBLIC_KEY_LEN);
+ if (private != NULL)
+ memcpy(private, l->l_private, NOISE_PUBLIC_KEY_LEN);
+ } else {
+ ret = ENXIO;
+ }
+ rw_exit_read(&l->l_identity_lock);
+ return ret;
+}
+
+void
+noise_remote_init(struct noise_remote *r, uint8_t public[NOISE_PUBLIC_KEY_LEN],
+ struct noise_local *l)
+{
+ bzero(r, sizeof(*r));
+ memcpy(r->r_public, public, NOISE_PUBLIC_KEY_LEN);
+ rw_init(&r->r_handshake_lock, "noise_handshake");
+ rw_init(&r->r_keypair_lock, "noise_keypair");
+
+ SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[0], kp_entry);
+ SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[1], kp_entry);
+ SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[2], kp_entry);
+
+ KASSERT(l != NULL);
+ r->r_local = l;
+
+ rw_enter_write(&l->l_identity_lock);
+ noise_remote_precompute(r);
+ rw_exit_write(&l->l_identity_lock);
+}
+
+int
+noise_remote_set_psk(struct noise_remote *r,
+ uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
+{
+ int same;
+ rw_enter_write(&r->r_handshake_lock);
+ same = !timingsafe_bcmp(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
+ if (!same) {
+ memcpy(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
+ }
+ rw_exit_write(&r->r_handshake_lock);
+ return same ? EEXIST : 0;
+}
+
+int
+noise_remote_keys(struct noise_remote *r, uint8_t public[NOISE_PUBLIC_KEY_LEN],
+ uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
+{
+ static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN];
+ int ret;
+
+ if (public != NULL)
+ memcpy(public, r->r_public, NOISE_PUBLIC_KEY_LEN);
+
+ rw_enter_read(&r->r_handshake_lock);
+ if (psk != NULL)
+ memcpy(psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
+ ret = timingsafe_bcmp(r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN);
+ rw_exit_read(&r->r_handshake_lock);
+
+ /* If r_psk != null_psk return 0, else ENOENT (no psk) */
+ return ret ? 0 : ENOENT;
+}
+
+void
+noise_remote_precompute(struct noise_remote *r)
+{
+ struct noise_local *l = r->r_local;
+ rw_assert_wrlock(&l->l_identity_lock);
+ if (!l->l_has_identity)
+ bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN);
+ else if (!curve25519(r->r_ss, l->l_private, r->r_public))
+ bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN);
+
+ rw_enter_write(&r->r_handshake_lock);
+ noise_remote_handshake_index_drop(r);
+ explicit_bzero(&r->r_handshake, sizeof(r->r_handshake));
+ rw_exit_write(&r->r_handshake_lock);
+}
+
+/* Handshake functions */
+int
+noise_create_initiation(struct noise_remote *r, uint32_t *s_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
+{
+ struct noise_handshake *hs = &r->r_handshake;
+ struct noise_local *l = r->r_local;
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
+ int ret = EINVAL;
+
+ rw_enter_read(&l->l_identity_lock);
+ rw_enter_write(&r->r_handshake_lock);
+ if (!l->l_has_identity)
+ goto error;
+ noise_param_init(hs->hs_ck, hs->hs_hash, r->r_public);
+
+ /* e */
+ curve25519_generate_secret(hs->hs_e);
+ if (curve25519_generate_public(ue, hs->hs_e) == 0)
+ goto error;
+ noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue);
+
+ /* es */
+ if (noise_mix_dh(hs->hs_ck, key, hs->hs_e, r->r_public) != 0)
+ goto error;
+
+ /* s */
+ noise_msg_encrypt(es, l->l_public,
+ NOISE_PUBLIC_KEY_LEN, key, hs->hs_hash);
+
+ /* ss */
+ if (noise_mix_ss(hs->hs_ck, key, r->r_ss) != 0)
+ goto error;
+
+ /* {t} */
+ noise_tai64n_now(ets);
+ noise_msg_encrypt(ets, ets,
+ NOISE_TIMESTAMP_LEN, key, hs->hs_hash);
+
+ noise_remote_handshake_index_drop(r);
+ hs->hs_state = CREATED_INITIATION;
+ hs->hs_local_index = noise_remote_handshake_index_get(r);
+ *s_idx = hs->hs_local_index;
+ ret = 0;
+error:
+ rw_exit_write(&r->r_handshake_lock);
+ rw_exit_read(&l->l_identity_lock);
+ explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
+ return ret;
+}
+
+int
+noise_consume_initiation(struct noise_local *l, struct noise_remote **rp,
+ uint32_t s_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
+{
+ struct noise_remote *r;
+ struct noise_handshake hs;
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
+ uint8_t r_public[NOISE_PUBLIC_KEY_LEN];
+ uint8_t timestamp[NOISE_TIMESTAMP_LEN];
+ int ret = EINVAL;
+
+ rw_enter_read(&l->l_identity_lock);
+ if (!l->l_has_identity)
+ goto error;
+ noise_param_init(hs.hs_ck, hs.hs_hash, l->l_public);
+
+ /* e */
+ noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue);
+
+ /* es */
+ if (noise_mix_dh(hs.hs_ck, key, l->l_private, ue) != 0)
+ goto error;
+
+ /* s */
+ if (noise_msg_decrypt(r_public, es,
+ NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
+ goto error;
+
+ /* Lookup the remote we received from */
+ if ((r = l->l_upcall.u_remote_get(l->l_upcall.u_arg, r_public)) == NULL)
+ goto error;
+
+ /* ss */
+ if (noise_mix_ss(hs.hs_ck, key, r->r_ss) != 0)
+ goto error;
+
+ /* {t} */
+ if (noise_msg_decrypt(timestamp, ets,
+ NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
+ goto error;
+
+ hs.hs_state = CONSUMED_INITIATION;
+ hs.hs_local_index = 0;
+ hs.hs_remote_index = s_idx;
+ memcpy(hs.hs_e, ue, NOISE_PUBLIC_KEY_LEN);
+
+ /* We have successfully computed the same results, now we ensure that
+ * this is not an initiation replay, or a flood attack */
+ rw_enter_write(&r->r_handshake_lock);
+
+ /* Replay */
+ if (memcmp(timestamp, r->r_timestamp, NOISE_TIMESTAMP_LEN) > 0)
+ memcpy(r->r_timestamp, timestamp, NOISE_TIMESTAMP_LEN);
+ else
+ goto error_set;
+ /* Flood attack */
+ if (noise_timer_expired(&r->r_last_init, 0, REJECT_INTERVAL))
+ getnanouptime(&r->r_last_init);
+ else
+ goto error_set;
+
+ /* Ok, we're happy to accept this initiation now */
+ noise_remote_handshake_index_drop(r);
+ r->r_handshake = hs;
+ *rp = r;
+ ret = 0;
+error_set:
+ rw_exit_write(&r->r_handshake_lock);
+error:
+ rw_exit_read(&l->l_identity_lock);
+ explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
+ explicit_bzero(&hs, sizeof(hs));
+ return ret;
+}
+
+int
+noise_create_response(struct noise_remote *r, uint32_t *s_idx, uint32_t *r_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN])
+{
+ struct noise_handshake *hs = &r->r_handshake;
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
+ uint8_t e[NOISE_PUBLIC_KEY_LEN];
+ int ret = EINVAL;
+
+ rw_enter_read(&r->r_local->l_identity_lock);
+ rw_enter_write(&r->r_handshake_lock);
+
+ if (hs->hs_state != CONSUMED_INITIATION)
+ goto error;
+
+ /* e */
+ curve25519_generate_secret(e);
+ if (curve25519_generate_public(ue, e) == 0)
+ goto error;
+ noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue);
+
+ /* ee */
+ if (noise_mix_dh(hs->hs_ck, NULL, e, hs->hs_e) != 0)
+ goto error;
+
+ /* se */
+ if (noise_mix_dh(hs->hs_ck, NULL, e, r->r_public) != 0)
+ goto error;
+
+ /* psk */
+ noise_mix_psk(hs->hs_ck, hs->hs_hash, key, r->r_psk);
+
+ /* {} */
+ noise_msg_encrypt(en, NULL, 0, key, hs->hs_hash);
+
+ hs->hs_state = CREATED_RESPONSE;
+ hs->hs_local_index = noise_remote_handshake_index_get(r);
+ *r_idx = hs->hs_remote_index;
+ *s_idx = hs->hs_local_index;
+ ret = 0;
+error:
+ rw_exit_write(&r->r_handshake_lock);
+ rw_exit_read(&r->r_local->l_identity_lock);
+ explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
+ explicit_bzero(e, NOISE_PUBLIC_KEY_LEN);
+ return ret;
+}
+
+int
+noise_consume_response(struct noise_remote *r, uint32_t s_idx, uint32_t r_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN])
+{
+ struct noise_local *l = r->r_local;
+ struct noise_handshake hs;
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
+ uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN];
+ int ret = EINVAL;
+
+ rw_enter_read(&l->l_identity_lock);
+ if (!l->l_has_identity)
+ goto error;
+
+ rw_enter_read(&r->r_handshake_lock);
+ hs = r->r_handshake;
+ memcpy(preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
+ rw_exit_read(&r->r_handshake_lock);
+
+ if (hs.hs_state != CREATED_INITIATION ||
+ hs.hs_local_index != r_idx)
+ goto error;
+
+ /* e */
+ noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue);
+
+ /* ee */
+ if (noise_mix_dh(hs.hs_ck, NULL, hs.hs_e, ue) != 0)
+ goto error;
+
+ /* se */
+ if (noise_mix_dh(hs.hs_ck, NULL, l->l_private, ue) != 0)
+ goto error;
+
+ /* psk */
+ noise_mix_psk(hs.hs_ck, hs.hs_hash, key, preshared_key);
+
+ /* {} */
+ if (noise_msg_decrypt(NULL, en,
+ 0 + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
+ goto error;
+
+ hs.hs_remote_index = s_idx;
+
+ rw_enter_write(&r->r_handshake_lock);
+ if (r->r_handshake.hs_state == hs.hs_state &&
+ r->r_handshake.hs_local_index == hs.hs_local_index) {
+ r->r_handshake = hs;
+ r->r_handshake.hs_state = CONSUMED_RESPONSE;
+ ret = 0;
+ }
+ rw_exit_write(&r->r_handshake_lock);
+error:
+ rw_exit_read(&l->l_identity_lock);
+ explicit_bzero(&hs, sizeof(hs));
+ explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
+ return ret;
+}
+
+int
+noise_remote_begin_session(struct noise_remote *r)
+{
+ struct noise_handshake *hs = &r->r_handshake;
+ struct noise_keypair kp, *next, *current, *previous;
+
+ rw_enter_write(&r->r_handshake_lock);
+
+ /* We now derive the keypair from the handshake */
+ if (hs->hs_state == CONSUMED_RESPONSE) {
+ kp.kp_is_initiator = 1;
+ noise_kdf(kp.kp_send, kp.kp_recv, NULL, NULL,
+ NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
+ hs->hs_ck);
+ } else if (hs->hs_state == CREATED_RESPONSE) {
+ kp.kp_is_initiator = 0;
+ noise_kdf(kp.kp_recv, kp.kp_send, NULL, NULL,
+ NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
+ hs->hs_ck);
+ } else {
+ rw_exit_write(&r->r_keypair_lock);
+ return EINVAL;
+ }
+
+ kp.kp_valid = 1;
+ kp.kp_local_index = hs->hs_local_index;
+ kp.kp_remote_index = hs->hs_remote_index;
+ getnanouptime(&kp.kp_birthdate);
+ rw_init(&kp.kp_ctr.c_lock, "noise_counter");
+ bzero(&kp.kp_ctr, sizeof(kp.kp_ctr));
+
+ /* Now we need to add_new_keypair */
+ rw_enter_write(&r->r_keypair_lock);
+ next = r->r_next;
+ current = r->r_current;
+ previous = r->r_previous;
+
+ if (kp.kp_is_initiator) {
+ if (next != NULL) {
+ r->r_next = NULL;
+ r->r_previous = next;
+ noise_remote_keypair_free(r, current);
+ } else {
+ r->r_previous = current;
+ }
+
+ noise_remote_keypair_free(r, previous);
+
+ r->r_current = noise_remote_keypair_allocate(r);
+ *r->r_current = kp;
+ } else {
+ noise_remote_keypair_free(r, next);
+ r->r_previous = NULL;
+ noise_remote_keypair_free(r, previous);
+
+ r->r_next = noise_remote_keypair_allocate(r);
+ *r->r_next = kp;
+ }
+ rw_exit_write(&r->r_keypair_lock);
+
+ explicit_bzero(&r->r_handshake, sizeof(r->r_handshake));
+ rw_exit_write(&r->r_handshake_lock);
+
+ explicit_bzero(&kp, sizeof(kp));
+ return 0;
+}
+
+void
+noise_remote_clear(struct noise_remote *r)
+{
+ rw_enter_write(&r->r_handshake_lock);
+ noise_remote_handshake_index_drop(r);
+ explicit_bzero(&r->r_handshake, sizeof(r->r_handshake));
+ rw_exit_write(&r->r_handshake_lock);
+
+ rw_enter_write(&r->r_keypair_lock);
+ noise_remote_keypair_free(r, r->r_next);
+ noise_remote_keypair_free(r, r->r_current);
+ noise_remote_keypair_free(r, r->r_previous);
+ r->r_next = NULL;
+ r->r_current = NULL;
+ r->r_previous = NULL;
+ rw_exit_write(&r->r_keypair_lock);
+}
+
+void
+noise_remote_expire_current(struct noise_remote *r)
+{
+ rw_enter_write(&r->r_keypair_lock);
+ if (r->r_next != NULL)
+ r->r_next->kp_valid = 0;
+ if (r->r_current != NULL)
+ r->r_current->kp_valid = 0;
+ rw_exit_write(&r->r_keypair_lock);
+}
+
+int
+noise_remote_ready(struct noise_remote *r)
+{
+ struct noise_keypair *kp;
+ int ret;
+
+ rw_enter_read(&r->r_keypair_lock);
+ /* kp_ctr isn't locked here, we're happy to accept a racy read. */
+ if ((kp = r->r_current) == NULL ||
+ !kp->kp_valid ||
+ noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) ||
+ kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES ||
+ kp->kp_ctr.c_send >= REJECT_AFTER_MESSAGES)
+ ret = EINVAL;
+ else
+ ret = 0;
+ rw_exit_read(&r->r_keypair_lock);
+ return ret;
+}
+
+int
+noise_remote_encrypt(struct noise_remote *r, uint32_t *r_idx, uint64_t *nonce,
+ uint8_t *buf, size_t buflen)
+{
+ struct noise_keypair *kp;
+ int ret = EINVAL;
+
+ rw_enter_read(&r->r_keypair_lock);
+ if ((kp = r->r_current) == NULL)
+ goto error;
+
+ /* We confirm that our values are within our tolerances. We want:
+ * - a valid keypair
+ * - our keypair to be less than REJECT_AFTER_TIME seconds old
+ * - our receive counter to be less than REJECT_AFTER_MESSAGES
+ * - our send counter to be less than REJECT_AFTER_MESSAGES
+ *
+ * kp_ctr isn't locked here, we're happy to accept a racy read. */
+ if (!kp->kp_valid ||
+ noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) ||
+ kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES ||
+ ((*nonce = noise_counter_send(&kp->kp_ctr)) > REJECT_AFTER_MESSAGES))
+ goto error;
+
+ /* We encrypt into the same buffer, so the caller must ensure that buf
+ * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index
+ * are passed back out to the caller through the provided data pointer. */
+ *r_idx = kp->kp_remote_index;
+ chacha20poly1305_encrypt(buf, buf, buflen,
+ NULL, 0, *nonce, kp->kp_send);
+
+ /* If our values are still within tolerances, but we are approaching
+ * the tolerances, we notify the caller with ESTALE that they should
+ * establish a new keypair. The current keypair can continue to be used
+ * until the tolerances are hit. We notify if:
+ * - our send counter is valid and not less than REKEY_AFTER_MESSAGES
+ * - we're the initiator and our keypair is older than
+ * REKEY_AFTER_TIME seconds */
+ ret = ESTALE;
+ if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) ||
+ (kp->kp_is_initiator &&
+ noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME, 0)))
+ goto error;
+
+ ret = 0;
+error:
+ rw_exit_read(&r->r_keypair_lock);
+ return ret;
+}
+
+int
+noise_remote_decrypt(struct noise_remote *r, uint32_t r_idx, uint64_t nonce,
+ uint8_t *buf, size_t buflen)
+{
+ struct noise_keypair *kp;
+ int ret = EINVAL;
+
+ /* We retrieve the keypair corresponding to the provided index. We
+ * attempt the current keypair first as that is most likely. We also
+ * want to make sure that the keypair is valid as it would be
+ * catastrophic to decrypt against a zero'ed keypair. */
+ rw_enter_read(&r->r_keypair_lock);
+
+ if (r->r_current != NULL && r->r_current->kp_local_index == r_idx) {
+ kp = r->r_current;
+ } else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx) {
+ kp = r->r_previous;
+ } else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx) {
+ kp = r->r_next;
+ } else {
+ goto error;
+ }
+
+ /* We confirm that our values are within our tolerances. These values
+ * are the same as the encrypt routine.
+ *
+ * kp_ctr isn't locked here, we're happy to accept a racy read. */
+ if (noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) ||
+ kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES)
+ goto error;
+
+ /* Decrypt, then validate the counter. We don't want to validate the
+ * counter before decrypting as we do not know the message is authentic
+ * prior to decryption. */
+ if (chacha20poly1305_decrypt(buf, buf, buflen,
+ NULL, 0, nonce, kp->kp_recv) == 0)
+ goto error;
+
+ if (noise_counter_recv(&kp->kp_ctr, nonce) != 0)
+ goto error;
+
+ /* If we've received the handshake confirming data packet then move the
+ * next keypair into current. If we do slide the next keypair in, then
+ * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a
+ * data packet can't confirm a session that we are an INITIATOR of. */
+ if (kp == r->r_next) {
+ rw_exit_read(&r->r_keypair_lock);
+ rw_enter_write(&r->r_keypair_lock);
+ if (kp == r->r_next && kp->kp_local_index == r_idx) {
+ noise_remote_keypair_free(r, r->r_previous);
+ r->r_previous = r->r_current;
+ r->r_current = r->r_next;
+ r->r_next = NULL;
+
+ ret = ECONNRESET;
+ goto error;
+ }
+ rw_enter(&r->r_keypair_lock, RW_DOWNGRADE);
+ }
+
+ /* Similar to when we encrypt, we want to notify the caller when we
+ * are approaching our tolerances. We notify if:
+ * - we're the initiator and the current keypair is older than
+ * REKEY_AFTER_TIME_RECV seconds. */
+ ret = ESTALE;
+ kp = r->r_current;
+ if (kp != NULL &&
+ kp->kp_valid &&
+ kp->kp_is_initiator &&
+ noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME_RECV, 0))
+ goto error;
+
+ ret = 0;
+
+error:
+ rw_exit(&r->r_keypair_lock);
+ return ret;
+}
+
+/* Private functions - these should not be called outside this file under any
+ * circumstances. */
+static struct noise_keypair *
+noise_remote_keypair_allocate(struct noise_remote *r)
+{
+ struct noise_keypair *kp;
+ kp = SLIST_FIRST(&r->r_unused_keypairs);
+ SLIST_REMOVE_HEAD(&r->r_unused_keypairs, kp_entry);
+ return kp;
+}
+
+static void
+noise_remote_keypair_free(struct noise_remote *r, struct noise_keypair *kp)
+{
+ struct noise_upcall *u = &r->r_local->l_upcall;
+ if (kp != NULL) {
+ SLIST_INSERT_HEAD(&r->r_unused_keypairs, kp, kp_entry);
+ u->u_index_drop(u->u_arg, kp->kp_local_index);
+ bzero(kp->kp_send, sizeof(kp->kp_send));
+ bzero(kp->kp_recv, sizeof(kp->kp_recv));
+ }
+}
+
+static uint32_t
+noise_remote_handshake_index_get(struct noise_remote *r)
+{
+ struct noise_upcall *u = &r->r_local->l_upcall;
+ return u->u_index_set(u->u_arg, r);
+}
+
+static void
+noise_remote_handshake_index_drop(struct noise_remote *r)
+{
+ struct noise_handshake *hs = &r->r_handshake;
+ struct noise_upcall *u = &r->r_local->l_upcall;
+ rw_assert_wrlock(&r->r_handshake_lock);
+ if (hs->hs_state != HS_ZEROED)
+ u->u_index_drop(u->u_arg, hs->hs_local_index);
+}
+
+static uint64_t
+noise_counter_send(struct noise_counter *ctr)
+{
+#ifdef __LP64__
+ return atomic_inc_long_nv((u_long *)&ctr->c_send) - 1;
+#else
+ uint64_t ret;
+ rw_enter_write(&ctr->c_lock);
+ ret = ctr->c_send++;
+ rw_exit_write(&ctr->c_lock);
+ return ret;
+#endif
+}
+
+static int
+noise_counter_recv(struct noise_counter *ctr, uint64_t recv)
+{
+ uint64_t i, top, index_recv, index_ctr;
+ unsigned long bit;
+ int ret = EEXIST;
+
+ rw_enter_write(&ctr->c_lock);
+
+ /* Check that the recv counter is valid */
+ if (ctr->c_recv >= REJECT_AFTER_MESSAGES ||
+ recv >= REJECT_AFTER_MESSAGES)
+ goto error;
+
+ /* If the packet is out of the window, invalid */
+ if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv)
+ goto error;
+
+ /* If the new counter is ahead of the current counter, we'll need to
+ * zero out the bitmap that has previously been used */
+ index_recv = recv / COUNTER_BITS;
+ index_ctr = ctr->c_recv / COUNTER_BITS;
+
+ if (recv > ctr->c_recv) {
+ top = MIN(index_recv - index_ctr, COUNTER_NUM);
+ for (i = 1; i <= top; i++)
+ ctr->c_backtrack[
+ (i + index_ctr) & (COUNTER_NUM - 1)] = 0;
+ ctr->c_recv = recv;
+ }
+
+ index_recv %= COUNTER_NUM;
+ bit = 1ul << (recv % COUNTER_BITS);
+
+ if (ctr->c_backtrack[index_recv] & bit)
+ goto error;
+
+ ctr->c_backtrack[index_recv] |= bit;
+
+ ret = 0;
+error:
+ rw_exit_write(&ctr->c_lock);
+ return ret;
+}
+
+static void
+noise_kdf(uint8_t *a, uint8_t *b, uint8_t *c, const uint8_t *x,
+ size_t a_len, size_t b_len, size_t c_len, size_t x_len,
+ const uint8_t ck[NOISE_HASH_LEN])
+{
+ uint8_t out[BLAKE2S_HASH_SIZE + 1];
+ uint8_t sec[BLAKE2S_HASH_SIZE];
+
+#ifdef DIAGNOSTIC
+ KASSERT(a_len <= BLAKE2S_HASH_SIZE && b_len <= BLAKE2S_HASH_SIZE &&
+ c_len <= BLAKE2S_HASH_SIZE);
+ KASSERT(!(b || b_len || c || c_len) || (a && a_len));
+ KASSERT(!(c || c_len) || (b && b_len));
+#endif
+
+ /* Extract entropy from "x" into sec */
+ blake2s_hmac(sec, x, ck, BLAKE2S_HASH_SIZE, x_len, NOISE_HASH_LEN);
+
+ if (a == NULL || a_len == 0)
+ goto out;
+
+ /* Expand first key: key = sec, data = 0x1 */
+ out[0] = 1;
+ blake2s_hmac(out, out, sec, BLAKE2S_HASH_SIZE, 1, BLAKE2S_HASH_SIZE);
+ memcpy(a, out, a_len);
+
+ if (b == NULL || b_len == 0)
+ goto out;
+
+ /* Expand second key: key = sec, data = "a" || 0x2 */
+ out[BLAKE2S_HASH_SIZE] = 2;
+ blake2s_hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1,
+ BLAKE2S_HASH_SIZE);
+ memcpy(b, out, b_len);
+
+ if (c == NULL || c_len == 0)
+ goto out;
+
+ /* Expand third key: key = sec, data = "b" || 0x3 */
+ out[BLAKE2S_HASH_SIZE] = 3;
+ blake2s_hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1,
+ BLAKE2S_HASH_SIZE);
+ memcpy(c, out, c_len);
+
+out:
+ /* Clear sensitive data from stack */
+ explicit_bzero(sec, BLAKE2S_HASH_SIZE);
+ explicit_bzero(out, BLAKE2S_HASH_SIZE + 1);
+}
+
+static int
+noise_mix_dh(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t private[NOISE_PUBLIC_KEY_LEN],
+ const uint8_t public[NOISE_PUBLIC_KEY_LEN])
+{
+ uint8_t dh[NOISE_PUBLIC_KEY_LEN];
+
+ if (!curve25519(dh, private, public))
+ return EINVAL;
+ noise_kdf(ck, key, NULL, dh,
+ NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck);
+ explicit_bzero(dh, NOISE_PUBLIC_KEY_LEN);
+ return 0;
+}
+
+static int
+noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t ss[NOISE_PUBLIC_KEY_LEN])
+{
+ static uint8_t null_point[NOISE_PUBLIC_KEY_LEN];
+ if (timingsafe_bcmp(ss, null_point, NOISE_PUBLIC_KEY_LEN) == 0)
+ return ENOENT;
+ noise_kdf(ck, key, NULL, ss,
+ NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck);
+ return 0;
+}
+
+static void
+noise_mix_hash(uint8_t hash[NOISE_HASH_LEN], const uint8_t *src,
+ size_t src_len)
+{
+ struct blake2s_state blake;
+
+ blake2s_init(&blake, NOISE_HASH_LEN);
+ blake2s_update(&blake, hash, NOISE_HASH_LEN);
+ blake2s_update(&blake, src, src_len);
+ blake2s_final(&blake, hash);
+}
+
+static void
+noise_mix_psk(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
+{
+ uint8_t tmp[NOISE_HASH_LEN];
+
+ noise_kdf(ck, tmp, key, psk,
+ NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN,
+ NOISE_SYMMETRIC_KEY_LEN, ck);
+ noise_mix_hash(hash, tmp, NOISE_HASH_LEN);
+ explicit_bzero(tmp, NOISE_HASH_LEN);
+}
+
+static void
+noise_param_init(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
+ const uint8_t s[NOISE_PUBLIC_KEY_LEN])
+{
+ struct blake2s_state blake;
+
+ blake2s(ck, (uint8_t *)NOISE_HANDSHAKE_NAME, NULL,
+ NOISE_HASH_LEN, strlen(NOISE_HANDSHAKE_NAME), 0);
+ blake2s_init(&blake, NOISE_HASH_LEN);
+ blake2s_update(&blake, ck, NOISE_HASH_LEN);
+ blake2s_update(&blake, (uint8_t *)NOISE_IDENTIFIER_NAME,
+ strlen(NOISE_IDENTIFIER_NAME));
+ blake2s_final(&blake, hash);
+
+ noise_mix_hash(hash, s, NOISE_PUBLIC_KEY_LEN);
+}
+
+static void
+noise_msg_encrypt(uint8_t *dst, const uint8_t *src, size_t src_len,
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN])
+{
+ /* Nonce always zero for Noise_IK */
+ chacha20poly1305_encrypt(dst, src, src_len,
+ hash, NOISE_HASH_LEN, 0, key);
+ noise_mix_hash(hash, dst, src_len + NOISE_AUTHTAG_LEN);
+}
+
+static int
+noise_msg_decrypt(uint8_t *dst, const uint8_t *src, size_t src_len,
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN])
+{
+ /* Nonce always zero for Noise_IK */
+ if (!chacha20poly1305_decrypt(dst, src, src_len,
+ hash, NOISE_HASH_LEN, 0, key))
+ return EINVAL;
+ noise_mix_hash(hash, src, src_len);
+ return 0;
+}
+
+static void
+noise_msg_ephemeral(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
+ const uint8_t src[NOISE_PUBLIC_KEY_LEN])
+{
+ noise_mix_hash(hash, src, NOISE_PUBLIC_KEY_LEN);
+ noise_kdf(ck, NULL, NULL, src, NOISE_HASH_LEN, 0, 0,
+ NOISE_PUBLIC_KEY_LEN, ck);
+}
+
+static void
+noise_tai64n_now(uint8_t output[NOISE_TIMESTAMP_LEN])
+{
+ struct timespec time;
+ uint64_t sec;
+ uint32_t nsec;
+
+ getnanotime(&time);
+
+ /* Round down the nsec counter to limit precise timing leak. */
+ time.tv_nsec &= REJECT_INTERVAL_MASK;
+
+ /* https://cr.yp.to/libtai/tai64.html */
+ sec = htobe64(0x400000000000000aULL + time.tv_sec);
+ nsec = htobe32(time.tv_nsec);
+
+ /* memcpy to output buffer, assuming output could be unaligned. */
+ memcpy(output, &sec, sizeof(sec));
+ memcpy(output + sizeof(sec), &nsec, sizeof(nsec));
+}
+
+static int
+noise_timer_expired(struct timespec *birthdate, time_t sec, long nsec)
+{
+ struct timespec uptime;
+ struct timespec expire = { .tv_sec = sec, .tv_nsec = nsec };
+
+ /* We don't really worry about a zeroed birthdate, to avoid the extra
+ * check on every encrypt/decrypt. This does mean that r_last_init
+ * check may fail if getnanouptime is < REJECT_INTERVAL from 0. */
+
+ getnanouptime(&uptime);
+ timespecadd(birthdate, &expire, &expire);
+ return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0;
+}
+
+#ifdef WGTEST
+
+#define MESSAGE_LEN 64
+#define LARGE_MESSAGE_LEN 1420
+
+#define T_LIM (COUNTER_WINDOW_SIZE + 1)
+#define T_INIT do { \
+ bzero(&ctr, sizeof(ctr)); \
+ rw_init(&ctr.c_lock, "counter"); \
+} while (0)
+#define T(num, v, e) do { \
+ if (noise_counter_recv(&ctr, v) != e) { \
+ printf("%s, test %d: failed.\n", __func__, num); \
+ return; \
+ } \
+} while (0)
+#define T_FAILED(test) do { \
+ printf("%s %s: failed\n", __func__, test); \
+ return; \
+} while (0)
+#define T_PASSED printf("%s: passed.\n", __func__)
+
+static struct noise_local al, bl;
+static struct noise_remote ar, br;
+
+static struct noise_initiation {
+ uint32_t s_idx;
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN];
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN];
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN];
+} init;
+
+static struct noise_response {
+ uint32_t s_idx;
+ uint32_t r_idx;
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN];
+ uint8_t en[0 + NOISE_AUTHTAG_LEN];
+} resp;
+
+static uint64_t nonce;
+static uint32_t index;
+static uint8_t data[MESSAGE_LEN + NOISE_AUTHTAG_LEN];
+static uint8_t largedata[LARGE_MESSAGE_LEN + NOISE_AUTHTAG_LEN];
+
+static struct noise_remote *
+upcall_get(void *x0, uint8_t *x1) { return x0; }
+static uint32_t
+upcall_set(void *x0, struct noise_remote *x1) { return 5; }
+static void
+upcall_drop(void *x0, uint32_t x1) { }
+
+static void
+noise_counter_test()
+{
+ struct noise_counter ctr;
+ int i;
+
+ T_INIT;
+ /* T(test number, nonce, expected_response) */
+ T( 1, 0, 0);
+ T( 2, 1, 0);
+ T( 3, 1, EEXIST);
+ T( 4, 9, 0);
+ T( 5, 8, 0);
+ T( 6, 7, 0);
+ T( 7, 7, EEXIST);
+ T( 8, T_LIM, 0);
+ T( 9, T_LIM - 1, 0);
+ T(10, T_LIM - 1, EEXIST);
+ T(11, T_LIM - 2, 0);
+ T(12, 2, 0);
+ T(13, 2, EEXIST);
+ T(14, T_LIM + 16, 0);
+ T(15, 3, EEXIST);
+ T(16, T_LIM + 16, EEXIST);
+ T(17, T_LIM * 4, 0);
+ T(18, T_LIM * 4 - (T_LIM - 1), 0);
+ T(19, 10, EEXIST);
+ T(20, T_LIM * 4 - T_LIM, EEXIST);
+ T(21, T_LIM * 4 - (T_LIM + 1), EEXIST);
+ T(22, T_LIM * 4 - (T_LIM - 2), 0);
+ T(23, T_LIM * 4 + 1 - T_LIM, EEXIST);
+ T(24, 0, EEXIST);
+ T(25, REJECT_AFTER_MESSAGES, EEXIST);
+ T(26, REJECT_AFTER_MESSAGES - 1, 0);
+ T(27, REJECT_AFTER_MESSAGES, EEXIST);
+ T(28, REJECT_AFTER_MESSAGES - 1, EEXIST);
+ T(29, REJECT_AFTER_MESSAGES - 2, 0);
+ T(30, REJECT_AFTER_MESSAGES + 1, EEXIST);
+ T(31, REJECT_AFTER_MESSAGES + 2, EEXIST);
+ T(32, REJECT_AFTER_MESSAGES - 2, EEXIST);
+ T(33, REJECT_AFTER_MESSAGES - 3, 0);
+ T(34, 0, EEXIST);
+
+ T_INIT;
+ for (i = 1; i <= COUNTER_WINDOW_SIZE; ++i)
+ T(35, i, 0);
+ T(36, 0, 0);
+ T(37, 0, EEXIST);
+
+ T_INIT;
+ for (i = 2; i <= COUNTER_WINDOW_SIZE + 1; ++i)
+ T(38, i, 0);
+ T(39, 1, 0);
+ T(40, 0, EEXIST);
+
+ T_INIT;
+ for (i = COUNTER_WINDOW_SIZE + 1; i-- > 0;)
+ T(41, i, 0);
+
+ T_INIT;
+ for (i = COUNTER_WINDOW_SIZE + 2; i-- > 1;)
+ T(42, i, 0);
+ T(43, 0, EEXIST);
+
+ T_INIT;
+ for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1;)
+ T(44, i, 0);
+ T(45, COUNTER_WINDOW_SIZE + 1, 0);
+ T(46, 0, EEXIST);
+
+ T_INIT;
+ for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1;)
+ T(47, i, 0);
+ T(48, 0, 0);
+ T(49, COUNTER_WINDOW_SIZE + 1, 0);
+
+ T_PASSED;
+}
+
+static void
+noise_handshake_init(struct noise_local *al, struct noise_remote *ar,
+ struct noise_local *bl, struct noise_remote *br)
+{
+ uint8_t apriv[NOISE_PUBLIC_KEY_LEN], bpriv[NOISE_PUBLIC_KEY_LEN];
+ uint8_t apub[NOISE_PUBLIC_KEY_LEN], bpub[NOISE_PUBLIC_KEY_LEN];
+ uint8_t psk[NOISE_SYMMETRIC_KEY_LEN];
+
+ struct noise_upcall upcall = {
+ .u_arg = NULL,
+ .u_remote_get = upcall_get,
+ .u_index_set = upcall_set,
+ .u_index_drop = upcall_drop,
+ };
+
+ upcall.u_arg = ar;
+ noise_local_init(al, &upcall);
+ upcall.u_arg = br;
+ noise_local_init(bl, &upcall);
+
+ arc4random_buf(apriv, NOISE_PUBLIC_KEY_LEN);
+ arc4random_buf(bpriv, NOISE_PUBLIC_KEY_LEN);
+
+ noise_local_lock_identity(al);
+ noise_local_set_private(al, apriv);
+ noise_local_unlock_identity(al);
+
+ noise_local_lock_identity(bl);
+ noise_local_set_private(bl, bpriv);
+ noise_local_unlock_identity(bl);
+
+ noise_local_keys(al, apub, NULL);
+ noise_local_keys(bl, bpub, NULL);
+
+ noise_remote_init(ar, bpub, al);
+ noise_remote_init(br, apub, bl);
+
+ arc4random_buf(psk, NOISE_SYMMETRIC_KEY_LEN);
+ noise_remote_set_psk(ar, psk);
+ noise_remote_set_psk(br, psk);
+}
+
+static void
+noise_handshake_test()
+{
+ struct noise_remote *r;
+ int i;
+
+ noise_handshake_init(&al, &ar, &bl, &br);
+
+ /* Create initiation */
+ if (noise_create_initiation(&ar, &init.s_idx,
+ init.ue, init.es, init.ets) != 0)
+ T_FAILED("create_initiation");
+
+ /* Check encrypted (es) validation */
+ for (i = 0; i < sizeof(init.es); i++) {
+ init.es[i] = ~init.es[i];
+ if (noise_consume_initiation(&bl, &r, init.s_idx,
+ init.ue, init.es, init.ets) != EINVAL)
+ T_FAILED("consume_initiation_es");
+ init.es[i] = ~init.es[i];
+ }
+
+ /* Check encrypted (ets) validation */
+ for (i = 0; i < sizeof(init.ets); i++) {
+ init.ets[i] = ~init.ets[i];
+ if (noise_consume_initiation(&bl, &r, init.s_idx,
+ init.ue, init.es, init.ets) != EINVAL)
+ T_FAILED("consume_initiation_ets");
+ init.ets[i] = ~init.ets[i];
+ }
+
+ /* Consume initiation properly */
+ if (noise_consume_initiation(&bl, &r, init.s_idx,
+ init.ue, init.es, init.ets) != 0)
+ T_FAILED("consume_initiation");
+ if (r != &br)
+ T_FAILED("remote_lookup");
+
+ /* Replay initiation */
+ if (noise_consume_initiation(&bl, &r, init.s_idx,
+ init.ue, init.es, init.ets) != EINVAL)
+ T_FAILED("consume_initiation_replay");
+ if (r != &br)
+ T_FAILED("remote_lookup_r_unchanged");
+
+ /* Create response */
+ if (noise_create_response(&br, &resp.s_idx,
+ &resp.r_idx, resp.ue, resp.en) != 0)
+ T_FAILED("create_response");
+
+ /* Check encrypted (en) validation */
+ for (i = 0; i < sizeof(resp.en); i++) {
+ resp.en[i] = ~resp.en[i];
+ if (noise_consume_response(&ar, resp.s_idx,
+ resp.r_idx, resp.ue, resp.en) != EINVAL)
+ T_FAILED("consume_response_en");
+ resp.en[i] = ~resp.en[i];
+ }
+
+ /* Consume response properly */
+ if (noise_consume_response(&ar, resp.s_idx,
+ resp.r_idx, resp.ue, resp.en) != 0)
+ T_FAILED("consume_response");
+
+ /* Derive keys on both sides */
+ if (noise_remote_begin_session(&ar) != 0)
+ T_FAILED("promote_ar");
+ if (noise_remote_begin_session(&br) != 0)
+ T_FAILED("promote_br");
+
+ for (i = 0; i < MESSAGE_LEN; i++)
+ data[i] = i;
+
+ /* Since bob is responder, he must not encrypt until confirmed */
+ if (noise_remote_encrypt(&br, &index, &nonce,
+ data, MESSAGE_LEN) != EINVAL)
+ T_FAILED("encrypt_kci_wait");
+
+ /* Alice now encrypt and gets bob to decrypt */
+ if (noise_remote_encrypt(&ar, &index, &nonce,
+ data, MESSAGE_LEN) != 0)
+ T_FAILED("encrypt_akp");
+ if (noise_remote_decrypt(&br, index, nonce,
+ data, MESSAGE_LEN + NOISE_AUTHTAG_LEN) != ECONNRESET)
+ T_FAILED("decrypt_bkp");
+
+ for (i = 0; i < MESSAGE_LEN; i++)
+ if (data[i] != i)
+ T_FAILED("decrypt_message_akp_bkp");
+
+ /* Now bob has received confirmation, he can encrypt */
+ if (noise_remote_encrypt(&br, &index, &nonce,
+ data, MESSAGE_LEN) != 0)
+ T_FAILED("encrypt_kci_ready");
+ if (noise_remote_decrypt(&ar, index, nonce,
+ data, MESSAGE_LEN + NOISE_AUTHTAG_LEN) != 0)
+ T_FAILED("decrypt_akp");
+
+ for (i = 0; i < MESSAGE_LEN; i++)
+ if (data[i] != i)
+ T_FAILED("decrypt_message_bkp_akp");
+
+ T_PASSED;
+}
+
+static void
+noise_speed_test()
+{
+#define SPEED_ITER (1<<16)
+ struct timespec start, end;
+ struct noise_remote *r;
+ int nsec, i;
+
+#define NSEC 1000000000
+#define T_TIME_START(iter, size) do { \
+ printf("%s %d %d byte encryptions\n", __func__, iter, size); \
+ nanouptime(&start); \
+} while (0)
+#define T_TIME_END(iter, size) do { \
+ nanouptime(&end); \
+ timespecsub(&end, &start, &end); \
+ nsec = (end.tv_sec * NSEC + end.tv_nsec) / iter; \
+ printf("%s %d nsec/iter, %d iter/sec, %d byte/sec\n", \
+ __func__, nsec, NSEC / nsec, NSEC / nsec * size); \
+} while (0)
+#define T_TIME_START_SINGLE(name) do { \
+ printf("%s %s\n", __func__, name); \
+ nanouptime(&start); \
+} while (0)
+#define T_TIME_END_SINGLE() do { \
+ nanouptime(&end); \
+ timespecsub(&end, &start, &end); \
+ nsec = (end.tv_sec * NSEC + end.tv_nsec); \
+ printf("%s %d nsec/iter, %d iter/sec\n", \
+ __func__, nsec, NSEC / nsec); \
+} while (0)
+
+ noise_handshake_init(&al, &ar, &bl, &br);
+
+ T_TIME_START_SINGLE("create_initiation");
+ if (noise_create_initiation(&ar, &init.s_idx,
+ init.ue, init.es, init.ets) != 0)
+ T_FAILED("create_initiation");
+ T_TIME_END_SINGLE();
+
+ T_TIME_START_SINGLE("consume_initiation");
+ if (noise_consume_initiation(&bl, &r, init.s_idx,
+ init.ue, init.es, init.ets) != 0)
+ T_FAILED("consume_initiation");
+ T_TIME_END_SINGLE();
+
+ T_TIME_START_SINGLE("create_response");
+ if (noise_create_response(&br, &resp.s_idx,
+ &resp.r_idx, resp.ue, resp.en) != 0)
+ T_FAILED("create_response");
+ T_TIME_END_SINGLE();
+
+ T_TIME_START_SINGLE("consume_response");
+ if (noise_consume_response(&ar, resp.s_idx,
+ resp.r_idx, resp.ue, resp.en) != 0)
+ T_FAILED("consume_response");
+ T_TIME_END_SINGLE();
+
+ /* Derive keys on both sides */
+ T_TIME_START_SINGLE("derive_keys");
+ if (noise_remote_begin_session(&ar) != 0)
+ T_FAILED("begin_ar");
+ T_TIME_END_SINGLE();
+ if (noise_remote_begin_session(&br) != 0)
+ T_FAILED("begin_br");
+
+ /* Small data encryptions */
+ T_TIME_START(SPEED_ITER, MESSAGE_LEN);
+ for (i = 0; i < SPEED_ITER; i++) {
+ if (noise_remote_encrypt(&ar, &index, &nonce,
+ data, MESSAGE_LEN) != 0)
+ T_FAILED("encrypt_akp");
+ }
+ T_TIME_END(SPEED_ITER, MESSAGE_LEN);
+
+
+ /* Large data encryptions */
+ T_TIME_START(SPEED_ITER, LARGE_MESSAGE_LEN);
+ for (i = 0; i < SPEED_ITER; i++) {
+ if (noise_remote_encrypt(&ar, &index, &nonce,
+ largedata, LARGE_MESSAGE_LEN) != 0)
+ T_FAILED("encrypt_akp");
+ }
+ T_TIME_END(SPEED_ITER, LARGE_MESSAGE_LEN);
+}
+
+void
+noise_test()
+{
+ noise_counter_test();
+ noise_handshake_test();
+ noise_speed_test();
+}
+
+#endif /* WGTEST */
diff --git a/sys/net/wg_noise.h b/sys/net/wg_noise.h
new file mode 100644
index 00000000000..60ba7f85406
--- /dev/null
+++ b/sys/net/wg_noise.h
@@ -0,0 +1,195 @@
+/*
+ * Copyright (C) 2015-2020 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2020 Matt Dunwoodie <ncon@noconroy.net>
+ *
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ */
+
+#ifndef __NOISE_H__
+#define __NOISE_H__
+
+#include <sys/types.h>
+#include <sys/time.h>
+#include <sys/rwlock.h>
+
+#include <crypto/blake2s.h>
+#include <crypto/chachapoly.h>
+#include <crypto/curve25519.h>
+
+#define NOISE_PUBLIC_KEY_LEN CURVE25519_KEY_SIZE
+#define NOISE_SYMMETRIC_KEY_LEN CHACHA20POLY1305_KEY_SIZE
+#define NOISE_TIMESTAMP_LEN (sizeof(uint64_t) + sizeof(uint32_t))
+#define NOISE_AUTHTAG_LEN CHACHA20POLY1305_AUTHTAG_SIZE
+#define NOISE_HASH_LEN BLAKE2S_HASH_SIZE
+
+/* Protocol string constants */
+#define NOISE_HANDSHAKE_NAME "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
+#define NOISE_IDENTIFIER_NAME "WireGuard v1 zx2c4 Jason@zx2c4.com"
+
+/* Constants for the counter */
+#define COUNTER_BITS_TOTAL 8192
+#define COUNTER_BITS (sizeof(unsigned long) * 8)
+#define COUNTER_NUM (COUNTER_BITS_TOTAL / COUNTER_BITS)
+#define COUNTER_WINDOW_SIZE (COUNTER_BITS_TOTAL - COUNTER_BITS)
+
+/* Constants for the keypair */
+#define REKEY_AFTER_MESSAGES (1ull << 60)
+#define REJECT_AFTER_MESSAGES (UINT64_MAX - COUNTER_WINDOW_SIZE - 1)
+#define REKEY_AFTER_TIME 120
+#define REKEY_AFTER_TIME_RECV 165
+#define REJECT_AFTER_TIME 180
+#define REJECT_INTERVAL (1000000000 / 50) /* fifty times per sec */
+/* 24 = floor(log2(REJECT_INTERVAL)) */
+#define REJECT_INTERVAL_MASK (~((1ull<<24)-1))
+
+enum noise_state_hs {
+ HS_ZEROED = 0,
+ CREATED_INITIATION,
+ CONSUMED_INITIATION,
+ CREATED_RESPONSE,
+ CONSUMED_RESPONSE,
+};
+
+struct noise_handshake {
+ enum noise_state_hs hs_state;
+ uint32_t hs_local_index;
+ uint32_t hs_remote_index;
+ uint8_t hs_e[NOISE_PUBLIC_KEY_LEN];
+ uint8_t hs_hash[NOISE_HASH_LEN];
+ uint8_t hs_ck[NOISE_HASH_LEN];
+};
+
+struct noise_counter {
+ struct rwlock c_lock;
+ uint64_t c_send;
+ uint64_t c_recv;
+ unsigned long c_backtrack[COUNTER_NUM];
+};
+
+struct noise_keypair {
+ SLIST_ENTRY(noise_keypair) kp_entry;
+ int kp_valid;
+ int kp_is_initiator;
+ uint32_t kp_local_index;
+ uint32_t kp_remote_index;
+ uint8_t kp_send[NOISE_SYMMETRIC_KEY_LEN];
+ uint8_t kp_recv[NOISE_SYMMETRIC_KEY_LEN];
+ struct timespec kp_birthdate; /* nanouptime */
+ struct noise_counter kp_ctr;
+};
+
+struct noise_remote {
+ uint8_t r_public[NOISE_PUBLIC_KEY_LEN];
+ struct noise_local *r_local;
+ uint8_t r_ss[NOISE_PUBLIC_KEY_LEN];
+
+ struct rwlock r_handshake_lock;
+ struct noise_handshake r_handshake;
+ uint8_t r_psk[NOISE_SYMMETRIC_KEY_LEN];
+ uint8_t r_timestamp[NOISE_TIMESTAMP_LEN];
+ struct timespec r_last_init; /* nanouptime */
+
+ struct rwlock r_keypair_lock;
+ SLIST_HEAD(,noise_keypair) r_unused_keypairs;
+ struct noise_keypair *r_next, *r_current, *r_previous;
+ struct noise_keypair r_keypair[3]; /* 3: next, current, previous. */
+
+};
+
+struct noise_local {
+ struct rwlock l_identity_lock;
+ int l_has_identity;
+ uint8_t l_public[NOISE_PUBLIC_KEY_LEN];
+ uint8_t l_private[NOISE_PUBLIC_KEY_LEN];
+
+ struct noise_upcall {
+ void *u_arg;
+ struct noise_remote *
+ (*u_remote_get)(void *, uint8_t[NOISE_PUBLIC_KEY_LEN]);
+ uint32_t
+ (*u_index_set)(void *, struct noise_remote *);
+ void (*u_index_drop)(void *, uint32_t);
+ } l_upcall;
+};
+
+/* Set/Get noise parameters */
+void noise_local_init(struct noise_local *, struct noise_upcall *);
+void noise_local_lock_identity(struct noise_local *);
+void noise_local_unlock_identity(struct noise_local *);
+int noise_local_set_private(struct noise_local *, uint8_t[NOISE_PUBLIC_KEY_LEN]);
+int noise_local_keys(struct noise_local *, uint8_t[NOISE_PUBLIC_KEY_LEN],
+ uint8_t[NOISE_PUBLIC_KEY_LEN]);
+
+void noise_remote_init(struct noise_remote *, uint8_t[NOISE_PUBLIC_KEY_LEN],
+ struct noise_local *);
+int noise_remote_set_psk(struct noise_remote *, uint8_t[NOISE_SYMMETRIC_KEY_LEN]);
+int noise_remote_keys(struct noise_remote *, uint8_t[NOISE_PUBLIC_KEY_LEN],
+ uint8_t[NOISE_SYMMETRIC_KEY_LEN]);
+
+/* Should be called anytime noise_local_set_private is called */
+void noise_remote_precompute(struct noise_remote *);
+
+/* Cryptographic functions */
+int noise_create_initiation(
+ struct noise_remote *,
+ uint32_t *s_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]);
+
+int noise_consume_initiation(
+ struct noise_local *,
+ struct noise_remote **,
+ uint32_t s_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]);
+
+int noise_create_response(
+ struct noise_remote *,
+ uint32_t *s_idx,
+ uint32_t *r_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t en[0 + NOISE_AUTHTAG_LEN]);
+
+int noise_consume_response(
+ struct noise_remote *,
+ uint32_t s_idx,
+ uint32_t r_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t en[0 + NOISE_AUTHTAG_LEN]);
+
+int noise_remote_begin_session(struct noise_remote *);
+void noise_remote_clear(struct noise_remote *);
+void noise_remote_expire_current(struct noise_remote *);
+
+int noise_remote_ready(struct noise_remote *);
+
+int noise_remote_encrypt(
+ struct noise_remote *,
+ uint32_t *r_idx,
+ uint64_t *nonce,
+ uint8_t *buf,
+ size_t buflen);
+int noise_remote_decrypt(
+ struct noise_remote *,
+ uint32_t r_idx,
+ uint64_t nonce,
+ uint8_t *buf,
+ size_t buflen);
+
+#ifdef WGTEST
+void noise_test();
+#endif /* WGTEST */
+
+#endif /* __NOISE_H__ */