From 1dd2b9f574ff2a45e6c39f9b4cd7094d2716d8c8 Mon Sep 17 00:00:00 2001 From: Matt Dunwoodie Date: Thu, 18 Mar 2021 16:23:42 +1100 Subject: Rework encap/decap routines This will make further work on in place decryption a lot easier. Additionally, it improves the readability as we can get rid of the difficult _len variables. The copy in and out of wg_pkt_data is also a cleaner solution than memcpy nonces and whatnot. --- sys/net/if_wg.c | 171 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 84 insertions(+), 87 deletions(-) diff --git a/sys/net/if_wg.c b/sys/net/if_wg.c index e34f5409d12..0d60b37b6b8 100644 --- a/sys/net/if_wg.c +++ b/sys/net/if_wg.c @@ -83,7 +83,7 @@ #define WG_PKT_COOKIE htole32(3) #define WG_PKT_DATA htole32(4) -#define WG_PKT_WITH_PADDING(n) (((n) + (16-1)) & (~(16-1))) +#define WG_PKT_ALIGNMENT 16 #define WG_KEY_SIZE WG_KEY_LEN struct wg_pkt_initiation { @@ -114,7 +114,7 @@ struct wg_pkt_cookie { struct wg_pkt_data { uint32_t t; uint32_t r_idx; - uint8_t nonce[sizeof(uint64_t)]; + uint64_t nonce; uint8_t buf[]; }; @@ -1346,6 +1346,9 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt) m = pkt->p_mbuf; e = &pkt->p_endpoint; + if ((m = m_pullup(m, m->m_pkthdr.len)) == NULL) + goto error; + switch (*mtod(m, uint32_t *)) { case WG_PKT_INITIATION: init = mtod(m, struct wg_pkt_initiation *); @@ -1480,53 +1483,42 @@ wg_handshake_worker(void *_sc) void wg_encap(struct wg_softc *sc, struct wg_packet *pkt) { - struct wg_pkt_data *data; + struct wg_pkt_data data; struct wg_peer *peer; - struct mbuf *m, *mc; - size_t padding_len, plaintext_len, out_len; - uint64_t nonce; - int res; + struct mbuf *m, *ms; + int res, pad, off, len; peer = pkt->p_peer; m = pkt->p_mbuf; - plaintext_len = min(WG_PKT_WITH_PADDING(m->m_pkthdr.len), pkt->p_mtu); - padding_len = plaintext_len - m->m_pkthdr.len; - out_len = sizeof(struct wg_pkt_data) + plaintext_len + NOISE_AUTHTAG_LEN; + /* Calculate what padding we need to add then limit it to the mtu of + * the interface. This is done to ensure we don't "over pad" a packet + * that is just under the MTU. */ + pad = (-m->m_pkthdr.len) & (WG_PKT_ALIGNMENT - 1); + if (m->m_pkthdr.len + pad > pkt->p_mtu) + pad = pkt->p_mtu - m->m_pkthdr.len; + + /* Pad the packet */ + if (pad != 0) { + if ((ms = m_makespace(m, m->m_pkthdr.len, pad, &off)) == NULL) + goto error_free; + bzero(mtod(ms, uint8_t *) + off, pad); + } - /* - * For the time being we allocate a new packet with sufficient size to - * hold the encrypted data and headers. It would be difficult to - * overcome as p_encap_queue (mbuf_list) holds a reference to the mbuf. - * If we m_makespace or similar, we risk corrupting that list. - * Additionally, we only pass a buf and buf length to - * noise_remote_encrypt. Technically it would be possible to teach - * noise_remote_encrypt about mbufs, but we would need to sort out the - * p_encap_queue situation first. - */ - if ((mc = m_clget(NULL, M_NOWAIT, out_len)) == NULL) + /* TODO teach noise_remote_encrypt about mbufs. Currently we have to + * resort to m_pullup to create an encryptable buffer. */ + len = m->m_pkthdr.len; + if (m_makespace(m, len, NOISE_AUTHTAG_LEN, &off) == NULL) + goto error_free; + if ((m = m_pullup(m, m->m_pkthdr.len)) == NULL) goto error; - data = mtod(mc, struct wg_pkt_data *); - m_copydata(m, 0, m->m_pkthdr.len, data->buf); - bzero(data->buf + m->m_pkthdr.len, padding_len); - data->t = WG_PKT_DATA; - - /* - * Copy the flow hash from the inner packet to the outer packet, so - * that fq_codel can property separate streams, rather than falling - * back to random buckets. - */ - mc->m_pkthdr.ph_flowid = m->m_pkthdr.ph_flowid; - - res = noise_remote_encrypt(&peer->p_remote, &data->r_idx, &nonce, - data->buf, plaintext_len); - nonce = htole64(nonce); /* Wire format is little endian. */ - memcpy(data->nonce, &nonce, sizeof(data->nonce)); + /* Do encryption */ + res = noise_remote_encrypt(&peer->p_remote, &data.r_idx, &data.nonce, + mtod(m, uint8_t *), len); if (__predict_false(res == EINVAL)) { - m_freem(mc); - goto error; + goto error_free; } else if (__predict_false(res == ESTALE)) { wg_timers_event_want_initiation(&peer->p_timers); } else if (__predict_false(res != 0)) { @@ -1534,14 +1526,17 @@ wg_encap(struct wg_softc *sc, struct wg_packet *pkt) } /* A packet with length 0 is a keepalive packet */ - if (__predict_false(m->m_pkthdr.len == 0)) + if (__predict_false(len == 0)) DPRINTF(sc, "Sending keepalive packet to peer %llu\n", peer->p_id); - mc->m_pkthdr.ph_loopcnt = m->m_pkthdr.ph_loopcnt; - mc->m_flags &= ~(M_MCAST | M_BCAST); - mc->m_len = out_len; - m_calchdrlen(mc); + /* Put header into packet */ + if ((m = m_prepend(m, sizeof(struct wg_pkt_data), M_NOWAIT)) == NULL) + goto error; + + data.t = WG_PKT_DATA; + data.nonce = htole64(data.nonce); + memcpy(mtod(m, void *), &data, sizeof(struct wg_pkt_data)); /* * We would count ifc_opackets, ifc_obytes of m here, except if_snd @@ -1549,16 +1544,18 @@ wg_encap(struct wg_softc *sc, struct wg_packet *pkt) counters_pkt(sc->sc_if.if_counters, ifc_opackets, ifc_obytes, m->m_pkthdr.len); */ - wg_peer_counters_add(peer, mc->m_pkthdr.len, 0); - - /* TODO this is temporary, and will be replaced with proper mbuf handling */ - m_freem(m); - pkt->p_mbuf = mc; + wg_peer_counters_add(peer, m->m_pkthdr.len, 0); + m->m_flags &= ~(M_MCAST | M_BCAST); + pkt->p_mbuf = m; pkt->p_state = WG_PACKET_CRYPTED; task_add(net_tq(sc->sc_if.if_index), &peer->p_deliver_out); return; + +error_free: + m_freem(m); error: + pkt->p_mbuf = NULL; pkt->p_state = WG_PACKET_DEAD; task_add(net_tq(sc->sc_if.if_index), &peer->p_deliver_out); } @@ -1566,35 +1563,32 @@ error: void wg_decap(struct wg_softc *sc, struct wg_packet *pkt) { - struct ip *ip; - struct ip6_hdr *ip6; - struct wg_pkt_data *data; + struct wg_pkt_data data; struct wg_peer *peer, *allowed_peer; struct mbuf *m; - size_t payload_len; - uint64_t nonce; + struct ip *ip; + struct ip6_hdr *ip6; int res, len; peer = pkt->p_peer; m = pkt->p_mbuf; + len = m->m_pkthdr.len; - /* - * Likewise to wg_encap, we pass a buf and buf length to - * noise_remote_decrypt. Again, possible to teach it about mbufs - * but need to get over the p_decap_queue situation first. However, - * we do not need to allocate a new mbuf as the decrypted packet is - * strictly smaller than encrypted. We just set t_mbuf to m and - * wg_deliver_in knows how to deal with that. - */ - data = mtod(m, struct wg_pkt_data *); - payload_len = m->m_pkthdr.len - sizeof(struct wg_pkt_data); - memcpy(&nonce, data->nonce, sizeof(nonce)); - nonce = le64toh(nonce); /* Wire format is little endian. */ - res = noise_remote_decrypt(&peer->p_remote, data->r_idx, nonce, - data->buf, payload_len); + /* Read index, nonce and then adjust to remove the header. */ + memcpy(&data, mtod(m, void *), sizeof(struct wg_pkt_data)); + m_adj(m, sizeof(struct wg_pkt_data)); - if (__predict_false(res == EINVAL)) { + /* TODO teach noise_remote_decrypt about mbufs. Currently we have to + * resort to m_pullup to create an decryptable buffer. */ + if ((m = m_pullup(m, m->m_pkthdr.len)) == NULL) { goto error; + } + + res = noise_remote_decrypt(&peer->p_remote, data.r_idx, + le64toh(data.nonce), mtod(m, void *), m->m_pkthdr.len); + + if (__predict_false(res == EINVAL)) { + goto error_free; } else if (__predict_false(res == ECONNRESET)) { wg_timers_event_handshake_complete(&peer->p_timers); } else if (__predict_false(res == ESTALE)) { @@ -1603,13 +1597,11 @@ wg_decap(struct wg_softc *sc, struct wg_packet *pkt) panic("unexpected response: %d\n", res); } - wg_peer_set_endpoint(peer, &pkt->p_endpoint); - - wg_peer_counters_add(peer, 0, m->m_pkthdr.len); - - m_adj(m, sizeof(struct wg_pkt_data)); m_adj(m, -NOISE_AUTHTAG_LEN); + wg_peer_set_endpoint(peer, &pkt->p_endpoint); + wg_peer_counters_add(peer, 0, len); + counters_pkt(sc->sc_if.if_counters, ifc_ipackets, ifc_ibytes, m->m_pkthdr.len); @@ -1656,13 +1648,13 @@ wg_decap(struct wg_softc *sc, struct wg_packet *pkt) } else { DPRINTF(sc, "Packet is neither ipv4 nor ipv6 from " "peer %llu\n", peer->p_id); - goto error; + goto error_free; } if (__predict_false(peer != allowed_peer)) { DPRINTF(sc, "Packet has unallowed src IP from peer " "%llu\n", peer->p_id); - goto error; + goto error_free; } /* tunneled packet was not offloaded */ @@ -1676,10 +1668,14 @@ wg_decap(struct wg_softc *sc, struct wg_packet *pkt) #endif /* NPF > 0 */ done: + pkt->p_mbuf = m; pkt->p_state = WG_PACKET_CRYPTED; task_add(net_tq(sc->sc_if.if_index), &peer->p_deliver_in); return; +error_free: + m_freem(m); error: + pkt->p_mbuf = NULL; pkt->p_state = WG_PACKET_DEAD; task_add(net_tq(sc->sc_if.if_index), &peer->p_deliver_in); } @@ -1993,6 +1989,9 @@ wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6, return NULL; } + /* Save a copy of the src/dst address to update the peer's endpoint. We + * only want to update it if we validate the packet cryptographically + * so that occurs later in `wg_peer_set_endpoint`. */ if (ip != NULL) { pkt->p_endpoint.e_remote.r_sa.sa_len = sizeof(struct sockaddr_in); pkt->p_endpoint.e_remote.r_sa.sa_family = AF_INET; @@ -2015,18 +2014,10 @@ wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6, /* m has a IP/IPv6 header of hlen length, we don't need it anymore. */ m_adj(m, hlen); - /* TODO rework to not do a pullup(pkthdr.len) */ - /* - * Ensure mbuf is contiguous over full length of packet. This is done - * os we can directly read the handshake values in wg_handshake, and so - * we can decrypt a transport packet by passing a single buffer to - * noise_remote_decrypt in wg_decap. - */ - if ((m = m_pullup(m, m->m_pkthdr.len)) == NULL) { - counters_inc(sc->sc_if.if_counters, ifc_ierrors); + if ((m = m_pullup(m, sizeof(uint32_t))) == NULL) { + counters_inc(sc->sc_if.if_counters, ifc_iqdrops); goto error_packet; } - pkt->p_mbuf = m; if ((m->m_pkthdr.len == sizeof(struct wg_pkt_initiation) && *mtod(m, uint32_t *) == WG_PKT_INITIATION) || @@ -2043,11 +2034,17 @@ wg_input(void *_sc, struct mbuf *m, struct ip *ip, struct ip6_hdr *ip6, goto error_mbuf; } task_add(wg_handshake_taskq, &sc->sc_handshake); + } else if (m->m_pkthdr.len >= sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN && *mtod(m, uint32_t *) == WG_PKT_DATA) { - data = mtod(m, struct wg_pkt_data *); + /* Pullup whole header to read r_idx below. */ + if ((m = m_pullup(m, sizeof(struct wg_pkt_data))) == NULL) { + counters_inc(sc->sc_if.if_counters, ifc_iqdrops); + goto error_packet; + } + data = mtod(m, struct wg_pkt_data *); if ((remote = wg_index_get(sc, data->r_idx)) == NULL) goto error_mbuf; -- cgit v1.2.3-59-g8ed1b