diff options
Diffstat (limited to 'src/if_wg.c')
-rw-r--r-- | src/if_wg.c | 1332 |
1 files changed, 649 insertions, 683 deletions
diff --git a/src/if_wg.c b/src/if_wg.c index 70f6b4b..e583f14 100644 --- a/src/if_wg.c +++ b/src/if_wg.c @@ -76,12 +76,8 @@ __FBSDID("$FreeBSD$"); #include "version.h" #include "if_wg.h" -/* It'd be nice to use IF_MAXMTU, but that means more complicated mbuf allocations, - * so instead just do the biggest mbuf we can easily allocate minus the usual maximum - * IPv6 overhead of 80 bytes. If somebody wants bigger frames, we can revisit this. */ -#define MAX_MTU (MJUM16BYTES - 80) - -#define DEFAULT_MTU 1420 +#define DEFAULT_MTU 1420 +#define MAX_MTU (IF_MAXMTU - 80) #define MAX_STAGED_PKT 128 #define MAX_QUEUED_PKT 1024 @@ -102,7 +98,7 @@ __FBSDID("$FreeBSD$"); #define WG_PKT_COOKIE htole32(3) #define WG_PKT_DATA htole32(4) -#define WG_PKT_WITH_PADDING(n) (((n) + (16-1)) & (~(16-1))) +#define WG_PKT_ALIGNMENT 16 #define WG_KEY_SIZE 32 struct wg_pkt_initiation { @@ -133,7 +129,7 @@ struct wg_pkt_cookie { struct wg_pkt_data { uint32_t t; uint32_t r_idx; - uint8_t nonce[sizeof(uint64_t)]; + uint64_t nonce; uint8_t buf[]; }; @@ -154,16 +150,6 @@ struct wg_endpoint { } e_local; }; -struct wg_tag { - struct m_tag t_tag; - struct wg_endpoint t_endpoint; - struct noise_keypair *t_keypair; - uint64_t t_nonce; - struct mbuf *t_mbuf; - int t_done; - int t_mtu; -}; - struct wg_timers { /* t_lock is for blocking wg_timers_event_* when setting t_disabled. */ struct rwlock t_lock; @@ -190,9 +176,28 @@ struct wg_aip { struct wg_peer *r_peer; }; +struct wg_packet { + STAILQ_ENTRY(wg_packet) p_serial; + STAILQ_ENTRY(wg_packet) p_parallel; + struct wg_endpoint p_endpoint; + struct noise_keypair *p_keypair; + uint64_t p_nonce; + struct mbuf *p_mbuf; + int p_mtu; + sa_family_t p_af; + enum wg_ring_state { + WG_PACKET_UNCRYPTED, + WG_PACKET_CRYPTED, + WG_PACKET_DEAD, + } p_state; +}; + +STAILQ_HEAD(wg_packet_list ,wg_packet); + struct wg_queue { - struct mtx q_mtx; - struct mbufq q; + struct mtx q_mtx; + struct wg_packet_list q_queue; + size_t q_len; }; struct wg_peer { @@ -208,8 +213,8 @@ struct wg_peer { struct wg_endpoint p_endpoint; struct wg_queue p_stage_queue; - struct wg_queue p_encap_queue; - struct wg_queue p_decap_queue; + struct wg_queue p_encrypt_serial; + struct wg_queue p_decrypt_serial; struct grouptask p_send_initiation; struct grouptask p_send_keepalive; @@ -265,17 +270,16 @@ struct wg_softc { TAILQ_HEAD(,wg_peer) sc_peers; size_t sc_peers_num; - struct mbufq sc_handshake_queue; - struct grouptask sc_handshake; - struct noise_local *sc_local; struct cookie_checker sc_cookie; - struct buf_ring *sc_encap_ring; - struct buf_ring *sc_decap_ring; + struct grouptask sc_handshake; + struct wg_queue sc_handshake_queue; struct grouptask *sc_encrypt; struct grouptask *sc_decrypt; + struct wg_queue sc_encrypt_parallel; + struct wg_queue sc_decrypt_parallel; struct sx sc_lock; }; @@ -292,10 +296,8 @@ struct wg_softc { #define GROUPTASK_DRAIN(gtask) \ gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task) -#define MTAG_WIREGUARD 0xBEAD -#define M_ENQUEUED M_PROTO1 - static int clone_count; +static uma_zone_t wg_packet_zone; static uma_zone_t ratelimit_zone; static volatile unsigned long peer_counter = 0; static const char wgname[] = "wg"; @@ -314,15 +316,12 @@ VNET_DEFINE_STATIC(struct if_clone *, wg_cloner); #define V_wg_cloner VNET(wg_cloner) #define WG_CAPS IFCAP_LINKSTATE -#define ph_family PH_loc.eight[5] struct wg_timespec64 { uint64_t tv_sec; uint64_t tv_nsec; }; -static struct wg_tag *wg_tag_get(struct mbuf *); -static struct wg_endpoint *wg_mbuf_endpoint_get(struct mbuf *); static int wg_socket_init(struct wg_softc *, in_port_t); static int wg_socket_bind(struct socket *, struct socket *, in_port_t *); static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *); @@ -351,14 +350,6 @@ static void wg_timers_enable(struct wg_timers *); static void wg_timers_disable(struct wg_timers *); static void wg_timers_set_persistent_keepalive(struct wg_timers *, uint16_t); static void wg_timers_get_last_handshake(struct wg_timers *, struct wg_timespec64 *); -static void wg_queue_init(struct wg_queue *, const char *); -static void wg_queue_deinit(struct wg_queue *); -static void wg_queue_purge(struct wg_queue *); -static struct mbuf *wg_queue_dequeue(struct wg_queue *, struct wg_tag **); -static int wg_queue_len(struct wg_queue *); -static int wg_queue_in(struct wg_peer *, struct mbuf *); -static void wg_queue_out(struct wg_peer *); -static void wg_queue_stage(struct wg_peer *, struct mbuf *); static int wg_aip_init(struct wg_aip_table *); static void wg_aip_destroy(struct wg_aip_table *); static void wg_aip_populate_aip4(struct wg_aip *, const struct in_addr *, uint8_t); @@ -368,31 +359,44 @@ static int wg_peer_remove(struct radix_node *, void *); static void wg_peer_remove_all(struct wg_softc *); static int wg_aip_delete(struct wg_aip_table *, struct wg_peer *); static struct wg_peer *wg_aip_lookup(struct wg_aip_table *, struct mbuf *, enum route_direction); -static int wg_cookie_validate_packet(struct cookie_checker *, struct mbuf *, int); static struct wg_peer *wg_peer_alloc(struct wg_softc *); static void wg_peer_free_deferred(struct noise_remote *); static void wg_peer_destroy(struct wg_peer *); static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t); static void wg_send_initiation(struct wg_peer *); static void wg_send_response(struct wg_peer *); -static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct mbuf *); -static void wg_peer_set_endpoint_from_tag(struct wg_peer *, struct wg_tag *); +static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct wg_endpoint *); +static void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *); static void wg_peer_clear_src(struct wg_peer *); static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *); -static void wg_deliver_out(struct wg_peer *); -static void wg_deliver_in(struct wg_peer *); static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t); static void wg_send_keepalive(struct wg_peer *); -static void wg_handshake(struct wg_softc *, struct mbuf *); -static void wg_encap(struct wg_softc *, struct mbuf *); -static void wg_decap(struct wg_softc *, struct mbuf *); +static void wg_handshake(struct wg_softc *, struct wg_packet *); +static void wg_encrypt(struct wg_softc *, struct wg_packet *); +static void wg_decrypt(struct wg_softc *, struct wg_packet *); static void wg_softc_handshake_receive(struct wg_softc *); static void wg_softc_decrypt(struct wg_softc *); static void wg_softc_encrypt(struct wg_softc *); -static int wg_update_endpoint_addrs(struct wg_endpoint *, const struct sockaddr *, struct ifnet *); -static void wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *); static void wg_encrypt_dispatch(struct wg_softc *); static void wg_decrypt_dispatch(struct wg_softc *); +static void wg_deliver_out(struct wg_peer *); +static void wg_deliver_in(struct wg_peer *); +static struct wg_packet *wg_packet_alloc(struct mbuf *); +static void wg_packet_free(struct wg_packet *); +static void wg_queue_init(struct wg_queue *, const char *); +static void wg_queue_deinit(struct wg_queue *); +static size_t wg_queue_len(struct wg_queue *); +static int wg_queue_enqueue_handshake(struct wg_queue *, struct wg_packet *); +static struct wg_packet *wg_queue_dequeue_handshake(struct wg_queue *); +static void wg_queue_push_staged(struct wg_queue *, struct wg_packet *); +static void wg_queue_enlist_staged(struct wg_queue *, struct wg_packet_list *); +static void wg_queue_delist_staged(struct wg_queue *, struct wg_packet_list *); +static void wg_queue_purge(struct wg_queue *); +static int wg_queue_both(struct wg_queue *, struct wg_queue *, struct wg_packet *); +static struct wg_packet *wg_queue_dequeue_serial(struct wg_queue *); +static struct wg_packet *wg_queue_dequeue_parallel(struct wg_queue *); +static void wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *); +static void wg_peer_send_staged(struct wg_peer *); static void crypto_taskq_setup(struct wg_softc *); static void crypto_taskq_destroy(struct wg_softc *); static int wg_clone_create(struct if_clone *, int, caddr_t); @@ -428,8 +432,8 @@ wg_peer_alloc(struct wg_softc *sc) rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint"); wg_queue_init(&peer->p_stage_queue, "stageq"); - wg_queue_init(&peer->p_encap_queue, "txq"); - wg_queue_init(&peer->p_decap_queue, "rxq"); + wg_queue_init(&peer->p_encrypt_serial, "txq"); + wg_queue_init(&peer->p_decrypt_serial, "rxq"); GROUPTASK_INIT(&peer->p_send_initiation, 0, (gtask_fn_t *)wg_send_initiation, peer); taskqgroup_attach(qgroup_wg_tqg, &peer->p_send_initiation, peer, NULL, NULL, "wg initiation"); @@ -483,8 +487,8 @@ wg_peer_destroy(struct wg_peer *peer) taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv); taskqgroup_detach(qgroup_wg_tqg, &peer->p_send); - wg_queue_deinit(&peer->p_decap_queue); - wg_queue_deinit(&peer->p_encap_queue); + wg_queue_deinit(&peer->p_decrypt_serial); + wg_queue_deinit(&peer->p_encrypt_serial); wg_queue_deinit(&peer->p_stage_queue); /* Final cleanup */ @@ -495,29 +499,31 @@ wg_peer_destroy(struct wg_peer *peer) } static void -wg_peer_set_endpoint_from_tag(struct wg_peer *peer, struct wg_tag *t) +wg_peer_set_endpoint(struct wg_peer *peer, struct wg_endpoint *e) { - struct wg_endpoint *e = &t->t_endpoint; - MPASS(e->e_remote.r_sa.sa_family != 0); if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0) return; + rw_wlock(&peer->p_endpoint_lock); peer->p_endpoint = *e; + rw_wunlock(&peer->p_endpoint_lock); } static void wg_peer_clear_src(struct wg_peer *peer) { - rw_rlock(&peer->p_endpoint_lock); + rw_wlock(&peer->p_endpoint_lock); bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local)); - rw_runlock(&peer->p_endpoint_lock); + rw_wunlock(&peer->p_endpoint_lock); } static void -wg_peer_get_endpoint(struct wg_peer *p, struct wg_endpoint *e) +wg_peer_get_endpoint(struct wg_peer *peer, struct wg_endpoint *e) { - memcpy(e, &p->p_endpoint, sizeof(*e)); + rw_rlock(&peer->p_endpoint_lock); + *e = peer->p_endpoint; + rw_runlock(&peer->p_endpoint_lock); } /* Allowed IP */ @@ -1035,33 +1041,6 @@ retry: DPRINTF(sc, "Unable to send packet: %d\n", ret); } -/* TODO Tag */ -static struct wg_tag * -wg_tag_get(struct mbuf *m) -{ - struct m_tag *tag; - - tag = m_tag_find(m, MTAG_WIREGUARD, NULL); - if (tag == NULL) { - tag = m_tag_get(MTAG_WIREGUARD, sizeof(struct wg_tag), M_NOWAIT|M_ZERO); - m_tag_prepend(m, tag); - MPASS(!SLIST_EMPTY(&m->m_pkthdr.tags)); - MPASS(m_tag_locate(m, MTAG_ABI_COMPAT, MTAG_WIREGUARD, NULL) == tag); - } - return (struct wg_tag *)tag; -} - -static struct wg_endpoint * -wg_mbuf_endpoint_get(struct mbuf *m) -{ - struct wg_tag *hdr; - - if ((hdr = wg_tag_get(m)) == NULL) - return (NULL); - - return (&hdr->t_endpoint); -} - /* Timers */ static void wg_timers_init(struct wg_timers *t) @@ -1392,17 +1371,15 @@ out: static void wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx, - struct mbuf *m) + struct wg_endpoint *e) { struct wg_pkt_cookie pkt; - struct wg_endpoint *e; DPRINTF(sc, "Sending cookie response for denied handshake message\n"); pkt.t = WG_PKT_COOKIE; pkt.r_idx = idx; - e = wg_mbuf_endpoint_get(m); cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce, pkt.ec, &e->e_remote.r_sa); wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt)); @@ -1411,162 +1388,131 @@ wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx, static void wg_send_keepalive(struct wg_peer *peer) { - struct mbuf *m = NULL; - struct wg_tag *t; - struct epoch_tracker et; + struct wg_packet *pkt; + struct mbuf *m; - if (wg_queue_len(&peer->p_stage_queue) != 0) { - NET_EPOCH_ENTER(et); + if (wg_queue_len(&peer->p_stage_queue) > 0) goto send; - } if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL) return; - if ((t = wg_tag_get(m)) == NULL) { - m_freem(m); + if ((pkt = wg_packet_alloc(m)) == NULL) return; - } - t->t_mbuf = NULL; - t->t_done = 0; - t->t_mtu = 0; /* MTU == 0 OK for keepalive */ - NET_EPOCH_ENTER(et); - wg_queue_stage(peer, m); + pkt->p_mtu = 0; + wg_queue_push_staged(&peer->p_stage_queue, pkt); + DPRINTF(peer->p_sc, "Sending keepalive packet to peer %lu\n", peer->p_id); send: - wg_queue_out(peer); - NET_EPOCH_EXIT(et); -} - -static int -wg_cookie_validate_packet(struct cookie_checker *checker, struct mbuf *m, - int under_load) -{ - struct wg_pkt_initiation *init; - struct wg_pkt_response *resp; - struct cookie_macs *macs; - struct wg_endpoint *e; - int type, size; - void *data; - - type = *mtod(m, uint32_t *); - data = m->m_data; - e = wg_mbuf_endpoint_get(m); - if (type == WG_PKT_INITIATION) { - init = mtod(m, struct wg_pkt_initiation *); - macs = &init->m; - size = sizeof(*init) - sizeof(*macs); - } else if (type == WG_PKT_RESPONSE) { - resp = mtod(m, struct wg_pkt_response *); - macs = &resp->m; - size = sizeof(*resp) - sizeof(*macs); - } else - return 0; - - return (cookie_checker_validate_macs(checker, macs, data, size, - under_load, &e->e_remote.r_sa)); + wg_peer_send_staged(peer); } - static void -wg_handshake(struct wg_softc *sc, struct mbuf *m) +wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) { - struct wg_pkt_initiation *init; - struct wg_pkt_response *resp; - struct noise_remote *remote; + struct wg_pkt_initiation *init; + struct wg_pkt_response *resp; struct wg_pkt_cookie *cook; - struct wg_peer *peer; - struct wg_tag *t; - - /* This is global, so that our load calculation applies to the whole - * system. We don't care about races with it at all. - */ - static struct timeval wg_last_underload; - static const struct timeval underload_interval = { UNDERLOAD_TIMEOUT, 0 }; - bool packet_needs_cookie = false; - int underload, res; - - underload = mbufq_len(&sc->sc_handshake_queue) >= - MAX_QUEUED_HANDSHAKES / 8; - if (underload) + struct wg_endpoint *e; + struct wg_peer *peer; + struct mbuf *m; + struct noise_keypair *keypair; + struct noise_remote *remote = NULL; + int res, underload = 0; + static struct timeval wg_last_underload; /* microuptime */ + static const struct timeval underload_interval = { UNDERLOAD_TIMEOUT, 0 }; + + if (wg_queue_len(&sc->sc_handshake_queue) >= MAX_QUEUED_HANDSHAKES/8) { getmicrouptime(&wg_last_underload); - else if (wg_last_underload.tv_sec != 0) { + underload = 1; + } else if (wg_last_underload.tv_sec != 0) { if (!ratecheck(&wg_last_underload, &underload_interval)) underload = 1; else bzero(&wg_last_underload, sizeof(wg_last_underload)); } - res = wg_cookie_validate_packet(&sc->sc_cookie, m, underload); + m = pkt->p_mbuf; + e = &pkt->p_endpoint; - if (res == EINVAL) { - DPRINTF(sc, "Invalid initiation MAC\n"); - goto free; - } else if (res == ECONNREFUSED) { - DPRINTF(sc, "Handshake ratelimited\n"); - goto free; - } else if (res == EAGAIN) { - packet_needs_cookie = true; - } else if (res != 0) { - DPRINTF(sc, "Unexpected handshake ratelimit response: %d\n", res); - goto free; - } + if ((m = m_pullup(m, m->m_pkthdr.len)) == NULL) + goto error; - t = wg_tag_get(m); switch (*mtod(m, uint32_t *)) { case WG_PKT_INITIATION: init = mtod(m, struct wg_pkt_initiation *); - if (packet_needs_cookie) { - wg_send_cookie(sc, &init->m, init->s_idx, m); - goto free; + res = cookie_checker_validate_macs(&sc->sc_cookie, &init->m, + init, sizeof(*init) - sizeof(init->m), + underload, &e->e_remote.r_sa); + + if (res == EINVAL) { + DPRINTF(sc, "Invalid initiation MAC\n"); + goto error; + } else if (res == ECONNREFUSED) { + DPRINTF(sc, "Handshake ratelimited\n"); + goto error; + } else if (res == EAGAIN) { + wg_send_cookie(sc, &init->m, init->s_idx, e); + goto error; + } else if (res != 0) { + panic("unexpected response: %d\n", res); } + if (noise_consume_initiation(sc->sc_local, &remote, init->s_idx, init->ue, init->es, init->ets) != 0) { - DPRINTF(sc, "Invalid handshake initiation"); - goto free; + DPRINTF(sc, "Invalid handshake initiation\n"); + goto error; } peer = noise_remote_arg(remote); - DPRINTF(sc, "Receiving handshake initiation from peer %llu\n", - (unsigned long long)peer->p_id); - counter_u64_add(peer->p_rx_bytes, sizeof(*init)); - if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); - if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, sizeof(*init)); - wg_peer_set_endpoint_from_tag(peer, t); + + DPRINTF(sc, "Receiving handshake initiation from peer %lu\n", peer->p_id); + + wg_peer_set_endpoint(peer, e); wg_send_response(peer); break; case WG_PKT_RESPONSE: resp = mtod(m, struct wg_pkt_response *); - if (packet_needs_cookie) { - wg_send_cookie(sc, &resp->m, resp->s_idx, m); - goto free; + res = cookie_checker_validate_macs(&sc->sc_cookie, &resp->m, + resp, sizeof(*resp) - sizeof(resp->m), + underload, &e->e_remote.r_sa); + + if (res == EINVAL) { + DPRINTF(sc, "Invalid response MAC\n"); + goto error; + } else if (res == ECONNREFUSED) { + DPRINTF(sc, "Handshake ratelimited\n"); + goto error; + } else if (res == EAGAIN) { + wg_send_cookie(sc, &resp->m, resp->s_idx, e); + goto error; + } else if (res != 0) { + panic("unexpected response: %d\n", res); } - if (noise_consume_response(sc->sc_local, &remote, resp->s_idx, - resp->r_idx, resp->ue, resp->en) != 0) { + + if (noise_consume_response(sc->sc_local, &remote, + resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) { DPRINTF(sc, "Invalid handshake response\n"); - goto free; + goto error; } peer = noise_remote_arg(remote); + DPRINTF(sc, "Receiving handshake response from peer %lu\n", peer->p_id); - DPRINTF(sc, "Receiving handshake response from peer %llu\n", - (unsigned long long)peer->p_id); - + wg_peer_set_endpoint(peer, e); wg_timers_event_session_derived(&peer->p_timers); wg_timers_event_handshake_complete(&peer->p_timers); - - counter_u64_add(peer->p_rx_bytes, sizeof(*resp)); - if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); - if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, sizeof(*resp)); - wg_peer_set_endpoint_from_tag(peer, t); break; case WG_PKT_COOKIE: cook = mtod(m, struct wg_pkt_cookie *); - if ((remote = noise_remote_index_lookup(sc->sc_local, - cook->r_idx)) == NULL) { - DPRINTF(sc, "Unknown cookie index\n"); - goto free; + if ((remote = noise_remote_index_lookup(sc->sc_local, cook->r_idx)) == NULL) { + if ((keypair = noise_keypair_lookup(sc->sc_local, cook->r_idx)) == NULL) { + DPRINTF(sc, "Unknown cookie index\n"); + goto error; + } + remote = noise_keypair_remote(keypair); + noise_keypair_put(keypair); } peer = noise_remote_arg(remote); @@ -1576,191 +1522,198 @@ wg_handshake(struct wg_softc *sc, struct mbuf *m) DPRINTF(sc, "Receiving cookie response\n"); } else { DPRINTF(sc, "Could not decrypt cookie response\n"); + goto error; } - noise_remote_put(remote); - goto free; + + goto not_authenticated; default: - goto free; + panic("invalid packet in handshake queue"); } - MPASS(peer != NULL); + wg_timers_event_any_authenticated_packet_received(&peer->p_timers); wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers); -free: - m_freem(m); +not_authenticated: + counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len); +error: + if (remote != NULL) + noise_remote_put(remote); + wg_packet_free(pkt); } static void wg_softc_handshake_receive(struct wg_softc *sc) { - struct mbuf *m; - - while ((m = mbufq_dequeue(&sc->sc_handshake_queue)) != NULL) - wg_handshake(sc, m); + struct wg_packet *pkt; + while ((pkt = wg_queue_dequeue_handshake(&sc->sc_handshake_queue)) != NULL) + wg_handshake(sc, pkt); } -/* TODO Encrypt */ static void -wg_encap(struct wg_softc *sc, struct mbuf *m) +wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt) { - struct wg_pkt_data *data; - size_t padding_len, plaintext_len, out_len; - struct noise_remote *remote; - struct mbuf *mc; - struct wg_peer *peer; - struct wg_tag *t; - uint64_t nonce; - int allocation_order; - - NET_EPOCH_ASSERT(); - t = wg_tag_get(m); - remote = noise_keypair_remote(t->t_keypair); - peer = noise_remote_arg(remote); - /* We can put the remote as we still hold a keypair ref */ - noise_remote_put(remote); - - plaintext_len = MIN(WG_PKT_WITH_PADDING(m->m_pkthdr.len), t->t_mtu); - padding_len = plaintext_len - m->m_pkthdr.len; - out_len = sizeof(struct wg_pkt_data) + plaintext_len + NOISE_AUTHTAG_LEN; - - if (out_len <= MCLBYTES) - allocation_order = MCLBYTES; - else if (out_len <= MJUMPAGESIZE) - allocation_order = MJUMPAGESIZE; - else if (out_len <= MJUM9BYTES) - allocation_order = MJUM9BYTES; - else if (out_len <= MJUM16BYTES) - allocation_order = MJUM16BYTES; - else + struct wg_pkt_data data; + struct wg_peer *peer; + struct mbuf *m; + uint32_t idx; + uint8_t zeroed[NOISE_AUTHTAG_LEN] = { 0 }; + int pad, len; + + peer = noise_keypair_remote_arg(pkt->p_keypair); + m = pkt->p_mbuf; + + /* Calculate what padding we need to add then limit it to the mtu of + * the interface. This is done to ensure we don't "over pad" a packet + * that is just under the MTU. */ + pad = (-m->m_pkthdr.len) & (WG_PKT_ALIGNMENT - 1); + if (m->m_pkthdr.len + pad > pkt->p_mtu) + pad = pkt->p_mtu - m->m_pkthdr.len; + + /* Pad the packet */ + if (pad != 0 && !m_append(m, pad, zeroed)) goto error; - if ((mc = m_getjcl(M_NOWAIT, MT_DATA, M_PKTHDR, allocation_order)) == NULL) + /* TODO teach noise_keypair_encrypt about mbufs. Currently we have to + * resort to m_defrag to create an encryptable buffer. */ + len = m->m_pkthdr.len; + if (!m_append(m, NOISE_AUTHTAG_LEN, zeroed)) + goto error; + /* TODO this is buffer overflow territory */ + if ((m = m_defrag(m, M_NOWAIT)) == NULL) goto error; - data = mtod(mc, struct wg_pkt_data *); - m_copydata(m, 0, m->m_pkthdr.len, data->buf); - bzero(data->buf + m->m_pkthdr.len, padding_len); - - data->t = WG_PKT_DATA; - - noise_keypair_encrypt(t->t_keypair, &data->r_idx, t->t_nonce, - data->buf, plaintext_len); + /* Do encryption */ + noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, mtod(m, uint8_t *), len); - nonce = htole64(t->t_nonce); /* Wire format is little endian. */ - memcpy(data->nonce, &nonce, sizeof(data->nonce)); + /* Put header into packet */ + M_PREPEND(m, sizeof(struct wg_pkt_data), M_NOWAIT); + if (m == NULL) + goto error; - /* A packet with length 0 is a keepalive packet */ - if (m->m_pkthdr.len == 0) - DPRINTF(sc, "Sending keepalive packet to peer %llu\n", - (unsigned long long)peer->p_id); - /* - * Set the correct output value here since it will be copied - * when we move the pkthdr in send. - */ - mc->m_len = mc->m_pkthdr.len = out_len; - mc->m_flags &= ~(M_MCAST | M_BCAST); + data.t = WG_PKT_DATA; + data.r_idx = idx; + data.nonce = htole64(pkt->p_nonce); + memcpy(mtod(m, void *), &data, sizeof(struct wg_pkt_data)); - t->t_mbuf = mc; - error: - /* XXX membar ? */ - t->t_done = 1; + /* TODO reset packet metadata */ + pkt->p_mbuf = m; + pkt->p_state = WG_PACKET_CRYPTED; + GROUPTASK_ENQUEUE(&peer->p_send); + return; +error: + pkt->p_mbuf = m; + pkt->p_state = WG_PACKET_DEAD; GROUPTASK_ENQUEUE(&peer->p_send); } static void -wg_decap(struct wg_softc *sc, struct mbuf *m) +wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt) { - struct wg_pkt_data *data; - struct wg_peer *peer, *routed_peer; - struct noise_remote *remote; - struct wg_tag *t; - size_t plaintext_len; - uint8_t version; - uint64_t nonce; - int res; + struct wg_pkt_data data; + struct wg_peer *peer, *allowed_peer; + struct mbuf *m; + struct ip *ip; + struct ip6_hdr *ip6; + int res, len; - NET_EPOCH_ASSERT(); - data = mtod(m, struct wg_pkt_data *); - plaintext_len = m->m_pkthdr.len - sizeof(struct wg_pkt_data); + peer = noise_keypair_remote_arg(pkt->p_keypair); + m = pkt->p_mbuf; + len = m->m_pkthdr.len; - t = wg_tag_get(m); - remote = noise_keypair_remote(t->t_keypair); - peer = noise_remote_arg(remote); - /* We can put the remote as we still hold a keypair ref */ - noise_remote_put(remote); + /* Read index, nonce and then adjust to remove the header. */ + memcpy(&data, mtod(m, void *), sizeof(struct wg_pkt_data)); + m_adj(m, sizeof(struct wg_pkt_data)); - memcpy(&nonce, data->nonce, sizeof(nonce)); - t->t_nonce = le64toh(nonce); /* Wire format is little endian. */ + /* TODO teach noise_keypair_decrypt about mbufs. Currently we have to + * resort to m_defrag to create an decryptable buffer. */ + /* TODO this is buffer overflow territory */ + if ((m = m_defrag(m, M_NOWAIT)) == NULL) { + goto error; + } - res = noise_keypair_decrypt(t->t_keypair, t->t_nonce, data->buf, - plaintext_len); + pkt->p_nonce = le64toh(data.nonce); + res = noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, mtod(m, void *), m->m_pkthdr.len); - if (__predict_false(res)) { - if (res == EINVAL) { - goto error; - } else if (res == ECONNRESET) { - wg_timers_event_handshake_complete(&peer->p_timers); - } else { - panic("unexpected response: %d\n", res); - } + if (__predict_false(res == EINVAL)) { + goto error; + } else if (__predict_false(res == ECONNRESET)) { + wg_timers_event_handshake_complete(&peer->p_timers); + } else if (__predict_false(res != 0)) { + panic("unexpected response: %d\n", res); } - wg_peer_set_endpoint_from_tag(peer, t); - /* Remove the data header, and crypto mac tail from the packet */ - m_adj(m, sizeof(struct wg_pkt_data)); m_adj(m, -NOISE_AUTHTAG_LEN); /* A packet with length 0 is a keepalive packet */ - if (m->m_pkthdr.len == 0) { - DPRINTF(peer->p_sc, "Receiving keepalive packet from peer " - "%llu\n", (unsigned long long)peer->p_id); + if (__predict_false(m->m_pkthdr.len == 0)) { + DPRINTF(sc, "Receiving keepalive packet from peer " + "%lu\n", peer->p_id); goto done; } - version = mtod(m, struct ip *)->ip_v; - if (!((version == 4 && m->m_pkthdr.len >= sizeof(struct ip)) || - (version == 6 && m->m_pkthdr.len >= sizeof(struct ip6_hdr)))) { - DPRINTF(peer->p_sc, "Packet is neither ipv4 nor ipv6 from peer " - "%llu\n", (unsigned long long)peer->p_id); + /* + * We can let the network stack handle the intricate validation of the + * IP header, we just worry about the sizeof and the version, so we can + * read the source address in wg_aip_lookup. + */ + ip = mtod(m, struct ip *); + ip6 = mtod(m, struct ip6_hdr *); + + if (m->m_pkthdr.len >= sizeof(struct ip) && ip->ip_v == IPVERSION) { + pkt->p_af = AF_INET; + + len = ntohs(ip->ip_len); + if (len >= sizeof(struct ip) && len < m->m_pkthdr.len) + m_adj(m, len - m->m_pkthdr.len); + + allowed_peer = wg_aip_lookup(&peer->p_sc->sc_aips, m, IN); + } else if (m->m_pkthdr.len >= sizeof(struct ip6_hdr) && + (ip6->ip6_vfc & IPV6_VERSION_MASK) == IPV6_VERSION) { + pkt->p_af = AF_INET6; + + len = ntohs(ip6->ip6_plen) + sizeof(struct ip6_hdr); + if (len < m->m_pkthdr.len) + m_adj(m, len - m->m_pkthdr.len); + + allowed_peer = wg_aip_lookup(&peer->p_sc->sc_aips, m, IN); + } else { + DPRINTF(sc, "Packet is neither ipv4 nor ipv6 from " "peer %lu\n", peer->p_id); goto error; } - routed_peer = wg_aip_lookup(&peer->p_sc->sc_aips, m, IN); - if (routed_peer != peer) { - DPRINTF(peer->p_sc, "Packet has unallowed src IP from peer " - "%llu\n", (unsigned long long)peer->p_id); + if (__predict_false(peer != allowed_peer)) { + DPRINTF(sc, "Packet has unallowed src IP from peer " "%lu\n", peer->p_id); goto error; } + /* TODO reset packet metadata */ done: - t->t_mbuf = m; + pkt->p_mbuf = m; + pkt->p_state = WG_PACKET_CRYPTED; + GROUPTASK_ENQUEUE(&peer->p_recv); + return; error: - t->t_done = 1; + pkt->p_mbuf = m; + pkt->p_state = WG_PACKET_DEAD; GROUPTASK_ENQUEUE(&peer->p_recv); } static void wg_softc_decrypt(struct wg_softc *sc) { - struct epoch_tracker et; - struct mbuf *m; - - NET_EPOCH_ENTER(et); - while ((m = buf_ring_dequeue_mc(sc->sc_decap_ring)) != NULL) - wg_decap(sc, m); - NET_EPOCH_EXIT(et); + struct wg_packet *pkt; + while ((pkt = wg_queue_dequeue_parallel(&sc->sc_decrypt_parallel)) != NULL) + wg_decrypt(sc, pkt); } static void wg_softc_encrypt(struct wg_softc *sc) { - struct mbuf *m; - struct epoch_tracker et; - - NET_EPOCH_ENTER(et); - while ((m = buf_ring_dequeue_mc(sc->sc_encap_ring)) != NULL) - wg_encap(sc, m); - NET_EPOCH_EXIT(et); + struct wg_packet *pkt; + while ((pkt = wg_queue_dequeue_parallel(&sc->sc_encrypt_parallel)) != NULL) + wg_encrypt(sc, pkt); } static void @@ -1786,460 +1739,471 @@ wg_decrypt_dispatch(struct wg_softc *sc) static void wg_deliver_out(struct wg_peer *peer) { - struct epoch_tracker et; - struct wg_tag *t; - struct mbuf *m; - struct wg_endpoint endpoint; - size_t len; - int ret; - - NET_EPOCH_ENTER(et); + struct wg_endpoint endpoint; + struct wg_softc *sc = peer->p_sc; + struct wg_packet *pkt; + struct mbuf *m; + int rc, len, data_sent = 0; wg_peer_get_endpoint(peer, &endpoint); - while ((m = wg_queue_dequeue(&peer->p_encap_queue, &t)) != NULL) { - /* t_mbuf will contain the encrypted packet */ - if (t->t_mbuf == NULL) { + while ((pkt = wg_queue_dequeue_serial(&peer->p_encrypt_serial)) != NULL) { + if (pkt->p_state == WG_PACKET_CRYPTED) { + m = pkt->p_mbuf; + pkt->p_mbuf = NULL; + + len = m->m_pkthdr.len; + + rc = wg_send(sc, &endpoint, m); + if (rc == 0) { + wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers); + wg_timers_event_any_authenticated_packet_sent(&peer->p_timers); + if (len > (sizeof(struct wg_pkt_data)+NOISE_AUTHTAG_LEN)) + data_sent = 1; + counter_u64_add(peer->p_tx_bytes, len); + } else if (rc == EADDRNOTAVAIL) { + wg_peer_clear_src(peer); + wg_peer_get_endpoint(peer, &endpoint); + goto error; + } else { + goto error; + } + } else { +error: if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OERRORS, 1); - m_freem(m); - continue; - } - len = t->t_mbuf->m_pkthdr.len; - noise_keypair_put(t->t_keypair); - ret = wg_send(peer->p_sc, &endpoint, t->t_mbuf); - - if (ret == 0) { - wg_timers_event_any_authenticated_packet_traversal( - &peer->p_timers); - wg_timers_event_any_authenticated_packet_sent( - &peer->p_timers); - - if (m->m_pkthdr.len != 0) - wg_timers_event_data_sent(&peer->p_timers); - counter_u64_add(peer->p_tx_bytes, len); - } else if (ret == EADDRNOTAVAIL) { - wg_peer_clear_src(peer); - wg_peer_get_endpoint(peer, &endpoint); } - m_freem(m); - - if (noise_keep_key_fresh_send(peer->p_remote) == 0) - wg_timers_event_want_initiation(&peer->p_timers); + wg_packet_free(pkt); } - NET_EPOCH_EXIT(et); + if (data_sent) + wg_timers_event_data_sent(&peer->p_timers); + if (noise_keep_key_fresh_send(peer->p_remote)) + wg_timers_event_want_initiation(&peer->p_timers); } static void wg_deliver_in(struct wg_peer *peer) { - struct mbuf *m; - struct ifnet *ifp; - struct wg_softc *sc; - struct epoch_tracker et; - struct wg_tag *t; - uint32_t af; - u_int isr; - - NET_EPOCH_ENTER(et); - sc = peer->p_sc; - ifp = sc->sc_ifp; - - while ((m = wg_queue_dequeue(&peer->p_decap_queue, &t)) != NULL) { - /* t_mbuf will contain the encrypted packet */ - if (t->t_mbuf == NULL) { - if_inc_counter(ifp, IFCOUNTER_IERRORS, 1); - m_freem(m); - continue; - } - MPASS(m == t->t_mbuf); - - if (noise_keypair_nonce_check(t->t_keypair, t->t_nonce) != 0) { + struct wg_softc *sc = peer->p_sc; + struct ifnet *ifp = sc->sc_ifp; + struct wg_packet *pkt; + struct mbuf *m; + uint32_t af; + int data_recv = 0; + + while ((pkt = wg_queue_dequeue_serial(&peer->p_decrypt_serial)) != NULL) { + if (pkt->p_state == WG_PACKET_CRYPTED) { + m = pkt->p_mbuf; + if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0) + goto error; + + wg_timers_event_any_authenticated_packet_received(&peer->p_timers); + wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers); + wg_peer_set_endpoint(peer, &pkt->p_endpoint); + + counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len + + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len + + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN); + + if (m->m_pkthdr.len == 0) + goto free; + + MPASS(pkt->p_af == AF_INET || pkt->p_af == AF_INET6); + pkt->p_mbuf = NULL; + data_recv = 1; + + m->m_flags &= ~(M_MCAST | M_BCAST); + m->m_pkthdr.rcvif = ifp; + + af = pkt->p_af; + BPF_MTAP2(ifp, &af, sizeof(af), m); + + CURVNET_SET(ifp->if_vnet); + M_SETFIB(m, ifp->if_fib); + if (pkt->p_af == AF_INET) + netisr_dispatch(NETISR_IP, m); + if (pkt->p_af == AF_INET6) + netisr_dispatch(NETISR_IPV6, m); + CURVNET_RESTORE(); + } else { +error: if_inc_counter(ifp, IFCOUNTER_IERRORS, 1); - noise_keypair_put(t->t_keypair); - m_freem(m); - continue; - } - noise_keypair_put(t->t_keypair); - - wg_timers_event_any_authenticated_packet_received( - &peer->p_timers); - wg_timers_event_any_authenticated_packet_traversal( - &peer->p_timers); - - counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN); - if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); - if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN); - - if (noise_keep_key_fresh_recv(peer->p_remote) == 0) - wg_timers_event_want_initiation(&peer->p_timers); - - if (m->m_pkthdr.len == 0) { - m_freem(m); - continue; - } - - m->m_flags &= ~(M_MCAST | M_BCAST); - m->m_pkthdr.rcvif = ifp; - switch (mtod(m, struct ip *)->ip_v) { - case 4: - isr = NETISR_IP; - af = AF_INET; - break; - case 6: - isr = NETISR_IPV6; - af = AF_INET6; - break; - default: - m_freem(m); - goto done; } +free: + wg_packet_free(pkt); + } - BPF_MTAP2(ifp, &af, sizeof(af), m); - CURVNET_SET(ifp->if_vnet); - M_SETFIB(m, ifp->if_fib); - netisr_dispatch(isr, m); - CURVNET_RESTORE(); -done: + if (data_recv) wg_timers_event_data_received(&peer->p_timers); - } - NET_EPOCH_EXIT(et); + if (noise_keep_key_fresh_recv(peer->p_remote)) + wg_timers_event_want_initiation(&peer->p_timers); } -static int -wg_queue_in(struct wg_peer *peer, struct mbuf *m) +static struct wg_packet * +wg_packet_alloc(struct mbuf *m) { - struct buf_ring *parallel = peer->p_sc->sc_decap_ring; - struct wg_queue *serial = &peer->p_decap_queue; - struct wg_tag *t; - int rc; + struct wg_packet *pkt; - MPASS(wg_tag_get(m) != NULL); - - mtx_lock(&serial->q_mtx); - if ((rc = mbufq_enqueue(&serial->q, m)) == ENOBUFS) { + if ((pkt = uma_zalloc(wg_packet_zone, M_NOWAIT)) == NULL) { m_freem(m); - if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1); - } else { - m->m_flags |= M_ENQUEUED; - rc = buf_ring_enqueue(parallel, m); - if (rc == ENOBUFS) { - t = wg_tag_get(m); - t->t_done = 1; - } + return NULL; } - mtx_unlock(&serial->q_mtx); - return (rc); + + pkt->p_keypair = NULL; + pkt->p_mbuf = m; + return pkt; } static void -wg_queue_stage(struct wg_peer *peer, struct mbuf *m) +wg_packet_free(struct wg_packet *pkt) { - struct wg_queue *q = &peer->p_stage_queue; - mtx_lock(&q->q_mtx); - STAILQ_INSERT_TAIL(&q->q.mq_head, m, m_stailqpkt); - q->q.mq_len++; - while (mbufq_full(&q->q)) { - m = mbufq_dequeue(&q->q); - if (m) { - m_freem(m); - if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1); - } - } - mtx_unlock(&q->q_mtx); + if (pkt->p_keypair != NULL) + noise_keypair_put(pkt->p_keypair); + if (pkt->p_mbuf != NULL) + m_freem(pkt->p_mbuf); + uma_zfree(wg_packet_zone, pkt); } static void -wg_queue_out(struct wg_peer *peer) +wg_queue_init(struct wg_queue *queue, const char *name) { - struct buf_ring *parallel = peer->p_sc->sc_encap_ring; - struct wg_queue *serial = &peer->p_encap_queue; - struct wg_tag *t; - struct noise_keypair *keypair; - struct mbufq staged; - struct mbuf *m; + mtx_init(&queue->q_mtx, name, NULL, MTX_DEF); + STAILQ_INIT(&queue->q_queue); + queue->q_len = 0; +} - if ((keypair = noise_keypair_current(peer->p_remote)) == NULL) { - if (wg_queue_len(&peer->p_stage_queue)) - wg_timers_event_want_initiation(&peer->p_timers); - return; - } +static void +wg_queue_deinit(struct wg_queue *queue) +{ + wg_queue_purge(queue); + mtx_destroy(&queue->q_mtx); +} - /* We first "steal" the staged queue to a local queue, so that we can do these - * remaining operations without having to hold the staged queue mutex. */ - STAILQ_INIT(&staged.mq_head); - mtx_lock(&peer->p_stage_queue.q_mtx); - STAILQ_SWAP(&staged.mq_head, &peer->p_stage_queue.q.mq_head, mbuf); - staged.mq_len = peer->p_stage_queue.q.mq_len; - peer->p_stage_queue.q.mq_len = 0; - staged.mq_maxlen = peer->p_stage_queue.q.mq_maxlen; - mtx_unlock(&peer->p_stage_queue.q_mtx); +static size_t +wg_queue_len(struct wg_queue *queue) +{ + size_t len; + mtx_lock(&queue->q_mtx); + len = queue->q_len; + mtx_unlock(&queue->q_mtx); + return len; +} - while ((m = mbufq_dequeue(&staged)) != NULL) { - if ((t = wg_tag_get(m)) == NULL) { - m_freem(m); - continue; - } - if (noise_keypair_nonce_next(keypair, &t->t_nonce) != 0) { - /* TODO if we get here, it means we are about to - * overflow this keypair sending nonce. We should place - * this back on the staged queue. */ - m_freem(m); - continue; - } - t->t_keypair = noise_keypair_ref(keypair); - mtx_lock(&serial->q_mtx); - if (mbufq_enqueue(&serial->q, m) != 0) { - m_freem(m); - noise_keypair_put(keypair); - if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1); - } else { - m->m_flags |= M_ENQUEUED; - if (buf_ring_enqueue(parallel, m)) { - t = wg_tag_get(m); - t->t_done = 1; - } - } - mtx_unlock(&serial->q_mtx); +static int +wg_queue_enqueue_handshake(struct wg_queue *hs, struct wg_packet *pkt) +{ + int ret = 0; + mtx_lock(&hs->q_mtx); + if (hs->q_len < MAX_QUEUED_HANDSHAKES) { + STAILQ_INSERT_TAIL(&hs->q_queue, pkt, p_parallel); + hs->q_len++; + } else { + ret = ENOBUFS; } - wg_encrypt_dispatch(peer->p_sc); + mtx_unlock(&hs->q_mtx); + if (ret != 0) + wg_packet_free(pkt); + return ret; } -static struct mbuf * -wg_queue_dequeue(struct wg_queue *q, struct wg_tag **t) +static struct wg_packet * +wg_queue_dequeue_handshake(struct wg_queue *hs) { - struct mbuf *m_, *m; - - m = NULL; - mtx_lock(&q->q_mtx); - m_ = mbufq_first(&q->q); - if (m_ != NULL && (*t = wg_tag_get(m_))->t_done) { - m = mbufq_dequeue(&q->q); - m->m_flags &= ~M_ENQUEUED; + struct wg_packet *pkt; + mtx_lock(&hs->q_mtx); + if ((pkt = STAILQ_FIRST(&hs->q_queue)) != NULL) { + STAILQ_REMOVE_HEAD(&hs->q_queue, p_parallel); + hs->q_len--; } - mtx_unlock(&q->q_mtx); - return (m); + mtx_unlock(&hs->q_mtx); + return pkt; } -static int -wg_queue_len(struct wg_queue *q) +static void +wg_queue_push_staged(struct wg_queue *staged, struct wg_packet *pkt) { - /* This access races. We might consider adding locking here. */ - return (mbufq_len(&q->q)); + struct wg_packet *old = NULL; + + mtx_lock(&staged->q_mtx); + if (staged->q_len >= MAX_STAGED_PKT) { + old = STAILQ_FIRST(&staged->q_queue); + STAILQ_REMOVE_HEAD(&staged->q_queue, p_parallel); + staged->q_len--; + } + STAILQ_INSERT_TAIL(&staged->q_queue, pkt, p_parallel); + staged->q_len++; + mtx_unlock(&staged->q_mtx); + + if (old != NULL) + wg_packet_free(old); } static void -wg_queue_init(struct wg_queue *q, const char *name) +wg_queue_enlist_staged(struct wg_queue *staged, struct wg_packet_list *list) { - mtx_init(&q->q_mtx, name, NULL, MTX_DEF); - mbufq_init(&q->q, MAX_QUEUED_PKT); + struct wg_packet *pkt, *tpkt; + STAILQ_FOREACH_SAFE(pkt, list, p_parallel, tpkt) + wg_queue_push_staged(staged, pkt); } static void -wg_queue_deinit(struct wg_queue *q) +wg_queue_delist_staged(struct wg_queue *staged, struct wg_packet_list *list) { - wg_queue_purge(q); - mtx_destroy(&q->q_mtx); + mtx_lock(&staged->q_mtx); + *list = staged->q_queue; + STAILQ_INIT(&staged->q_queue); + staged->q_len = 0; + mtx_unlock(&staged->q_mtx); } static void -wg_queue_purge(struct wg_queue *q) +wg_queue_purge(struct wg_queue *staged) { - mtx_lock(&q->q_mtx); - mbufq_drain(&q->q); - mtx_unlock(&q->q_mtx); + struct wg_packet_list list; + struct wg_packet *pkt, *tpkt; + wg_queue_delist_staged(staged, &list); + STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt) + wg_packet_free(pkt); } static int -wg_update_endpoint_addrs(struct wg_endpoint *e, const struct sockaddr *srcsa, - struct ifnet *rcvif) +wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_packet *pkt) { - const struct sockaddr_in *sa4; -#ifdef INET6 - const struct sockaddr_in6 *sa6; -#endif - int ret = 0; + pkt->p_state = WG_PACKET_UNCRYPTED; - /* - * UDP passes a 2-element sockaddr array: first element is the - * source addr/port, second the destination addr/port. - */ - if (srcsa->sa_family == AF_INET) { - sa4 = (const struct sockaddr_in *)srcsa; - e->e_remote.r_sin = sa4[0]; - e->e_local.l_in = sa4[1].sin_addr; -#ifdef INET6 - } else if (srcsa->sa_family == AF_INET6) { - sa6 = (const struct sockaddr_in6 *)srcsa; - e->e_remote.r_sin6 = sa6[0]; - e->e_local.l_in6 = sa6[1].sin6_addr; -#endif + mtx_lock(&serial->q_mtx); + if (serial->q_len < MAX_QUEUED_PKT) { + serial->q_len++; + STAILQ_INSERT_TAIL(&serial->q_queue, pkt, p_serial); } else { - ret = EAFNOSUPPORT; + mtx_unlock(&serial->q_mtx); + wg_packet_free(pkt); + return ENOBUFS; } + mtx_unlock(&serial->q_mtx); - return (ret); + mtx_lock(¶llel->q_mtx); + if (parallel->q_len < MAX_QUEUED_PKT) { + parallel->q_len++; + STAILQ_INSERT_TAIL(¶llel->q_queue, pkt, p_parallel); + } else { + mtx_unlock(¶llel->q_mtx); + pkt->p_state = WG_PACKET_DEAD; + return ENOBUFS; + } + mtx_unlock(¶llel->q_mtx); + + return 0; +} + +static struct wg_packet * +wg_queue_dequeue_serial(struct wg_queue *serial) +{ + struct wg_packet *pkt = NULL; + mtx_lock(&serial->q_mtx); + if (serial->q_len > 0 && STAILQ_FIRST(&serial->q_queue)->p_state != WG_PACKET_UNCRYPTED) { + serial->q_len--; + pkt = STAILQ_FIRST(&serial->q_queue); + STAILQ_REMOVE_HEAD(&serial->q_queue, p_serial); + } + mtx_unlock(&serial->q_mtx); + return pkt; +} + +static struct wg_packet * +wg_queue_dequeue_parallel(struct wg_queue *parallel) +{ + struct wg_packet *pkt = NULL; + mtx_lock(¶llel->q_mtx); + if (parallel->q_len > 0) { + parallel->q_len--; + pkt = STAILQ_FIRST(¶llel->q_queue); + STAILQ_REMOVE_HEAD(¶llel->q_queue, p_parallel); + } + mtx_unlock(¶llel->q_mtx); + return pkt; } static void -wg_input(struct mbuf *m0, int offset, struct inpcb *inpcb, - const struct sockaddr *srcsa, void *_sc) +wg_input(struct mbuf *m, int offset, struct inpcb *inpcb, + const struct sockaddr *sa, void *_sc) { - struct wg_pkt_data *pkt_data; - struct wg_endpoint *e; - struct wg_softc *sc = _sc; - struct mbuf *m; - int pktlen, pkttype; - struct noise_keypair *keypair; - struct noise_remote *remote; - struct wg_tag *t; - void *data; + const struct sockaddr_in *sin; + const struct sockaddr_in6 *sin6; + struct wg_pkt_data *data; + struct wg_packet *pkt; + struct wg_peer *peer; + struct wg_softc *sc = _sc; - /* Caller provided us with srcsa, no need for this header. */ - m_adj(m0, offset + sizeof(struct udphdr)); + /* Caller provided us with `sa`, no need for this header. */ + m_adj(m, offset + sizeof(struct udphdr)); - /* - * Ensure mbuf has at least enough contiguous data to peel off our - * headers at the beginning, and make a jumbo contigious copy if we've - * got a jumbo frame. This is pretty sloppy, and we should just fix the - * crypto routines to deal with mbuf clusters instead. - */ - if (!m0->m_next) - m = m0; - else { - int allocation_order; - - if (m0->m_pkthdr.len <= MCLBYTES) - allocation_order = MCLBYTES; - else if (m0->m_pkthdr.len <= MJUMPAGESIZE) - allocation_order = MJUMPAGESIZE; - else if (m0->m_pkthdr.len <= MJUM9BYTES) - allocation_order = MJUM9BYTES; - else if (m0->m_pkthdr.len <= MJUM16BYTES) - allocation_order = MJUM16BYTES; - else { - m_freem(m0); - return; - } - if ((m = m_getjcl(M_NOWAIT, MT_DATA, M_PKTHDR, allocation_order)) == NULL) { - m_freem(m0); - return; - } - m->m_len = m->m_pkthdr.len = m0->m_pkthdr.len; - m_copydata(m0, 0, m0->m_pkthdr.len, mtod(m, void *)); - m_freem(m0); + /* Pullup enough to read packet type */ + if ((m = m_pullup(m, sizeof(uint32_t))) == NULL) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); + return; } - data = mtod(m, void *); - pkttype = *(uint32_t *)data; - t = wg_tag_get(m); - if (t == NULL) { - goto free; + + if ((pkt = wg_packet_alloc(m)) == NULL) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); + return; } - e = wg_mbuf_endpoint_get(m); - if (wg_update_endpoint_addrs(e, srcsa, m->m_pkthdr.rcvif)) { - goto free; + /* Save send/recv address and port for later. */ + if (sa->sa_family == AF_INET) { + sin = (const struct sockaddr_in *)sa; + pkt->p_endpoint.e_remote.r_sin = sin[0]; + pkt->p_endpoint.e_local.l_in = sin[1].sin_addr; + } else if (sa->sa_family == AF_INET6) { + sin6 = (const struct sockaddr_in6 *)sa; + pkt->p_endpoint.e_remote.r_sin6 = sin6[0]; + pkt->p_endpoint.e_local.l_in6 = sin6[1].sin6_addr; + } else { + if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1); + goto error; } - pktlen = m->m_pkthdr.len; + if ((m->m_pkthdr.len == sizeof(struct wg_pkt_initiation) && + *mtod(m, uint32_t *) == WG_PKT_INITIATION) || + (m->m_pkthdr.len == sizeof(struct wg_pkt_response) && + *mtod(m, uint32_t *) == WG_PKT_RESPONSE) || + (m->m_pkthdr.len == sizeof(struct wg_pkt_cookie) && + *mtod(m, uint32_t *) == WG_PKT_COOKIE)) { - if ((pktlen == sizeof(struct wg_pkt_initiation) && - pkttype == WG_PKT_INITIATION) || - (pktlen == sizeof(struct wg_pkt_response) && - pkttype == WG_PKT_RESPONSE) || - (pktlen == sizeof(struct wg_pkt_cookie) && - pkttype == WG_PKT_COOKIE)) { - if (mbufq_enqueue(&sc->sc_handshake_queue, m) == 0) { - GROUPTASK_ENQUEUE(&sc->sc_handshake); - } else { + if (wg_queue_enqueue_handshake(&sc->sc_handshake_queue, pkt) != 0) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); DPRINTF(sc, "Dropping handshake packet\n"); - m_freem(m); } - } else if (pktlen >= sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN - && pkttype == WG_PKT_DATA) { - pkt_data = data; - keypair = noise_keypair_lookup(sc->sc_local, pkt_data->r_idx); - if (keypair == NULL) { - if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1); - m_freem(m); - } else if (buf_ring_count(sc->sc_decap_ring) > MAX_QUEUED_PKT) { - if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); - noise_keypair_put(keypair); - m_freem(m); - } else { - t->t_keypair = keypair; - t->t_mbuf = NULL; - t->t_done = 0; + GROUPTASK_ENQUEUE(&sc->sc_handshake); + } else if (m->m_pkthdr.len >= sizeof(struct wg_pkt_data) + + NOISE_AUTHTAG_LEN && *mtod(m, uint32_t *) == WG_PKT_DATA) { - remote = noise_keypair_remote(keypair); - wg_queue_in(noise_remote_arg(remote), m); - wg_decrypt_dispatch(sc); - noise_remote_put(remote); - } + /* Pullup whole header to read r_idx below. */ + if ((pkt->p_mbuf = m_pullup(m, sizeof(struct wg_pkt_data))) == NULL) + goto error; + + data = mtod(pkt->p_mbuf, struct wg_pkt_data *); + if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL) + goto error; + + peer = noise_keypair_remote_arg(pkt->p_keypair); + if (wg_queue_both(&sc->sc_decrypt_parallel, &peer->p_decrypt_serial, pkt) != 0) + if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); + wg_decrypt_dispatch(sc); } else { -free: - m_freem(m); + goto error; + } + return; +error: + if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1); + wg_packet_free(pkt); + return; +} + +static void +wg_peer_send_staged(struct wg_peer *peer) +{ + struct wg_packet_list list; + struct noise_keypair *keypair; + struct wg_packet *pkt, *tpkt; + struct wg_softc *sc = peer->p_sc; + + wg_queue_delist_staged(&peer->p_stage_queue, &list); + + if (STAILQ_EMPTY(&list)) + return; + + if ((keypair = noise_keypair_current(peer->p_remote)) == NULL) + goto error; + + STAILQ_FOREACH(pkt, &list, p_parallel) { + if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0) + goto error_keypair; } + STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt) { + pkt->p_keypair = noise_keypair_ref(keypair); + if (wg_queue_both(&sc->sc_encrypt_parallel, &peer->p_encrypt_serial, pkt) != 0) + if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1); + } + wg_encrypt_dispatch(peer->p_sc); + noise_keypair_put(keypair); + return; + +error_keypair: + noise_keypair_put(keypair); +error: + wg_queue_enlist_staged(&peer->p_stage_queue, &list); + wg_timers_event_want_initiation(&peer->p_timers); } static int wg_transmit(struct ifnet *ifp, struct mbuf *m) { - struct wg_softc *sc = ifp->if_softc; - sa_family_t family; - struct epoch_tracker et; - struct wg_peer *peer; - struct wg_tag *t; - uint32_t af; - int rc = 0; + struct epoch_tracker et; + struct wg_packet *pkt = m->m_pkthdr.PH_loc.ptr; + struct wg_softc *sc = ifp->if_softc; + struct wg_peer *peer; + uint32_t af = pkt->p_af; + int rc = 0; + sa_family_t peer_af; + NET_EPOCH_ENTER(et); /* Work around lifetime issue in the ipv6 mld code. */ - if (__predict_false((ifp->if_flags & IFF_DYING) || !sc)) - return (ENXIO); - - if ((t = wg_tag_get(m)) == NULL) { - rc = ENOBUFS; - goto early_out; + if (__predict_false((ifp->if_flags & IFF_DYING) || !sc)) { + rc = ENXIO; + goto err; } - af = m->m_pkthdr.ph_family; + BPF_MTAP2(ifp, &af, sizeof(af), m); - NET_EPOCH_ENTER(et); - peer = wg_aip_lookup(&sc->sc_aips, m, OUT); + if (pkt->p_af == AF_INET) { + peer = wg_aip_lookup(&sc->sc_aips, m, OUT); + } else if (pkt->p_af == AF_INET6) { + peer = wg_aip_lookup(&sc->sc_aips, m, OUT); + } else { + rc = EAFNOSUPPORT; + goto err; + } + if (__predict_false(peer == NULL)) { rc = ENOKEY; goto err; } - family = peer->p_endpoint.e_remote.r_sa.sa_family; - if (__predict_false(family != AF_INET && family != AF_INET6)) { + peer_af = peer->p_endpoint.e_remote.r_sa.sa_family; + if (__predict_false(peer_af != AF_INET && peer_af != AF_INET6)) { DPRINTF(sc, "No valid endpoint has been configured or " "discovered for peer %llu\n", (unsigned long long)peer->p_id); - rc = EHOSTUNREACH; goto err; } - t->t_mbuf = NULL; - t->t_done = 0; - t->t_mtu = ifp->if_mtu; - wg_queue_stage(peer, m); - wg_queue_out(peer); + wg_queue_push_staged(&peer->p_stage_queue, pkt); + wg_peer_send_staged(peer); NET_EPOCH_EXIT(et); - return (rc); + return (0); err: NET_EPOCH_EXIT(et); -early_out: if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1); - /* TODO: send ICMP unreachable */ - m_free(m); + /* TODO: send ICMP unreachable? */ + wg_packet_free(pkt); return (rc); } static int -wg_output(struct ifnet *ifp, struct mbuf *m, const struct sockaddr *sa, struct route *rt) +wg_output(struct ifnet *ifp, struct mbuf *m, const struct sockaddr *sa, struct route *ro) { - m->m_pkthdr.ph_family = sa->sa_family; + struct wg_packet *pkt; + + if ((pkt = wg_packet_alloc(m)) == NULL) + return ENOBUFS; + + pkt->p_af = sa->sa_family; + pkt->p_mtu = (ro != NULL && ro->ro_mtu > 0) ? ro->ro_mtu : ifp->if_mtu; + m->m_pkthdr.PH_loc.ptr = pkt; + return (wg_transmit(ifp, m)); } @@ -2702,10 +2666,8 @@ wg_up(struct wg_softc *sc) rc = wg_socket_init(sc, sc->sc_socket.so_port); if (rc == 0) { - TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { + TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) wg_timers_enable(&peer->p_timers); - wg_queue_out(peer); - } if_link_state_change(sc->sc_ifp, LINK_STATE_UP); } else { ifp->if_drv_flags &= ~IFF_DRV_RUNNING; @@ -2734,7 +2696,7 @@ wg_down(struct wg_softc *sc) wg_timers_disable(&peer->p_timers); } - mbufq_drain(&sc->sc_handshake_queue); + wg_queue_purge(&sc->sc_handshake_queue); TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { noise_remote_handshake_clear(peer->p_remote); @@ -2804,14 +2766,14 @@ wg_clone_create(struct if_clone *ifc, int unit, caddr_t params) atomic_add_int(&clone_count, 1); ifp->if_capabilities = ifp->if_capenable = WG_CAPS; - mbufq_init(&sc->sc_handshake_queue, MAX_QUEUED_HANDSHAKES); + wg_queue_init(&sc->sc_handshake_queue, "hsq"); sx_init(&sc->sc_lock, "wg softc lock"); - sc->sc_encap_ring = buf_ring_alloc(MAX_QUEUED_PKT, M_WG, M_WAITOK, NULL); - sc->sc_decap_ring = buf_ring_alloc(MAX_QUEUED_PKT, M_WG, M_WAITOK, NULL); - GROUPTASK_INIT(&sc->sc_handshake, 0, - (gtask_fn_t *)wg_softc_handshake_receive, sc); + GROUPTASK_INIT(&sc->sc_handshake, 0, (gtask_fn_t *)wg_softc_handshake_receive, sc); taskqgroup_attach(qgroup_wg_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation"); crypto_taskq_setup(sc); + wg_queue_init(&sc->sc_encrypt_parallel, "encp"); + wg_queue_init(&sc->sc_decrypt_parallel, "decp"); + wg_aip_init(&sc->sc_aips); @@ -2886,8 +2848,9 @@ wg_clone_destroy(struct ifnet *ifp) sx_destroy(&sc->sc_lock); taskqgroup_detach(qgroup_wg_tqg, &sc->sc_handshake); crypto_taskq_destroy(sc); - buf_ring_free(sc->sc_encap_ring, M_WG); - buf_ring_free(sc->sc_decap_ring, M_WG); + wg_queue_deinit(&sc->sc_handshake_queue); + wg_queue_deinit(&sc->sc_encrypt_parallel); + wg_queue_deinit(&sc->sc_decrypt_parallel); wg_aip_destroy(&sc->sc_aips); @@ -2992,6 +2955,8 @@ wg_module_init(void) [PR_METHOD_REMOVE] = wg_prison_remove, }; + wg_packet_zone = uma_zcreate("wg packet", sizeof(struct wg_packet), + NULL, NULL, NULL, NULL, 0, 0); ratelimit_zone = uma_zcreate("wg ratelimit", sizeof(struct ratelimit), NULL, NULL, NULL, NULL, 0, 0); wg_osd_jail_slot = osd_jail_register(NULL, methods); @@ -3001,6 +2966,7 @@ static void wg_module_deinit(void) { + uma_zdestroy(wg_packet_zone); uma_zdestroy(ratelimit_zone); osd_jail_deregister(wg_osd_jail_slot); |