From 028c66a45b5017ca5cba0f4ed1e222888d816d8c Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sat, 5 Nov 2016 02:51:05 +0100 Subject: socket: big refactoring --- src/send.c | 2 +- src/socket.c | 355 ++++++++++++++++++++++++++++------------------------------- src/socket.h | 6 +- 3 files changed, 170 insertions(+), 193 deletions(-) (limited to 'src') diff --git a/src/send.c b/src/send.c index 0956c8d..2ea2a0c 100644 --- a/src/send.c +++ b/src/send.c @@ -82,7 +82,7 @@ void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *i net_dbg_ratelimited("Sending cookie response for denied handshake message for %pISpfsc\n", &addr); #endif cookie_message_create(&packet, initiating_skb, data, data_len, sender_index, &wg->cookie_checker); - socket_send_buffer_as_reply_to_skb(initiating_skb, &packet, sizeof(packet), wg); + socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet, sizeof(packet)); } static inline void keep_key_fresh(struct wireguard_peer *peer) diff --git a/src/socket.c b/src/socket.c index 2d2cf45..4f8c4e8 100644 --- a/src/socket.c +++ b/src/socket.c @@ -13,230 +13,151 @@ #include #include - -union flowi46 { - struct flowi4 fl4; - struct flowi6 fl6; -}; - -int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb) +static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct sockaddr_in *addr, uint8_t ds, struct dst_cache *cache) { - struct iphdr *ip4; - struct ipv6hdr *ip6; - struct udphdr *udp; - struct sockaddr_in *addr4; - struct sockaddr_in6 *addr6; + struct flowi4 fl = { + .daddr = addr->sin_addr.s_addr, + .fl4_dport = addr->sin_port, + .fl4_sport = htons(wg->incoming_port), + .flowi4_proto = IPPROTO_UDP + }; + struct rtable *rt = NULL; + struct sock *sock; + int ret = 0; - addr4 = (struct sockaddr_in *)sockaddr; - addr6 = (struct sockaddr_in6 *)sockaddr; - ip4 = ip_hdr(skb); - ip6 = ipv6_hdr(skb); - udp = udp_hdr(skb); - if (ip4->version == 4) { - addr4->sin_family = AF_INET; - addr4->sin_port = udp->source; - addr4->sin_addr.s_addr = ip4->saddr; - } else if (ip4->version == 6) { - addr6->sin6_family = AF_INET6; - addr6->sin6_port = udp->source; - addr6->sin6_addr = ip6->saddr; - addr6->sin6_scope_id = ipv6_iface_scope_id(&ip6->saddr, skb->skb_iif); - /* TODO: addr6->sin6_flowinfo */ - } else - return -EINVAL; - return 0; -} + skb->next = skb->prev = NULL; + skb->dev = netdev_pub(wg); -static inline struct dst_entry *route(struct wireguard_device *wg, union flowi46 *fl, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6, struct dst_cache *cache) -{ - if (addr->ss_family == AF_INET) { - struct rtable *rt; - struct sockaddr_in *sin4 = (struct sockaddr_in *)addr; - - if (unlikely(!sock4)) - return ERR_PTR(-ENONET); - - memset(&fl->fl4, 0, sizeof(struct flowi4)); - fl->fl4.daddr = sin4->sin_addr.s_addr; - fl->fl4.fl4_dport = sin4->sin_port; - fl->fl4.fl4_sport = htons(wg->incoming_port); - fl->fl4.flowi4_proto = IPPROTO_UDP; - - rt = dst_cache_get_ip4(cache, &fl->fl4.saddr); - if (rt) - return &rt->dst; - - security_sk_classify_flow(sock4, flowi4_to_flowi(&fl->fl4)); - rt = ip_route_output_flow(sock_net(sock4), &fl->fl4, sock4); - if (unlikely(IS_ERR(rt))) - return ERR_PTR(PTR_ERR(rt)); - dst_cache_set_ip4(cache, &rt->dst, fl->fl4.saddr); - return &rt->dst; - } else if (addr->ss_family == AF_INET6) { -#if IS_ENABLED(CONFIG_IPV6) - int ret; - struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)addr; - struct dst_entry *dst; - - if (unlikely(!sock6)) - return ERR_PTR(-ENONET); - - memset(&fl->fl6, 0, sizeof(struct flowi6)); - fl->fl6.daddr = sin6->sin6_addr; - fl->fl6.fl6_dport = sin6->sin6_port; - fl->fl6.fl6_sport = htons(wg->incoming_port); - fl->fl6.flowi6_oif = sin6->sin6_scope_id; - fl->fl6.flowi6_proto = IPPROTO_UDP; - /* TODO: addr6->sin6_flowinfo */ + rcu_read_lock(); + sock = rcu_dereference(wg->sock4); + + if (unlikely(!sock)) { + ret = -ENONET; + goto err; + } - dst = dst_cache_get_ip6(cache, &fl->fl6.saddr); - if (dst) - return dst; + if (cache) + rt = dst_cache_get_ip4(cache, &fl.saddr); - security_sk_classify_flow(sock6, flowi6_to_flowi(&fl->fl6)); - ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock6), sock6, &dst, &fl->fl6); - if (unlikely(ret)) - return ERR_PTR(ret); - dst_cache_set_ip6(cache, dst, &fl->fl6.saddr); - return dst; -#endif + if (!rt) { + security_sk_classify_flow(sock, flowi4_to_flowi(&fl)); + rt = ip_route_output_flow(sock_net(sock), &fl, sock); + if (unlikely(IS_ERR(rt))) { + ret = PTR_ERR(rt); + net_dbg_ratelimited("No route to %pISpfsc, error %d\n", addr, ret); + goto err; + } else if (unlikely(rt->dst.dev == skb->dev)) { + ret = -ELOOP; + net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", addr); + goto err; + } + if (cache) + dst_cache_set_ip4(cache, &rt->dst, fl.saddr); } - return ERR_PTR(-EAFNOSUPPORT); + + udp_tunnel_xmit_skb(rt, sock, skb, + fl.saddr, fl.daddr, + ds, ip4_dst_hoplimit(&rt->dst), 0, + fl.fl4_sport, fl.fl4_dport, + false, false); + goto out; + +err: + kfree_skb(skb); +out: + rcu_read_unlock(); + return ret; } -static inline int send(struct net_device *dev, struct sk_buff *skb, struct dst_entry *dst, union flowi46 *fl, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6, u8 dscp) +static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct sockaddr_in6 *addr, uint8_t ds, struct dst_cache *cache) { - int ret = -EAFNOSUPPORT; +#if IS_ENABLED(CONFIG_IPV6) + struct flowi6 fl = { + .daddr = addr->sin6_addr, + .fl6_dport = addr->sin6_port, + .fl6_sport = htons(wg->incoming_port), + .flowi6_oif = addr->sin6_scope_id, + .flowi6_proto = IPPROTO_UDP + /* TODO: addr->sin6_flowinfo */ + }; + struct dst_entry *dst = NULL; + struct sock *sock; + int ret = 0; skb->next = skb->prev = NULL; - skb->dev = dev; + skb->dev = netdev_pub(wg); - if (addr->ss_family == AF_INET) { - if (unlikely(!sock4)) { - ret = -ENONET; + rcu_read_lock(); + sock = rcu_dereference(wg->sock6); + + if (unlikely(!sock)) { + ret = -ENONET; + goto err; + } + + if (cache) + dst = dst_cache_get_ip6(cache, &fl.saddr); + + if (!dst) { + security_sk_classify_flow(sock, flowi6_to_flowi(&fl)); + ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock), sock, &dst, &fl); + if (unlikely(ret)) { + net_dbg_ratelimited("No route to %pISpfsc, error %d\n", addr, ret); goto err; - } - udp_tunnel_xmit_skb((struct rtable *)dst, sock4, skb, - fl->fl4.saddr, fl->fl4.daddr, - dscp, ip4_dst_hoplimit(dst), 0, - fl->fl4.fl4_sport, fl->fl4.fl4_dport, - false, false); - return 0; - } else if (addr->ss_family == AF_INET6) { - if (unlikely(!sock6)) { - ret = -ENONET; + } else if (unlikely(dst->dev == skb->dev)) { + ret = -ELOOP; + net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", addr); goto err; } -#if IS_ENABLED(CONFIG_IPV6) - udp_tunnel6_xmit_skb(dst, sock6, skb, dev, - &fl->fl6.saddr, &fl->fl6.daddr, - dscp, ip6_dst_hoplimit(dst), 0, - fl->fl6.fl6_sport, fl->fl6.fl6_dport, - false); - return 0; -#else - goto err; -#endif + if (cache) + dst_cache_set_ip6(cache, dst, &fl.saddr); } + udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, + &fl.saddr, &fl.daddr, + ds, ip6_dst_hoplimit(dst), 0, + fl.fl6_sport, fl.fl6_dport, + false); + goto out; + err: kfree_skb(skb); - dst_release(dst); - return ret; -} - -void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *sockaddr) -{ - if (sockaddr->ss_family == AF_INET) { - read_lock_bh(&peer->endpoint_lock); - if (!memcmp(sockaddr, &peer->endpoint_addr, sizeof(struct sockaddr_in))) - goto out; - read_unlock_bh(&peer->endpoint_lock); - write_lock_bh(&peer->endpoint_lock); - memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in)); - } else if (sockaddr->ss_family == AF_INET6) { - read_lock_bh(&peer->endpoint_lock); - if (!memcmp(sockaddr, &peer->endpoint_addr, sizeof(struct sockaddr_in6))) - goto out; - read_unlock_bh(&peer->endpoint_lock); - write_lock_bh(&peer->endpoint_lock); - memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in6)); - } else - return; - dst_cache_reset(&peer->endpoint_cache); - write_unlock_bh(&peer->endpoint_lock); - return; out: - read_unlock_bh(&peer->endpoint_lock); + rcu_read_unlock(); + return ret; +#else + return -EAFNOSUPPORT; +#endif } -int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds) +int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, uint8_t ds) { - struct net_device *dev = netdev_pub(peer->device); - struct dst_entry *dst; - union flowi46 fl; size_t skb_len = skb->len; - int ret = 0; + int ret = -EAFNOSUPPORT; - rcu_read_lock(); read_lock_bh(&peer->endpoint_lock); - - dst = route(peer->device, &fl, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6), &peer->endpoint_cache); - if (unlikely(IS_ERR(dst))) { - net_dbg_ratelimited("No route to %pISpfsc for peer %Lu\n", &peer->endpoint_addr, peer->internal_id); - kfree_skb(skb); - ret = PTR_ERR(dst); - goto out; - } else if (unlikely(dst->dev == dev)) { - net_dbg_ratelimited("Avoiding routing loop to %pISpfsc for peer %Lu\n", &peer->endpoint_addr, peer->internal_id); - kfree_skb(skb); - ret = -ELOOP; - goto out; - } - - ret = send(dev, skb, dst, &fl, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6), ds); - if (!ret) + if (peer->endpoint_addr.ss_family == AF_INET) + ret = send4(peer->device, skb, (struct sockaddr_in *)&peer->endpoint_addr, ds, &peer->endpoint_cache); + else if (peer->endpoint_addr.ss_family == AF_INET6) + ret = send6(peer->device, skb, (struct sockaddr_in6 *)&peer->endpoint_addr, ds, &peer->endpoint_cache); + if (likely(!ret)) peer->tx_bytes += skb_len; - -out: read_unlock_bh(&peer->endpoint_lock); - rcu_read_unlock(); return ret; } -int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, size_t len, u8 ds) +int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, size_t len, uint8_t ds) { struct sk_buff *skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC); - if (!skb) + if (unlikely(!skb)) return -ENOMEM; skb_reserve(skb, SKB_HEADER_LEN); memcpy(skb_put(skb, len), buffer, len); return socket_send_skb_to_peer(peer, skb, ds); } -static int send_to_sockaddr(struct sk_buff *skb, struct wireguard_device *wg, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6) -{ - struct dst_entry *dst; - struct net_device *dev = netdev_pub(wg); - union flowi46 fl; - - dst = route(wg, &fl, addr, sock4, sock6, NULL); - if (IS_ERR(dst)) { - net_dbg_ratelimited("No route to %pISpfsc\n", addr); - kfree_skb(skb); - return PTR_ERR(dst); - } else if (unlikely(dst->dev == netdev_pub(wg))) { - net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", addr); - dst_release(dst); - kfree_skb(skb); - return -ELOOP; - } - - return send(dev, skb, dst, &fl, addr, sock4, sock6, 0); -} - -int socket_send_buffer_as_reply_to_skb(struct sk_buff *in_skb, void *out_buffer, size_t len, struct wireguard_device *wg) +int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_buff *in_skb, void *out_buffer, size_t len) { int ret = 0; struct sk_buff *skb; @@ -245,22 +166,78 @@ int socket_send_buffer_as_reply_to_skb(struct sk_buff *in_skb, void *out_buffer, if (unlikely(!in_skb)) return -EINVAL; ret = socket_addr_from_skb(&addr, in_skb); - if (ret < 0) + if (unlikely(ret < 0)) return ret; skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC); - if (!skb) + if (unlikely(!skb)) return -ENOMEM; skb_reserve(skb, SKB_HEADER_LEN); memcpy(skb_put(skb, len), out_buffer, len); - rcu_read_lock(); - ret = send_to_sockaddr(skb, wg, &addr, rcu_dereference(wg->sock4), rcu_dereference(wg->sock6)); - rcu_read_unlock(); + if (addr.ss_family == AF_INET) + ret = send4(wg, skb, (struct sockaddr_in *)&addr, 0, NULL); + else if (addr.ss_family == AF_INET6) + ret = send6(wg, skb, (struct sockaddr_in6 *)&addr, 0, NULL); + else + ret = -EAFNOSUPPORT; return ret; } +int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb) +{ + struct iphdr *ip4; + struct ipv6hdr *ip6; + struct udphdr *udp; + struct sockaddr_in *addr4; + struct sockaddr_in6 *addr6; + + addr4 = (struct sockaddr_in *)sockaddr; + addr6 = (struct sockaddr_in6 *)sockaddr; + ip4 = ip_hdr(skb); + ip6 = ipv6_hdr(skb); + udp = udp_hdr(skb); + if (ip4->version == 4) { + addr4->sin_family = AF_INET; + addr4->sin_port = udp->source; + addr4->sin_addr.s_addr = ip4->saddr; + } else if (ip4->version == 6) { + addr6->sin6_family = AF_INET6; + addr6->sin6_port = udp->source; + addr6->sin6_addr = ip6->saddr; + addr6->sin6_scope_id = ipv6_iface_scope_id(&ip6->saddr, skb->skb_iif); + /* TODO: addr6->sin6_flowinfo */ + } else + return -EINVAL; + return 0; +} + +void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *sockaddr) +{ + if (sockaddr->ss_family == AF_INET) { + read_lock_bh(&peer->endpoint_lock); + if (!memcmp(sockaddr, &peer->endpoint_addr, sizeof(struct sockaddr_in))) + goto out; + read_unlock_bh(&peer->endpoint_lock); + write_lock_bh(&peer->endpoint_lock); + memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in)); + } else if (sockaddr->ss_family == AF_INET6) { + read_lock_bh(&peer->endpoint_lock); + if (!memcmp(sockaddr, &peer->endpoint_addr, sizeof(struct sockaddr_in6))) + goto out; + read_unlock_bh(&peer->endpoint_lock); + write_lock_bh(&peer->endpoint_lock); + memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in6)); + } else + return; + dst_cache_reset(&peer->endpoint_cache); + write_unlock_bh(&peer->endpoint_lock); + return; +out: + read_unlock_bh(&peer->endpoint_lock); +} + static int receive(struct sock *sk, struct sk_buff *skb) { struct wireguard_device *wg; diff --git a/src/socket.h b/src/socket.h index 5ab1365..d7c9df4 100644 --- a/src/socket.h +++ b/src/socket.h @@ -14,9 +14,9 @@ struct wireguard_device; int socket_init(struct wireguard_device *wg); void socket_uninit(struct wireguard_device *wg); -int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *data, size_t len, u8 ds); -int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds); -int socket_send_buffer_as_reply_to_skb(struct sk_buff *in_skb, void *out_buffer, size_t len, struct wireguard_device *wg); +int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *data, size_t len, uint8_t ds); +int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, uint8_t ds); +int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_buff *in_skb, void *out_buffer, size_t len); int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb); void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *sockaddr); -- cgit v1.2.3-59-g8ed1b