From 9dedf50dd0929207b5239488caaea0403089effe Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 4 May 2020 16:09:47 -0600 Subject: wireguard: socket: do not hold locks while transmitting packets Before, we followed this pattern for using the udp_tunnel api: rcu_read_lock_bh(); sock = rcu_dereference(obj->sock); ... udp_tunnel_xmit_skb(..., sock, ...); rcu_read_unlock_bh(); This commit changes that to use a reference counter instead: rcu_read_lock_bh(); sock = rcu_dereference(obj->sock); sock_hold(sock); rcu_read_unlock_bh(); ... udp_tunnel_xmit_skb(..., sock, ...); sock_put(sock); The advantage of the latter approach is that we now no longer hold any locks while udp_tunnel_xmit_skb runs, since it could be somewhat slow on systems with advanced qdisc or netfilter configurations. This should avoid potential RCU stalls in those situations. This commit makes sure we're holding neither the rcu read lock nor the endpoint read lock when udp_tunnel_xmit_skb is called. Fixes: e7096c131e51 ("net: WireGuard secure network tunnel") Signed-off-by: Jason A. Donenfeld --- drivers/net/wireguard/socket.c | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c index f9018027fc13..62716cca4c8a 100644 --- a/drivers/net/wireguard/socket.c +++ b/drivers/net/wireguard/socket.c @@ -18,7 +18,8 @@ #include static int send4(struct wg_device *wg, struct sk_buff *skb, - struct endpoint *endpoint, u8 ds, struct dst_cache *cache) + struct endpoint *endpoint, u8 ds, struct dst_cache *cache, + rwlock_t *endpoint_lock) { struct flowi4 fl = { .saddr = endpoint->src4.s_addr, @@ -37,6 +38,9 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, rcu_read_lock_bh(); sock = rcu_dereference_bh(wg->sock4); + if (likely(sock)) + sock_hold(sock); + rcu_read_unlock_bh(); if (unlikely(!sock)) { ret = -ENONET; @@ -80,6 +84,8 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, if (cache) dst_cache_set_ip4(cache, &rt->dst, fl.saddr); } + if (endpoint_lock) + read_unlock_bh(endpoint_lock); skb->ignore_df = 1; udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds, @@ -88,14 +94,18 @@ static int send4(struct wg_device *wg, struct sk_buff *skb, goto out; err: + if (endpoint_lock) + read_unlock_bh(endpoint_lock); kfree_skb(skb); out: - rcu_read_unlock_bh(); + if (likely(sock)) + sock_put(sock); return ret; } static int send6(struct wg_device *wg, struct sk_buff *skb, - struct endpoint *endpoint, u8 ds, struct dst_cache *cache) + struct endpoint *endpoint, u8 ds, struct dst_cache *cache, + rwlock_t *endpoint_lock) { #if IS_ENABLED(CONFIG_IPV6) struct flowi6 fl = { @@ -117,6 +127,9 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, rcu_read_lock_bh(); sock = rcu_dereference_bh(wg->sock6); + if (likely(sock)) + sock_hold(sock); + rcu_read_unlock_bh(); if (unlikely(!sock)) { ret = -ENONET; @@ -147,6 +160,8 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, if (cache) dst_cache_set_ip6(cache, dst, &fl.saddr); } + if (endpoint_lock) + read_unlock_bh(endpoint_lock); skb->ignore_df = 1; udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds, @@ -155,9 +170,12 @@ static int send6(struct wg_device *wg, struct sk_buff *skb, goto out; err: + if (endpoint_lock) + read_unlock_bh(endpoint_lock); kfree_skb(skb); out: - rcu_read_unlock_bh(); + if (likely(sock)) + sock_put(sock); return ret; #else return -EAFNOSUPPORT; @@ -169,18 +187,17 @@ int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds) size_t skb_len = skb->len; int ret = -EAFNOSUPPORT; - read_lock_bh(&peer->endpoint_lock); + read_lock_bh(&peer->endpoint_lock); /* Unlocked by send4/send6 */ if (peer->endpoint.addr.sa_family == AF_INET) ret = send4(peer->device, skb, &peer->endpoint, ds, - &peer->endpoint_cache); + &peer->endpoint_cache, &peer->endpoint_lock); else if (peer->endpoint.addr.sa_family == AF_INET6) ret = send6(peer->device, skb, &peer->endpoint, ds, - &peer->endpoint_cache); + &peer->endpoint_cache, &peer->endpoint_lock); else dev_kfree_skb(skb); if (likely(!ret)) peer->tx_bytes += skb_len; - read_unlock_bh(&peer->endpoint_lock); return ret; } @@ -221,9 +238,9 @@ int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg, skb_put_data(skb, buffer, len); if (endpoint.addr.sa_family == AF_INET) - ret = send4(wg, skb, &endpoint, 0, NULL); + ret = send4(wg, skb, &endpoint, 0, NULL, NULL); else if (endpoint.addr.sa_family == AF_INET6) - ret = send6(wg, skb, &endpoint, 0, NULL); + ret = send6(wg, skb, &endpoint, 0, NULL, NULL); /* No other possibilities if the endpoint is valid, which it is, * as we checked above. */ -- cgit v1.2.3-59-g8ed1b