diff options
Diffstat (limited to 'src/if_wg.c')
-rw-r--r-- | src/if_wg.c | 181 |
1 files changed, 78 insertions, 103 deletions
diff --git a/src/if_wg.c b/src/if_wg.c index 11b8394..ac824d7 100644 --- a/src/if_wg.c +++ b/src/if_wg.c @@ -265,6 +265,8 @@ struct wg_softc { struct grouptask *sc_decrypt; struct wg_queue sc_encrypt_parallel; struct wg_queue sc_decrypt_parallel; + u_int sc_encrypt_last_cpu; + u_int sc_decrypt_last_cpu; struct sx sc_lock; }; @@ -377,7 +379,7 @@ static void wg_queue_purge(struct wg_queue *); static int wg_queue_both(struct wg_queue *, struct wg_queue *, struct wg_packet *); static struct wg_packet *wg_queue_dequeue_serial(struct wg_queue *); static struct wg_packet *wg_queue_dequeue_parallel(struct wg_queue *); -static void wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *); +static bool wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *); static void wg_peer_send_staged(struct wg_peer *); static int wg_clone_create(struct if_clone *, int, caddr_t); static void wg_qflush(struct ifnet *); @@ -407,18 +409,10 @@ wg_peer_alloc(struct wg_softc *sc, const uint8_t pub_key[WG_KEY_SIZE]) sx_assert(&sc->sc_lock, SX_XLOCKED); - if ((peer = malloc(sizeof(*peer), M_WG, M_NOWAIT | M_ZERO)) == NULL) - goto free_none; - - if ((peer->p_remote = noise_remote_alloc(sc->sc_local, peer, pub_key)) == NULL) - goto free_peer; - - if ((peer->p_tx_bytes = counter_u64_alloc(M_NOWAIT)) == NULL) - goto free_remote; - - if ((peer->p_rx_bytes = counter_u64_alloc(M_NOWAIT)) == NULL) - goto free_tx_bytes; - + peer = malloc(sizeof(*peer), M_WG, M_WAITOK | M_ZERO); + peer->p_remote = noise_remote_alloc(sc->sc_local, peer, pub_key); + peer->p_tx_bytes = counter_u64_alloc(M_WAITOK); + peer->p_rx_bytes = counter_u64_alloc(M_WAITOK); peer->p_id = peer_counter++; peer->p_sc = sc; @@ -452,14 +446,6 @@ wg_peer_alloc(struct wg_softc *sc, const uint8_t pub_key[WG_KEY_SIZE]) peer->p_aips_num = 0; return (peer); -free_tx_bytes: - counter_u64_free(peer->p_tx_bytes); -free_remote: - noise_remote_free(peer->p_remote, NULL); -free_peer: - free(peer, M_WG); -free_none: - return NULL; } static void @@ -558,8 +544,7 @@ wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void struct wg_aip *aip; int i, ret = 0; - if ((aip = malloc(sizeof(*aip), M_WG, M_NOWAIT | M_ZERO)) == NULL) - return (ENOBUFS); + aip = malloc(sizeof(*aip), M_WG, M_WAITOK | M_ZERO); aip->a_peer = peer; aip->a_af = af; @@ -572,6 +557,7 @@ wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void aip->a_addr.ip &= aip->a_mask.ip; aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr); break; +#ifdef INET6 case AF_INET6: if (cidr > 128) cidr = 128; root = sc->sc_aip6; @@ -581,6 +567,7 @@ wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void aip->a_addr.ip6[i] &= aip->a_mask.ip6[i]; aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr); break; +#endif default: free(aip, M_WG); return (EAFNOSUPPORT); @@ -706,7 +693,7 @@ wg_socket_init(struct wg_softc *sc, in_port_t port) if (rc) goto out; - rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc); + rc = udp_set_kernel_tunneling(so4, (udp_tun_func_t)wg_input, NULL, sc); /* * udp_set_kernel_tunneling can only fail if there is already a tunneling function set. * This should never happen with a new socket. @@ -717,7 +704,7 @@ wg_socket_init(struct wg_softc *sc, in_port_t port) rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, td); if (rc) goto out; - rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc); + rc = udp_set_kernel_tunneling(so6, (udp_tun_func_t)wg_input, NULL, sc); MPASS(rc == 0); #endif @@ -891,13 +878,13 @@ wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m) if (e->e_local.l_in.s_addr != INADDR_ANY) control = sbcreatecontrol((caddr_t)&e->e_local.l_in, sizeof(struct in_addr), IP_SENDSRCADDR, - IPPROTO_IP); + IPPROTO_IP, M_NOWAIT); #ifdef INET6 } else if (e->e_remote.r_sa.sa_family == AF_INET6) { if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6)) control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6, sizeof(struct in6_pktinfo), IPV6_PKTINFO, - IPPROTO_IPV6); + IPPROTO_IPV6, M_NOWAIT); #endif } else { m_freem(m); @@ -1476,10 +1463,8 @@ wg_mbuf_reset(struct mbuf *m) m_tag_delete(m, t); } - if (m->m_pkthdr.csum_flags & CSUM_SND_TAG) { - m_snd_tag_rele(m->m_pkthdr.snd_tag); - m->m_pkthdr.snd_tag = NULL; - } + KASSERT((m->m_pkthdr.csum_flags & CSUM_SND_TAG) == 0, + ("%s: mbuf %p has a send tag", __func__, m)); m->m_pkthdr.csum_flags = 0; m->m_pkthdr.PH_per.sixtyfour[0] = 0; @@ -1639,21 +1624,22 @@ wg_softc_encrypt(struct wg_softc *sc) static void wg_encrypt_dispatch(struct wg_softc *sc) { - for (int i = 0; i < mp_ncpus; i++) { - if (sc->sc_encrypt[i].gt_task.ta_flags & TASK_ENQUEUED) - continue; - GROUPTASK_ENQUEUE(&sc->sc_encrypt[i]); - } + /* + * The update to encrypt_last_cpu is racey such that we may + * reschedule the task for the same CPU multiple times, but + * the race doesn't really matter. + */ + u_int cpu = (sc->sc_encrypt_last_cpu + 1) % mp_ncpus; + sc->sc_encrypt_last_cpu = cpu; + GROUPTASK_ENQUEUE(&sc->sc_encrypt[cpu]); } static void wg_decrypt_dispatch(struct wg_softc *sc) { - for (int i = 0; i < mp_ncpus; i++) { - if (sc->sc_decrypt[i].gt_task.ta_flags & TASK_ENQUEUED) - continue; - GROUPTASK_ENQUEUE(&sc->sc_decrypt[i]); - } + u_int cpu = (sc->sc_decrypt_last_cpu + 1) % mp_ncpus; + sc->sc_decrypt_last_cpu = cpu; + GROUPTASK_ENQUEUE(&sc->sc_decrypt[cpu]); } static void @@ -1676,10 +1662,10 @@ wg_deliver_out(struct wg_peer *peer) len = m->m_pkthdr.len; + wg_timers_event_any_authenticated_packet_traversal(peer); + wg_timers_event_any_authenticated_packet_sent(peer); rc = wg_send(sc, &endpoint, m); if (rc == 0) { - wg_timers_event_any_authenticated_packet_traversal(peer); - wg_timers_event_any_authenticated_packet_sent(peer); if (len > (sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN)) wg_timers_event_data_sent(peer); counter_u64_add(peer->p_tx_bytes, len); @@ -1802,11 +1788,7 @@ wg_queue_deinit(struct wg_queue *queue) static size_t wg_queue_len(struct wg_queue *queue) { - size_t len; - mtx_lock(&queue->q_mtx); - len = queue->q_len; - mtx_unlock(&queue->q_mtx); - return (len); + return (queue->q_len); } static int @@ -1869,9 +1851,9 @@ wg_queue_enlist_staged(struct wg_queue *staged, struct wg_packet_list *list) static void wg_queue_delist_staged(struct wg_queue *staged, struct wg_packet_list *list) { + STAILQ_INIT(list); mtx_lock(&staged->q_mtx); - *list = staged->q_queue; - STAILQ_INIT(&staged->q_queue); + STAILQ_CONCAT(list, &staged->q_queue); staged->q_len = 0; mtx_unlock(&staged->q_mtx); } @@ -1944,7 +1926,7 @@ wg_queue_dequeue_parallel(struct wg_queue *parallel) return (pkt); } -static void +static bool wg_input(struct mbuf *m, int offset, struct inpcb *inpcb, const struct sockaddr *sa, void *_sc) { @@ -1963,7 +1945,7 @@ wg_input(struct mbuf *m, int offset, struct inpcb *inpcb, m = m_unshare(m, M_NOWAIT); if (!m) { if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); - return; + return true; } /* Caller provided us with `sa`, no need for this header. */ @@ -1972,13 +1954,13 @@ wg_input(struct mbuf *m, int offset, struct inpcb *inpcb, /* Pullup enough to read packet type */ if ((m = m_pullup(m, sizeof(uint32_t))) == NULL) { if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); - return; + return true; } if ((pkt = wg_packet_alloc(m)) == NULL) { if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); m_freem(m); - return; + return true; } /* Save send/recv address and port for later. */ @@ -2025,11 +2007,11 @@ wg_input(struct mbuf *m, int offset, struct inpcb *inpcb, } else { goto error; } - return; + return true; error: if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1); wg_packet_free(pkt); - return; + return true; } static void @@ -2280,10 +2262,7 @@ wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl) wg_aip_remove_all(sc, peer); } if (peer == NULL) { - if ((peer = wg_peer_alloc(sc, pub_key)) == NULL) { - err = ENOMEM; - goto out; - } + peer = wg_peer_alloc(sc, pub_key); need_insert = true; } if (nvlist_exists_binary(nvl, "endpoint")) { @@ -2380,10 +2359,8 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) if (wgd->wgd_size >= UINT32_MAX / 2) return (E2BIG); - if ((nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_NOWAIT | M_ZERO)) == NULL) - return (ENOMEM); + nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK | M_ZERO); - sx_xlock(&sc->sc_lock); err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size); if (err) goto out; @@ -2392,6 +2369,7 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) err = EBADMSG; goto out; } + sx_xlock(&sc->sc_lock); if (nvlist_exists_bool(nvl, "replace-peers") && nvlist_get_bool(nvl, "replace-peers")) wg_peer_destroy_all(sc); @@ -2399,12 +2377,12 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) uint64_t new_port = nvlist_get_number(nvl, "listen-port"); if (new_port > UINT16_MAX) { err = EINVAL; - goto out; + goto out_locked; } if (new_port != sc->sc_socket.so_port) { if ((ifp->if_drv_flags & IFF_DRV_RUNNING) != 0) { if ((err = wg_socket_init(sc, new_port)) != 0) - goto out; + goto out_locked; } else sc->sc_socket.so_port = new_port; } @@ -2413,7 +2391,7 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) const void *key = nvlist_get_binary(nvl, "private-key", &size); if (size != WG_KEY_SIZE) { err = EINVAL; - goto out; + goto out_locked; } if (noise_local_keys(sc->sc_local, NULL, private) != 0 || @@ -2447,11 +2425,11 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie"); if (user_cookie > UINT32_MAX) { err = EINVAL; - goto out; + goto out_locked; } err = wg_socket_set_cookie(sc, user_cookie); if (err) - goto out; + goto out_locked; } if (nvlist_exists_nvlist_array(nvl, "peers")) { size_t peercount; @@ -2461,15 +2439,16 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) for (int i = 0; i < peercount; i++) { err = wg_peer_add(sc, nvl_peers[i]); if (err != 0) - goto out; + goto out_locked; } } +out_locked: + sx_xunlock(&sc->sc_lock); nvlist_destroy(nvl); out: explicit_bzero(nvlpacked, wgd->wgd_size); free(nvlpacked, M_TEMP); - sx_xunlock(&sc->sc_lock); return (err); } @@ -2505,10 +2484,7 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) } peer_count = sc->sc_peers_num; if (peer_count) { - if ((nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_NOWAIT | M_ZERO)) == NULL) { - err = ENOMEM; - goto err; - } + nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO); i = 0; TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { if (i >= peer_count) @@ -2537,10 +2513,7 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) aip_count = peer->p_aips_num; if (aip_count) { - if ((nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_NOWAIT | M_ZERO)) == NULL) { - err = ENOMEM; - goto err_peer; - } + nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO); j = 0; LIST_FOREACH(aip, &peer->p_aips, a_entry) { if (j >= aip_count) @@ -2554,10 +2527,13 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) if (aip->a_af == AF_INET) { nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in)); nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip)); - } else if (aip->a_af == AF_INET6) { + } +#ifdef INET6 + else if (aip->a_af == AF_INET6) { nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6)); nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL)); } +#endif } nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count); err_aip: @@ -2573,9 +2549,12 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) for (i = 0; i < peer_count; ++i) nvlist_destroy(nvl_peers[i]); free(nvl_peers, M_NVLIST); - if (err) + if (err) { + sx_sunlock(&sc->sc_lock); goto err; + } } + sx_sunlock(&sc->sc_lock); packed = nvlist_pack(nvl, &size); if (!packed) { err = ENOMEM; @@ -2589,10 +2568,6 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) err = ENOSPC; goto out; } - if (!wgd->wgd_data) { - err = EFAULT; - goto out; - } err = copyout(packed, wgd->wgd_data, size); wgd->wgd_size = size; @@ -2601,7 +2576,6 @@ out: free(packed, M_NVLIST); err: nvlist_destroy(nvl); - sx_sunlock(&sc->sc_lock); return (err); } @@ -2743,17 +2717,13 @@ wg_clone_create(struct if_clone *ifc, int unit, caddr_t params) struct wg_softc *sc; struct ifnet *ifp; - if ((sc = malloc(sizeof(*sc), M_WG, M_NOWAIT | M_ZERO)) == NULL) - goto free_none; + sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO); - if ((sc->sc_local = noise_local_alloc(sc)) == NULL) - goto free_sc; + sc->sc_local = noise_local_alloc(sc); - if ((sc->sc_encrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_NOWAIT | M_ZERO)) == NULL) - goto free_local; + sc->sc_encrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO); - if ((sc->sc_decrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_NOWAIT | M_ZERO)) == NULL) - goto free_encrypt; + sc->sc_decrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO); if (!rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY)) goto free_decrypt; @@ -2821,13 +2791,9 @@ free_aip4: free(sc->sc_aip4, M_RTABLE); free_decrypt: free(sc->sc_decrypt, M_WG); -free_encrypt: free(sc->sc_encrypt, M_WG); -free_local: noise_local_free(sc->sc_local, NULL); -free_sc: free(sc, M_WG); -free_none: return (ENOMEM); } @@ -3014,17 +2980,25 @@ wg_module_init(void) if ((wg_packet_zone = uma_zcreate("wg packet", sizeof(struct wg_packet), NULL, NULL, NULL, NULL, 0, 0)) == NULL) goto free_none; - if (cookie_init() != 0) + ret = crypto_init(); + if (ret != 0) goto free_zone; + if (cookie_init() != 0) + goto free_crypto; wg_osd_jail_slot = osd_jail_register(NULL, methods); ret = ENOTRECOVERABLE; if (!wg_run_selftests()) - goto free_zone; + goto free_all; return (0); +free_all: + osd_jail_deregister(wg_osd_jail_slot); + cookie_deinit(); +free_crypto: + crypto_deinit(); free_zone: uma_zdestroy(wg_packet_zone); free_none: @@ -3038,16 +3012,17 @@ wg_module_deinit(void) VNET_LIST_RLOCK(); VNET_FOREACH(vnet_iter) { struct if_clone *clone = VNET_VNET(vnet_iter, wg_cloner); - if (!clone) - continue; - if_clone_detach(clone); - VNET_VNET(vnet_iter, wg_cloner) = NULL; + if (clone) { + if_clone_detach(clone); + VNET_VNET(vnet_iter, wg_cloner) = NULL; + } } VNET_LIST_RUNLOCK(); NET_EPOCH_WAIT(); MPASS(LIST_EMPTY(&wg_list)); osd_jail_deregister(wg_osd_jail_slot); cookie_deinit(); + crypto_deinit(); uma_zdestroy(wg_packet_zone); } |