aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/socket.c
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2016-11-10 23:28:35 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2016-11-15 02:42:36 +0100
commit46de6f340ae179741fb475ad6858a7a56a34af4a (patch)
tree132f9cdb49215e98285a70fb0decde84b2b34d81 /src/socket.c
parentcurve25519: use kmalloc in order to not overflow stack (diff)
downloadwireguard-monolithic-historical-46de6f340ae179741fb475ad6858a7a56a34af4a.tar.xz
wireguard-monolithic-historical-46de6f340ae179741fb475ad6858a7a56a34af4a.zip
socket: keep track of src address in sending packets
Diffstat (limited to 'src/socket.c')
-rw-r--r--src/socket.c104
1 files changed, 54 insertions, 50 deletions
diff --git a/src/socket.c b/src/socket.c
index 069d45d..20e82fd 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -13,11 +13,12 @@
#include <net/udp_tunnel.h>
#include <net/ipv6.h>
-static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct sockaddr_in *addr, uint8_t ds, struct dst_cache *cache)
+static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct endpoint *endpoint, uint8_t ds, struct dst_cache *cache)
{
struct flowi4 fl = {
- .daddr = addr->sin_addr.s_addr,
- .fl4_dport = addr->sin_port,
+ .saddr = endpoint->src4.s_addr,
+ .daddr = endpoint->addr4.sin_addr.s_addr,
+ .fl4_dport = endpoint->addr4.sin_port,
.fl4_sport = htons(wg->incoming_port),
.flowi4_proto = IPPROTO_UDP
};
@@ -44,12 +45,12 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct
rt = ip_route_output_flow(sock_net(sock), &fl, sock);
if (unlikely(IS_ERR(rt))) {
ret = PTR_ERR(rt);
- net_dbg_ratelimited("No route to %pISpfsc, error %d\n", addr, ret);
+ net_dbg_ratelimited("No route to %pISpfsc, error %d\n", &endpoint->addr_storage, ret);
goto err;
} else if (unlikely(rt->dst.dev == skb->dev)) {
dst_release(&rt->dst);
ret = -ELOOP;
- net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", addr);
+ net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", &endpoint->addr_storage);
goto err;
}
if (cache)
@@ -70,14 +71,15 @@ out:
return ret;
}
-static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct sockaddr_in6 *addr, uint8_t ds, struct dst_cache *cache)
+static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct endpoint *endpoint, uint8_t ds, struct dst_cache *cache)
{
#if IS_ENABLED(CONFIG_IPV6)
struct flowi6 fl = {
- .daddr = addr->sin6_addr,
- .fl6_dport = addr->sin6_port,
+ .saddr = endpoint->src6,
+ .daddr = endpoint->addr6.sin6_addr,
+ .fl6_dport = endpoint->addr6.sin6_port,
.fl6_sport = htons(wg->incoming_port),
- .flowi6_oif = addr->sin6_scope_id,
+ .flowi6_oif = endpoint->addr6.sin6_scope_id,
.flowi6_proto = IPPROTO_UDP
/* TODO: addr->sin6_flowinfo */
};
@@ -103,12 +105,12 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct
security_sk_classify_flow(sock, flowi6_to_flowi(&fl));
ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock), sock, &dst, &fl);
if (unlikely(ret)) {
- net_dbg_ratelimited("No route to %pISpfsc, error %d\n", addr, ret);
+ net_dbg_ratelimited("No route to %pISpfsc, error %d\n", &endpoint->addr_storage, ret);
goto err;
} else if (unlikely(dst->dev == skb->dev)) {
dst_release(dst);
ret = -ELOOP;
- net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", addr);
+ net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", &endpoint->addr_storage);
goto err;
}
if (cache)
@@ -138,10 +140,10 @@ int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, ui
int ret = -EAFNOSUPPORT;
read_lock_bh(&peer->endpoint_lock);
- if (peer->endpoint_addr.ss_family == AF_INET)
- ret = send4(peer->device, skb, (struct sockaddr_in *)&peer->endpoint_addr, ds, &peer->endpoint_cache);
- else if (peer->endpoint_addr.ss_family == AF_INET6)
- ret = send6(peer->device, skb, (struct sockaddr_in6 *)&peer->endpoint_addr, ds, &peer->endpoint_cache);
+ if (peer->endpoint.addr_storage.ss_family == AF_INET)
+ ret = send4(peer->device, skb, &peer->endpoint, ds, &peer->endpoint_cache);
+ else if (peer->endpoint.addr_storage.ss_family == AF_INET6)
+ ret = send6(peer->device, skb, &peer->endpoint, ds, &peer->endpoint_cache);
if (likely(!ret))
peer->tx_bytes += skb_len;
read_unlock_bh(&peer->endpoint_lock);
@@ -163,11 +165,11 @@ int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_bu
{
int ret = 0;
struct sk_buff *skb;
- struct sockaddr_storage addr = { 0 };
+ struct endpoint endpoint;
if (unlikely(!in_skb))
return -EINVAL;
- ret = socket_addr_from_skb(&addr, in_skb);
+ ret = socket_endpoint_from_skb(&endpoint, in_skb);
if (unlikely(ret < 0))
return ret;
@@ -177,60 +179,62 @@ int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_bu
skb_reserve(skb, SKB_HEADER_LEN);
memcpy(skb_put(skb, len), out_buffer, len);
- if (addr.ss_family == AF_INET)
- ret = send4(wg, skb, (struct sockaddr_in *)&addr, 0, NULL);
- else if (addr.ss_family == AF_INET6)
- ret = send6(wg, skb, (struct sockaddr_in6 *)&addr, 0, NULL);
+ if (endpoint.addr_storage.ss_family == AF_INET)
+ ret = send4(wg, skb, &endpoint, 0, NULL);
+ else if (endpoint.addr_storage.ss_family == AF_INET6)
+ ret = send6(wg, skb, &endpoint, 0, NULL);
else
ret = -EAFNOSUPPORT;
return ret;
}
-int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb)
+int socket_endpoint_from_skb(struct endpoint *endpoint, struct sk_buff *skb)
{
- struct iphdr *ip4;
- struct ipv6hdr *ip6;
- struct udphdr *udp;
- struct sockaddr_in *addr4;
- struct sockaddr_in6 *addr6;
-
- addr4 = (struct sockaddr_in *)sockaddr;
- addr6 = (struct sockaddr_in6 *)sockaddr;
- ip4 = ip_hdr(skb);
- ip6 = ipv6_hdr(skb);
- udp = udp_hdr(skb);
- if (ip4->version == 4) {
- addr4->sin_family = AF_INET;
- addr4->sin_port = udp->source;
- addr4->sin_addr.s_addr = ip4->saddr;
- } else if (ip4->version == 6) {
- addr6->sin6_family = AF_INET6;
- addr6->sin6_port = udp->source;
- addr6->sin6_addr = ip6->saddr;
- addr6->sin6_scope_id = ipv6_iface_scope_id(&ip6->saddr, skb->skb_iif);
- /* TODO: addr6->sin6_flowinfo */
+ memset(endpoint, 0, sizeof(struct endpoint));
+ if (ip_hdr(skb)->version == 4) {
+ endpoint->addr4.sin_family = AF_INET;
+ endpoint->addr4.sin_port = udp_hdr(skb)->source;
+ endpoint->addr4.sin_addr.s_addr = ip_hdr(skb)->saddr;
+ endpoint->src4.s_addr = ip_hdr(skb)->daddr;
+ } else if (ip_hdr(skb)->version == 6) {
+ endpoint->addr6.sin6_family = AF_INET6;
+ endpoint->addr6.sin6_port = udp_hdr(skb)->source;
+ endpoint->addr6.sin6_addr = ipv6_hdr(skb)->saddr;
+ endpoint->addr6.sin6_scope_id = ipv6_iface_scope_id(&ipv6_hdr(skb)->saddr, skb->skb_iif);
+ /* TODO: endpoint->addr6.sin6_flowinfo */
+ endpoint->src6 = ipv6_hdr(skb)->daddr;
} else
return -EINVAL;
return 0;
}
-void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *sockaddr)
+void socket_set_peer_endpoint(struct wireguard_peer *peer, struct endpoint *endpoint)
{
- if (sockaddr->ss_family == AF_INET) {
+ if (endpoint->addr_storage.ss_family == AF_INET) {
read_lock_bh(&peer->endpoint_lock);
- if (!memcmp(sockaddr, &peer->endpoint_addr, sizeof(struct sockaddr_in)))
+ if (likely(peer->endpoint.addr4.sin_family == AF_INET &&
+ peer->endpoint.addr4.sin_port == endpoint->addr4.sin_port &&
+ peer->endpoint.addr4.sin_addr.s_addr == endpoint->addr4.sin_addr.s_addr &&
+ peer->endpoint.src4.s_addr == endpoint->src4.s_addr))
goto out;
read_unlock_bh(&peer->endpoint_lock);
write_lock_bh(&peer->endpoint_lock);
- memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in));
- } else if (sockaddr->ss_family == AF_INET6) {
+ peer->endpoint.addr4 = endpoint->addr4;
+ peer->endpoint.src4 = endpoint->src4;
+ } else if (endpoint->addr_storage.ss_family == AF_INET6) {
read_lock_bh(&peer->endpoint_lock);
- if (!memcmp(sockaddr, &peer->endpoint_addr, sizeof(struct sockaddr_in6)))
+ if (likely(peer->endpoint.addr6.sin6_family == AF_INET6 &&
+ peer->endpoint.addr6.sin6_port == endpoint->addr6.sin6_port &&
+ /* TODO: peer->endpoint.addr6.sin6_flowinfo == endpoint->addr6.sin6_flowinfo && */
+ ipv6_addr_equal(&peer->endpoint.addr6.sin6_addr, &endpoint->addr6.sin6_addr) &&
+ peer->endpoint.addr6.sin6_scope_id == endpoint->addr6.sin6_scope_id &&
+ ipv6_addr_equal(&peer->endpoint.src6, &endpoint->src6)))
goto out;
read_unlock_bh(&peer->endpoint_lock);
write_lock_bh(&peer->endpoint_lock);
- memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in6));
+ peer->endpoint.addr6 = endpoint->addr6;
+ peer->endpoint.src6 = endpoint->src6;
} else
return;
dst_cache_reset(&peer->endpoint_cache);