From 46de6f340ae179741fb475ad6858a7a56a34af4a Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 10 Nov 2016 23:28:35 +0100 Subject: socket: keep track of src address in sending packets --- src/socket.c | 104 +++++++++++++++++++++++++++++++---------------------------- 1 file changed, 54 insertions(+), 50 deletions(-) (limited to 'src/socket.c') diff --git a/src/socket.c b/src/socket.c index 069d45d..20e82fd 100644 --- a/src/socket.c +++ b/src/socket.c @@ -13,11 +13,12 @@ #include #include -static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct sockaddr_in *addr, uint8_t ds, struct dst_cache *cache) +static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct endpoint *endpoint, uint8_t ds, struct dst_cache *cache) { struct flowi4 fl = { - .daddr = addr->sin_addr.s_addr, - .fl4_dport = addr->sin_port, + .saddr = endpoint->src4.s_addr, + .daddr = endpoint->addr4.sin_addr.s_addr, + .fl4_dport = endpoint->addr4.sin_port, .fl4_sport = htons(wg->incoming_port), .flowi4_proto = IPPROTO_UDP }; @@ -44,12 +45,12 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct rt = ip_route_output_flow(sock_net(sock), &fl, sock); if (unlikely(IS_ERR(rt))) { ret = PTR_ERR(rt); - net_dbg_ratelimited("No route to %pISpfsc, error %d\n", addr, ret); + net_dbg_ratelimited("No route to %pISpfsc, error %d\n", &endpoint->addr_storage, ret); goto err; } else if (unlikely(rt->dst.dev == skb->dev)) { dst_release(&rt->dst); ret = -ELOOP; - net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", addr); + net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", &endpoint->addr_storage); goto err; } if (cache) @@ -70,14 +71,15 @@ out: return ret; } -static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct sockaddr_in6 *addr, uint8_t ds, struct dst_cache *cache) +static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct endpoint *endpoint, uint8_t ds, struct dst_cache *cache) { #if IS_ENABLED(CONFIG_IPV6) struct flowi6 fl = { - .daddr = addr->sin6_addr, - .fl6_dport = addr->sin6_port, + .saddr = endpoint->src6, + .daddr = endpoint->addr6.sin6_addr, + .fl6_dport = endpoint->addr6.sin6_port, .fl6_sport = htons(wg->incoming_port), - .flowi6_oif = addr->sin6_scope_id, + .flowi6_oif = endpoint->addr6.sin6_scope_id, .flowi6_proto = IPPROTO_UDP /* TODO: addr->sin6_flowinfo */ }; @@ -103,12 +105,12 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct security_sk_classify_flow(sock, flowi6_to_flowi(&fl)); ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock), sock, &dst, &fl); if (unlikely(ret)) { - net_dbg_ratelimited("No route to %pISpfsc, error %d\n", addr, ret); + net_dbg_ratelimited("No route to %pISpfsc, error %d\n", &endpoint->addr_storage, ret); goto err; } else if (unlikely(dst->dev == skb->dev)) { dst_release(dst); ret = -ELOOP; - net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", addr); + net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", &endpoint->addr_storage); goto err; } if (cache) @@ -138,10 +140,10 @@ int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, ui int ret = -EAFNOSUPPORT; read_lock_bh(&peer->endpoint_lock); - if (peer->endpoint_addr.ss_family == AF_INET) - ret = send4(peer->device, skb, (struct sockaddr_in *)&peer->endpoint_addr, ds, &peer->endpoint_cache); - else if (peer->endpoint_addr.ss_family == AF_INET6) - ret = send6(peer->device, skb, (struct sockaddr_in6 *)&peer->endpoint_addr, ds, &peer->endpoint_cache); + if (peer->endpoint.addr_storage.ss_family == AF_INET) + ret = send4(peer->device, skb, &peer->endpoint, ds, &peer->endpoint_cache); + else if (peer->endpoint.addr_storage.ss_family == AF_INET6) + ret = send6(peer->device, skb, &peer->endpoint, ds, &peer->endpoint_cache); if (likely(!ret)) peer->tx_bytes += skb_len; read_unlock_bh(&peer->endpoint_lock); @@ -163,11 +165,11 @@ int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_bu { int ret = 0; struct sk_buff *skb; - struct sockaddr_storage addr = { 0 }; + struct endpoint endpoint; if (unlikely(!in_skb)) return -EINVAL; - ret = socket_addr_from_skb(&addr, in_skb); + ret = socket_endpoint_from_skb(&endpoint, in_skb); if (unlikely(ret < 0)) return ret; @@ -177,60 +179,62 @@ int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_bu skb_reserve(skb, SKB_HEADER_LEN); memcpy(skb_put(skb, len), out_buffer, len); - if (addr.ss_family == AF_INET) - ret = send4(wg, skb, (struct sockaddr_in *)&addr, 0, NULL); - else if (addr.ss_family == AF_INET6) - ret = send6(wg, skb, (struct sockaddr_in6 *)&addr, 0, NULL); + if (endpoint.addr_storage.ss_family == AF_INET) + ret = send4(wg, skb, &endpoint, 0, NULL); + else if (endpoint.addr_storage.ss_family == AF_INET6) + ret = send6(wg, skb, &endpoint, 0, NULL); else ret = -EAFNOSUPPORT; return ret; } -int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb) +int socket_endpoint_from_skb(struct endpoint *endpoint, struct sk_buff *skb) { - struct iphdr *ip4; - struct ipv6hdr *ip6; - struct udphdr *udp; - struct sockaddr_in *addr4; - struct sockaddr_in6 *addr6; - - addr4 = (struct sockaddr_in *)sockaddr; - addr6 = (struct sockaddr_in6 *)sockaddr; - ip4 = ip_hdr(skb); - ip6 = ipv6_hdr(skb); - udp = udp_hdr(skb); - if (ip4->version == 4) { - addr4->sin_family = AF_INET; - addr4->sin_port = udp->source; - addr4->sin_addr.s_addr = ip4->saddr; - } else if (ip4->version == 6) { - addr6->sin6_family = AF_INET6; - addr6->sin6_port = udp->source; - addr6->sin6_addr = ip6->saddr; - addr6->sin6_scope_id = ipv6_iface_scope_id(&ip6->saddr, skb->skb_iif); - /* TODO: addr6->sin6_flowinfo */ + memset(endpoint, 0, sizeof(struct endpoint)); + if (ip_hdr(skb)->version == 4) { + endpoint->addr4.sin_family = AF_INET; + endpoint->addr4.sin_port = udp_hdr(skb)->source; + endpoint->addr4.sin_addr.s_addr = ip_hdr(skb)->saddr; + endpoint->src4.s_addr = ip_hdr(skb)->daddr; + } else if (ip_hdr(skb)->version == 6) { + endpoint->addr6.sin6_family = AF_INET6; + endpoint->addr6.sin6_port = udp_hdr(skb)->source; + endpoint->addr6.sin6_addr = ipv6_hdr(skb)->saddr; + endpoint->addr6.sin6_scope_id = ipv6_iface_scope_id(&ipv6_hdr(skb)->saddr, skb->skb_iif); + /* TODO: endpoint->addr6.sin6_flowinfo */ + endpoint->src6 = ipv6_hdr(skb)->daddr; } else return -EINVAL; return 0; } -void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *sockaddr) +void socket_set_peer_endpoint(struct wireguard_peer *peer, struct endpoint *endpoint) { - if (sockaddr->ss_family == AF_INET) { + if (endpoint->addr_storage.ss_family == AF_INET) { read_lock_bh(&peer->endpoint_lock); - if (!memcmp(sockaddr, &peer->endpoint_addr, sizeof(struct sockaddr_in))) + if (likely(peer->endpoint.addr4.sin_family == AF_INET && + peer->endpoint.addr4.sin_port == endpoint->addr4.sin_port && + peer->endpoint.addr4.sin_addr.s_addr == endpoint->addr4.sin_addr.s_addr && + peer->endpoint.src4.s_addr == endpoint->src4.s_addr)) goto out; read_unlock_bh(&peer->endpoint_lock); write_lock_bh(&peer->endpoint_lock); - memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in)); - } else if (sockaddr->ss_family == AF_INET6) { + peer->endpoint.addr4 = endpoint->addr4; + peer->endpoint.src4 = endpoint->src4; + } else if (endpoint->addr_storage.ss_family == AF_INET6) { read_lock_bh(&peer->endpoint_lock); - if (!memcmp(sockaddr, &peer->endpoint_addr, sizeof(struct sockaddr_in6))) + if (likely(peer->endpoint.addr6.sin6_family == AF_INET6 && + peer->endpoint.addr6.sin6_port == endpoint->addr6.sin6_port && + /* TODO: peer->endpoint.addr6.sin6_flowinfo == endpoint->addr6.sin6_flowinfo && */ + ipv6_addr_equal(&peer->endpoint.addr6.sin6_addr, &endpoint->addr6.sin6_addr) && + peer->endpoint.addr6.sin6_scope_id == endpoint->addr6.sin6_scope_id && + ipv6_addr_equal(&peer->endpoint.src6, &endpoint->src6))) goto out; read_unlock_bh(&peer->endpoint_lock); write_lock_bh(&peer->endpoint_lock); - memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in6)); + peer->endpoint.addr6 = endpoint->addr6; + peer->endpoint.src6 = endpoint->src6; } else return; dst_cache_reset(&peer->endpoint_cache); -- cgit v1.2.3-59-g8ed1b