aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2020-05-04 16:09:47 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2020-05-04 18:04:01 -0600
commit9dedf50dd0929207b5239488caaea0403089effe (patch)
tree679ac06dd5f8fee7a6ed9b47a02ab59d6ee011e8
parentwireguard: send/receive: cond_resched() when processing worker ringbuffers (diff)
downloadwireguard-linux-jd/shorter-socket-lock.tar.xz
wireguard-linux-jd/shorter-socket-lock.zip
wireguard: socket: do not hold locks while transmitting packetsjd/shorter-socket-lock
Before, we followed this pattern for using the udp_tunnel api: rcu_read_lock_bh(); sock = rcu_dereference(obj->sock); ... udp_tunnel_xmit_skb(..., sock, ...); rcu_read_unlock_bh(); This commit changes that to use a reference counter instead: rcu_read_lock_bh(); sock = rcu_dereference(obj->sock); sock_hold(sock); rcu_read_unlock_bh(); ... udp_tunnel_xmit_skb(..., sock, ...); sock_put(sock); The advantage of the latter approach is that we now no longer hold any locks while udp_tunnel_xmit_skb runs, since it could be somewhat slow on systems with advanced qdisc or netfilter configurations. This should avoid potential RCU stalls in those situations. This commit makes sure we're holding neither the rcu read lock nor the endpoint read lock when udp_tunnel_xmit_skb is called. Fixes: e7096c131e51 ("net: WireGuard secure network tunnel") Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--drivers/net/wireguard/socket.c37
1 files changed, 27 insertions, 10 deletions
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c
index f9018027fc13..62716cca4c8a 100644
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -18,7 +18,8 @@
#include <net/ipv6.h>
static int send4(struct wg_device *wg, struct sk_buff *skb,
- struct endpoint *endpoint, u8 ds, struct dst_cache *cache)
+ struct endpoint *endpoint, u8 ds, struct dst_cache *cache,
+ rwlock_t *endpoint_lock)
{
struct flowi4 fl = {
.saddr = endpoint->src4.s_addr,
@@ -37,6 +38,9 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
rcu_read_lock_bh();
sock = rcu_dereference_bh(wg->sock4);
+ if (likely(sock))
+ sock_hold(sock);
+ rcu_read_unlock_bh();
if (unlikely(!sock)) {
ret = -ENONET;
@@ -80,6 +84,8 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
if (cache)
dst_cache_set_ip4(cache, &rt->dst, fl.saddr);
}
+ if (endpoint_lock)
+ read_unlock_bh(endpoint_lock);
skb->ignore_df = 1;
udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds,
@@ -88,14 +94,18 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
goto out;
err:
+ if (endpoint_lock)
+ read_unlock_bh(endpoint_lock);
kfree_skb(skb);
out:
- rcu_read_unlock_bh();
+ if (likely(sock))
+ sock_put(sock);
return ret;
}
static int send6(struct wg_device *wg, struct sk_buff *skb,
- struct endpoint *endpoint, u8 ds, struct dst_cache *cache)
+ struct endpoint *endpoint, u8 ds, struct dst_cache *cache,
+ rwlock_t *endpoint_lock)
{
#if IS_ENABLED(CONFIG_IPV6)
struct flowi6 fl = {
@@ -117,6 +127,9 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
rcu_read_lock_bh();
sock = rcu_dereference_bh(wg->sock6);
+ if (likely(sock))
+ sock_hold(sock);
+ rcu_read_unlock_bh();
if (unlikely(!sock)) {
ret = -ENONET;
@@ -147,6 +160,8 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
if (cache)
dst_cache_set_ip6(cache, dst, &fl.saddr);
}
+ if (endpoint_lock)
+ read_unlock_bh(endpoint_lock);
skb->ignore_df = 1;
udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds,
@@ -155,9 +170,12 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
goto out;
err:
+ if (endpoint_lock)
+ read_unlock_bh(endpoint_lock);
kfree_skb(skb);
out:
- rcu_read_unlock_bh();
+ if (likely(sock))
+ sock_put(sock);
return ret;
#else
return -EAFNOSUPPORT;
@@ -169,18 +187,17 @@ int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds)
size_t skb_len = skb->len;
int ret = -EAFNOSUPPORT;
- read_lock_bh(&peer->endpoint_lock);
+ read_lock_bh(&peer->endpoint_lock); /* Unlocked by send4/send6 */
if (peer->endpoint.addr.sa_family == AF_INET)
ret = send4(peer->device, skb, &peer->endpoint, ds,
- &peer->endpoint_cache);
+ &peer->endpoint_cache, &peer->endpoint_lock);
else if (peer->endpoint.addr.sa_family == AF_INET6)
ret = send6(peer->device, skb, &peer->endpoint, ds,
- &peer->endpoint_cache);
+ &peer->endpoint_cache, &peer->endpoint_lock);
else
dev_kfree_skb(skb);
if (likely(!ret))
peer->tx_bytes += skb_len;
- read_unlock_bh(&peer->endpoint_lock);
return ret;
}
@@ -221,9 +238,9 @@ int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg,
skb_put_data(skb, buffer, len);
if (endpoint.addr.sa_family == AF_INET)
- ret = send4(wg, skb, &endpoint, 0, NULL);
+ ret = send4(wg, skb, &endpoint, 0, NULL, NULL);
else if (endpoint.addr.sa_family == AF_INET6)
- ret = send6(wg, skb, &endpoint, 0, NULL);
+ ret = send6(wg, skb, &endpoint, 0, NULL, NULL);
/* No other possibilities if the endpoint is valid, which it is,
* as we checked above.
*/