aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/drivers/net/wireguard/socket.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/net/wireguard/socket.c')
-rw-r--r--drivers/net/wireguard/socket.c37
1 files changed, 27 insertions, 10 deletions
diff --git a/drivers/net/wireguard/socket.c b/drivers/net/wireguard/socket.c
index f9018027fc13..62716cca4c8a 100644
--- a/drivers/net/wireguard/socket.c
+++ b/drivers/net/wireguard/socket.c
@@ -18,7 +18,8 @@
#include <net/ipv6.h>
static int send4(struct wg_device *wg, struct sk_buff *skb,
- struct endpoint *endpoint, u8 ds, struct dst_cache *cache)
+ struct endpoint *endpoint, u8 ds, struct dst_cache *cache,
+ rwlock_t *endpoint_lock)
{
struct flowi4 fl = {
.saddr = endpoint->src4.s_addr,
@@ -37,6 +38,9 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
rcu_read_lock_bh();
sock = rcu_dereference_bh(wg->sock4);
+ if (likely(sock))
+ sock_hold(sock);
+ rcu_read_unlock_bh();
if (unlikely(!sock)) {
ret = -ENONET;
@@ -80,6 +84,8 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
if (cache)
dst_cache_set_ip4(cache, &rt->dst, fl.saddr);
}
+ if (endpoint_lock)
+ read_unlock_bh(endpoint_lock);
skb->ignore_df = 1;
udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds,
@@ -88,14 +94,18 @@ static int send4(struct wg_device *wg, struct sk_buff *skb,
goto out;
err:
+ if (endpoint_lock)
+ read_unlock_bh(endpoint_lock);
kfree_skb(skb);
out:
- rcu_read_unlock_bh();
+ if (likely(sock))
+ sock_put(sock);
return ret;
}
static int send6(struct wg_device *wg, struct sk_buff *skb,
- struct endpoint *endpoint, u8 ds, struct dst_cache *cache)
+ struct endpoint *endpoint, u8 ds, struct dst_cache *cache,
+ rwlock_t *endpoint_lock)
{
#if IS_ENABLED(CONFIG_IPV6)
struct flowi6 fl = {
@@ -117,6 +127,9 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
rcu_read_lock_bh();
sock = rcu_dereference_bh(wg->sock6);
+ if (likely(sock))
+ sock_hold(sock);
+ rcu_read_unlock_bh();
if (unlikely(!sock)) {
ret = -ENONET;
@@ -147,6 +160,8 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
if (cache)
dst_cache_set_ip6(cache, dst, &fl.saddr);
}
+ if (endpoint_lock)
+ read_unlock_bh(endpoint_lock);
skb->ignore_df = 1;
udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds,
@@ -155,9 +170,12 @@ static int send6(struct wg_device *wg, struct sk_buff *skb,
goto out;
err:
+ if (endpoint_lock)
+ read_unlock_bh(endpoint_lock);
kfree_skb(skb);
out:
- rcu_read_unlock_bh();
+ if (likely(sock))
+ sock_put(sock);
return ret;
#else
return -EAFNOSUPPORT;
@@ -169,18 +187,17 @@ int wg_socket_send_skb_to_peer(struct wg_peer *peer, struct sk_buff *skb, u8 ds)
size_t skb_len = skb->len;
int ret = -EAFNOSUPPORT;
- read_lock_bh(&peer->endpoint_lock);
+ read_lock_bh(&peer->endpoint_lock); /* Unlocked by send4/send6 */
if (peer->endpoint.addr.sa_family == AF_INET)
ret = send4(peer->device, skb, &peer->endpoint, ds,
- &peer->endpoint_cache);
+ &peer->endpoint_cache, &peer->endpoint_lock);
else if (peer->endpoint.addr.sa_family == AF_INET6)
ret = send6(peer->device, skb, &peer->endpoint, ds,
- &peer->endpoint_cache);
+ &peer->endpoint_cache, &peer->endpoint_lock);
else
dev_kfree_skb(skb);
if (likely(!ret))
peer->tx_bytes += skb_len;
- read_unlock_bh(&peer->endpoint_lock);
return ret;
}
@@ -221,9 +238,9 @@ int wg_socket_send_buffer_as_reply_to_skb(struct wg_device *wg,
skb_put_data(skb, buffer, len);
if (endpoint.addr.sa_family == AF_INET)
- ret = send4(wg, skb, &endpoint, 0, NULL);
+ ret = send4(wg, skb, &endpoint, 0, NULL, NULL);
else if (endpoint.addr.sa_family == AF_INET6)
- ret = send6(wg, skb, &endpoint, 0, NULL);
+ ret = send6(wg, skb, &endpoint, 0, NULL, NULL);
/* No other possibilities if the endpoint is valid, which it is,
* as we checked above.
*/