From 8c2797b04a058eaa7d7c64e6b263b8b720a24547 Mon Sep 17 00:00:00 2001 From: John Baldwin Date: Wed, 10 Nov 2021 16:02:11 -0800 Subject: if_wg: wgc_get/set: use M_WAITOK with malloc() This reduces the edge cases which need handling, and M_WAITOK is safe to use in this context. While here, narrow the scope of the sc_lock to the code that interacts with the softc, but not copyin/copyout, malloc, and nvlist_pack calls before and after interacting with the softc. Signed-off-by: John Baldwin --- src/if_wg.c | 40 ++++++++++++++++------------------------ 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/src/if_wg.c b/src/if_wg.c index 6f53a38..bf4d070 100644 --- a/src/if_wg.c +++ b/src/if_wg.c @@ -2378,10 +2378,8 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) if (wgd->wgd_size >= UINT32_MAX / 2) return (E2BIG); - if ((nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_NOWAIT | M_ZERO)) == NULL) - return (ENOMEM); + nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK | M_ZERO); - sx_xlock(&sc->sc_lock); err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size); if (err) goto out; @@ -2390,6 +2388,7 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) err = EBADMSG; goto out; } + sx_xlock(&sc->sc_lock); if (nvlist_exists_bool(nvl, "replace-peers") && nvlist_get_bool(nvl, "replace-peers")) wg_peer_destroy_all(sc); @@ -2397,12 +2396,12 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) uint64_t new_port = nvlist_get_number(nvl, "listen-port"); if (new_port > UINT16_MAX) { err = EINVAL; - goto out; + goto out_locked; } if (new_port != sc->sc_socket.so_port) { if ((ifp->if_drv_flags & IFF_DRV_RUNNING) != 0) { if ((err = wg_socket_init(sc, new_port)) != 0) - goto out; + goto out_locked; } else sc->sc_socket.so_port = new_port; } @@ -2411,7 +2410,7 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) const void *key = nvlist_get_binary(nvl, "private-key", &size); if (size != WG_KEY_SIZE) { err = EINVAL; - goto out; + goto out_locked; } if (noise_local_keys(sc->sc_local, NULL, private) != 0 || @@ -2445,11 +2444,11 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie"); if (user_cookie > UINT32_MAX) { err = EINVAL; - goto out; + goto out_locked; } err = wg_socket_set_cookie(sc, user_cookie); if (err) - goto out; + goto out_locked; } if (nvlist_exists_nvlist_array(nvl, "peers")) { size_t peercount; @@ -2459,15 +2458,16 @@ wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) for (int i = 0; i < peercount; i++) { err = wg_peer_add(sc, nvl_peers[i]); if (err != 0) - goto out; + goto out_locked; } } +out_locked: + sx_xunlock(&sc->sc_lock); nvlist_destroy(nvl); out: explicit_bzero(nvlpacked, wgd->wgd_size); free(nvlpacked, M_TEMP); - sx_xunlock(&sc->sc_lock); return (err); } @@ -2503,10 +2503,7 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) } peer_count = sc->sc_peers_num; if (peer_count) { - if ((nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_NOWAIT | M_ZERO)) == NULL) { - err = ENOMEM; - goto err; - } + nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO); i = 0; TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) { if (i >= peer_count) @@ -2535,10 +2532,7 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) aip_count = peer->p_aips_num; if (aip_count) { - if ((nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_NOWAIT | M_ZERO)) == NULL) { - err = ENOMEM; - goto err_peer; - } + nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO); j = 0; LIST_FOREACH(aip, &peer->p_aips, a_entry) { if (j >= aip_count) @@ -2574,9 +2568,12 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) for (i = 0; i < peer_count; ++i) nvlist_destroy(nvl_peers[i]); free(nvl_peers, M_NVLIST); - if (err) + if (err) { + sx_sunlock(&sc->sc_lock); goto err; + } } + sx_sunlock(&sc->sc_lock); packed = nvlist_pack(nvl, &size); if (!packed) { err = ENOMEM; @@ -2590,10 +2587,6 @@ wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) err = ENOSPC; goto out; } - if (!wgd->wgd_data) { - err = EFAULT; - goto out; - } err = copyout(packed, wgd->wgd_data, size); wgd->wgd_size = size; @@ -2602,7 +2595,6 @@ out: free(packed, M_NVLIST); err: nvlist_destroy(nvl); - sx_sunlock(&sc->sc_lock); return (err); } -- cgit v1.2.3-59-g8ed1b