summaryrefslogtreecommitdiffstatshomepage
path: root/src/socket.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/socket.c')
-rw-r--r--src/socket.c14
1 files changed, 11 insertions, 3 deletions
diff --git a/src/socket.c b/src/socket.c
index cdb0f47..1dff57a 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -44,9 +44,14 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct
if (!rt) {
security_sk_classify_flow(sock, flowi4_to_flowi(&fl));
+ if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0, fl.saddr, RT_SCOPE_HOST))) {
+ endpoint->src4.s_addr = endpoint->src_if4 = fl.saddr = 0;
+ if (cache)
+ dst_cache_reset(cache);
+ }
rt = ip_route_output_flow(sock_net(sock), &fl, sock);
- if (unlikely(endpoint->src4.s_addr && ((IS_ERR(rt) && PTR_ERR(rt) == -EINVAL) || (!IS_ERR(rt) && !inet_confirm_addr(sock_net(sock), rcu_dereference_bh(rt->dst.dev->ip_ptr), 0, fl.saddr, RT_SCOPE_HOST))))) {
- endpoint->src4.s_addr = fl.saddr = 0;
+ if (unlikely(endpoint->src_if4 && ((IS_ERR(rt) && PTR_ERR(rt) == -EINVAL) || (!IS_ERR(rt) && rt->dst.dev->ifindex != endpoint->src_if4)))) {
+ endpoint->src4.s_addr = endpoint->src_if4 = fl.saddr = 0;
if (cache)
dst_cache_reset(cache);
if (!IS_ERR(rt))
@@ -204,6 +209,7 @@ int socket_endpoint_from_skb(struct endpoint *endpoint, struct sk_buff *skb)
endpoint->addr4.sin_port = udp_hdr(skb)->source;
endpoint->addr4.sin_addr.s_addr = ip_hdr(skb)->saddr;
endpoint->src4.s_addr = ip_hdr(skb)->daddr;
+ endpoint->src_if4 = skb->skb_iif;
} else if (skb->protocol == htons(ETH_P_IPV6)) {
endpoint->addr6.sin6_family = AF_INET6;
endpoint->addr6.sin6_port = udp_hdr(skb)->source;
@@ -223,12 +229,14 @@ void socket_set_peer_endpoint(struct wireguard_peer *peer, struct endpoint *endp
if (likely(peer->endpoint.addr4.sin_family == AF_INET &&
peer->endpoint.addr4.sin_port == endpoint->addr4.sin_port &&
peer->endpoint.addr4.sin_addr.s_addr == endpoint->addr4.sin_addr.s_addr &&
- peer->endpoint.src4.s_addr == endpoint->src4.s_addr))
+ peer->endpoint.src4.s_addr == endpoint->src4.s_addr &&
+ peer->endpoint.src_if4 == endpoint->src_if4))
goto out;
read_unlock_bh(&peer->endpoint_lock);
write_lock_bh(&peer->endpoint_lock);
peer->endpoint.addr4 = endpoint->addr4;
peer->endpoint.src4 = endpoint->src4;
+ peer->endpoint.src_if4 = endpoint->src_if4;
} else if (endpoint->addr.sa_family == AF_INET6) {
read_lock_bh(&peer->endpoint_lock);
if (likely(peer->endpoint.addr6.sin6_family == AF_INET6 &&