/* Copyright 2015-2016 Jason A. Donenfeld . All Rights Reserved. */ #include "wireguard.h" #include "socket.h" #include "packets.h" #include "messages.h" #include #include #include #include #include int socket_addr_from_skb(struct sockaddr_storage *sockaddr, struct sk_buff *skb) { struct iphdr *ip4; struct ipv6hdr *ip6; struct udphdr *udp; struct sockaddr_in *addr4; struct sockaddr_in6 *addr6; addr4 = (struct sockaddr_in *)sockaddr; addr6 = (struct sockaddr_in6 *)sockaddr; ip4 = ip_hdr(skb); ip6 = ipv6_hdr(skb); udp = udp_hdr(skb); if (ip4->version == 4) { addr4->sin_family = AF_INET; addr4->sin_port = udp->source; addr4->sin_addr.s_addr = ip4->saddr; } else if (ip4->version == 6) { addr6->sin6_family = AF_INET6; addr6->sin6_port = udp->source; addr6->sin6_addr = ip6->saddr; addr6->sin6_scope_id = ipv6_iface_scope_id(&ip6->saddr, skb->skb_iif); /* TODO: addr6->sin6_flowinfo */ } else return -EINVAL; return 0; } static inline struct dst_entry *route(struct wireguard_device *wg, struct flowi4 *fl4, struct flowi6 *fl6, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6) { struct dst_entry *dst = ERR_PTR(-EAFNOSUPPORT); if (addr->ss_family == AF_INET) { struct rtable *rt; struct sockaddr_in *sin4 = (struct sockaddr_in *)addr; if (unlikely(!sock4)) return ERR_PTR(-ENONET); memset(fl4, 0, sizeof(struct flowi4)); fl4->daddr = sin4->sin_addr.s_addr; fl4->fl4_dport = sin4->sin_port; fl4->fl4_sport = htons(wg->incoming_port); fl4->flowi4_proto = IPPROTO_UDP; security_sk_classify_flow(sock4, flowi4_to_flowi(fl4)); rt = ip_route_output_flow(sock_net(sock4), fl4, sock4); if (unlikely(IS_ERR(rt))) dst = ERR_PTR(PTR_ERR(rt)); dst = &rt->dst; } else if (addr->ss_family == AF_INET6) { int ret; struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)addr; if (unlikely(!sock6)) return ERR_PTR(-ENONET); memset(fl6, 0, sizeof(struct flowi6)); fl6->daddr = sin6->sin6_addr; fl6->fl6_dport = sin6->sin6_port; fl6->fl6_sport = htons(wg->incoming_port); fl6->flowi6_oif = sin6->sin6_scope_id; fl6->flowi6_proto = IPPROTO_UDP; /* TODO: addr6->sin6_flowinfo */ security_sk_classify_flow(sock6, flowi6_to_flowi(fl6)); #if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 3, 0) ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock6), sock6, &dst, fl6); #else ret = ipv6_stub->ipv6_dst_lookup(sock6, &dst, fl6); #endif if (unlikely(ret)) dst = ERR_PTR(ret); } return dst; } static inline int send(struct net_device *dev, struct sk_buff *skb, struct dst_entry *dst, struct flowi4 *fl4, struct flowi6 *fl6, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6, u8 dscp) { int ret = -EAFNOSUPPORT; skb->next = skb->prev = NULL; skb->dev = dev; if (addr->ss_family == AF_INET) { if (unlikely(!sock4)) { ret = -ENONET; goto err; } #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 5, 0) ret = udp_tunnel_xmit_skb((struct rtable *)dst, sock4, skb, fl4->saddr, fl4->daddr, dscp, ip4_dst_hoplimit(dst), 0, fl4->fl4_sport, fl4->fl4_dport, false, false); iptunnel_xmit_stats(ret, &dev->stats, dev->tstats); return ret > 0 ? 0 : -ECOMM; #else udp_tunnel_xmit_skb((struct rtable *)dst, sock4, skb, fl4->saddr, fl4->daddr, dscp, ip4_dst_hoplimit(dst), 0, fl4->fl4_sport, fl4->fl4_dport, false, false); return 0; #endif } else if (addr->ss_family == AF_INET6) { if (unlikely(!sock6)) { ret = -ENONET; goto err; } #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 5, 0) return udp_tunnel6_xmit_skb(dst, sock6, skb, dev, &fl6->saddr, &fl6->daddr, dscp, ip6_dst_hoplimit(dst), fl6->fl6_sport, fl6->fl6_dport, false) == 0 ? 0 : -ECOMM; #elif LINUX_VERSION_CODE < KERNEL_VERSION(4, 6, 0) udp_tunnel6_xmit_skb(dst, sock6, skb, dev, &fl6->saddr, &fl6->daddr, dscp, ip6_dst_hoplimit(dst), fl6->fl6_sport, fl6->fl6_dport, false); return 0; #else udp_tunnel6_xmit_skb(dst, sock6, skb, dev, &fl6->saddr, &fl6->daddr, dscp, ip6_dst_hoplimit(dst), 0, fl6->fl6_sport, fl6->fl6_dport, false); return 0; #endif } err: kfree_skb(skb); dst_release(dst); return ret; } static inline void __socket_set_peer_dst(struct wireguard_peer *peer) { struct dst_entry *old_dst, *new_dst; lockdep_assert_held(&peer->endpoint_lock); old_dst = peer->endpoint_dst; peer->endpoint_dst = NULL; wmb(); if (old_dst) dst_release(old_dst); rcu_read_lock(); new_dst = route(peer->device, &peer->endpoint_flow.fl4, &peer->endpoint_flow.fl6, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6)); rcu_read_unlock(); if (likely(!IS_ERR(new_dst))) { peer->endpoint_dst = new_dst; wmb(); } } void socket_set_peer_dst(struct wireguard_peer *peer) { write_lock_bh(&peer->endpoint_lock); __socket_set_peer_dst(peer); write_unlock_bh(&peer->endpoint_lock); } void socket_set_peer_addr(struct wireguard_peer *peer, struct sockaddr_storage *sockaddr) { if (sockaddr->ss_family == AF_INET) { read_lock_bh(&peer->endpoint_lock); if (!memcmp(sockaddr, &peer->endpoint_addr, sizeof(struct sockaddr_in))) goto out; read_unlock_bh(&peer->endpoint_lock); write_lock_bh(&peer->endpoint_lock); memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in)); } else if (sockaddr->ss_family == AF_INET6) { read_lock_bh(&peer->endpoint_lock); if (!memcmp(sockaddr, &peer->endpoint_addr, sizeof(struct sockaddr_in6))) goto out; read_unlock_bh(&peer->endpoint_lock); write_lock_bh(&peer->endpoint_lock); memcpy(&peer->endpoint_addr, sockaddr, sizeof(struct sockaddr_in6)); } else return; __socket_set_peer_dst(peer); write_unlock_bh(&peer->endpoint_lock); return; out: read_unlock_bh(&peer->endpoint_lock); } static inline struct dst_entry *peer_dst_get(struct wireguard_peer *peer) { struct dst_entry *dst = NULL; read_lock_bh(&peer->endpoint_lock); if (!peer->endpoint_dst || (peer->endpoint_dst->obsolete && !peer->endpoint_dst->ops->check(peer->endpoint_dst, 0))) { read_unlock_bh(&peer->endpoint_lock); socket_set_peer_dst(peer); read_lock_bh(&peer->endpoint_lock); if (!peer->endpoint_dst) goto out; } if (!atomic_inc_not_zero(&peer->endpoint_dst->__refcnt)) goto out; dst = peer->endpoint_dst; out: read_unlock_bh(&peer->endpoint_lock); return dst; } int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 dscp) { struct net_device *dev = netdev_pub(peer->device); struct dst_entry *dst; size_t skb_len = skb->len; int ret = 0; dst = peer_dst_get(peer); if (unlikely(!dst)) { net_dbg_ratelimited("No route to %pISpfsc for peer %Lu\n", &peer->endpoint_addr, peer->internal_id); kfree_skb(skb); return -EHOSTUNREACH; } else if (unlikely(dst->dev == dev)) { net_dbg_ratelimited("Avoiding routing loop to %pISpfsc for peer %Lu\n", &peer->endpoint_addr, peer->internal_id); kfree_skb(skb); return -ELOOP; } rcu_read_lock(); read_lock_bh(&peer->endpoint_lock); ret = send(dev, skb, dst, &peer->endpoint_flow.fl4, &peer->endpoint_flow.fl6, &peer->endpoint_addr, rcu_dereference(peer->device->sock4), rcu_dereference(peer->device->sock6), dscp); if (!ret) peer->tx_bytes += skb_len; read_unlock_bh(&peer->endpoint_lock); rcu_read_unlock(); return ret; } int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, size_t len, u8 dscp) { struct sk_buff *skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC); if (!skb) return -ENOMEM; skb_reserve(skb, SKB_HEADER_LEN); memcpy(skb_put(skb, len), buffer, len); return socket_send_skb_to_peer(peer, skb, dscp); } static int send_to_sockaddr(struct sk_buff *skb, struct wireguard_device *wg, struct sockaddr_storage *addr, struct sock *sock4, struct sock *sock6) { struct dst_entry *dst; struct net_device *dev = netdev_pub(wg); union { struct flowi4 fl4; struct flowi6 fl6; } fl; dst = route(wg, &fl.fl4, &fl.fl6, addr, sock4, sock6); if (IS_ERR(dst)) { net_dbg_ratelimited("No route to %pISpfsc\n", addr); kfree_skb(skb); return PTR_ERR(dst); } else if (unlikely(dst->dev == netdev_pub(wg))) { net_dbg_ratelimited("Avoiding routing loop to %pISpfsc\n", addr); dst_release(dst); kfree_skb(skb); return -ELOOP; } return send(dev, skb, dst, &fl.fl4, &fl.fl6, addr, sock4, sock6, 0); } int socket_send_buffer_as_reply_to_skb(struct sk_buff *in_skb, void *out_buffer, size_t len, struct wireguard_device *wg) { int ret = 0; struct sk_buff *skb; struct sockaddr_storage addr = { 0 }; if (unlikely(!in_skb)) return -EINVAL; ret = socket_addr_from_skb(&addr, in_skb); if (ret < 0) return ret; skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC); if (!skb) return -ENOMEM; skb_reserve(skb, SKB_HEADER_LEN); memcpy(skb_put(skb, len), out_buffer, len); rcu_read_lock(); ret = send_to_sockaddr(skb, wg, &addr, rcu_dereference(wg->sock4), rcu_dereference(wg->sock6)); rcu_read_unlock(); return ret; } static int receive(struct sock *sk, struct sk_buff *skb) { struct wireguard_device *wg; if (unlikely(!sk)) goto err; wg = sk->sk_user_data; if (unlikely(!wg)) goto err; packet_receive(wg, skb); return 0; err: kfree_skb(skb); return 0; } /* Generates a default port from the interface name. * wg0 --> 51820 * wg1 --> 51821 * wg2 --> 51822 * wg100 --> 51920 * wg60000 --> 46285 * blahbla --> 51820 * 50 --> 51870 */ static uint16_t generate_default_incoming_port(struct wireguard_device *wg) { uint16_t port = 51820; unsigned long parsed; char *name, *digit_begin; size_t len; ASSERT_RTNL(); name = netdev_pub(wg)->name; len = strlen(name); if (!len) return port; digit_begin = name + len - 1; while (digit_begin >= name) { if (*digit_begin >= '0' && *digit_begin <= '9') --digit_begin; else break; } ++digit_begin; if (!*digit_begin) return port; if (!kstrtoul(digit_begin, 10, &parsed)) port += parsed; if (!port) ++port; return port; } static inline void sock_free(struct sock *sock) { if (unlikely(!sock)) return; sk_clear_memalloc(sock); udp_tunnel_sock_release(sock->sk_socket); } static inline void set_sock_opts(struct socket *sock) { sock->sk->sk_allocation = GFP_ATOMIC; sock->sk->sk_sndbuf = INT_MAX; sk_set_memalloc(sock->sk); } int socket_init(struct wireguard_device *wg) { struct udp_port_cfg port4 = { .family = AF_INET, .local_ip.s_addr = htonl(INADDR_ANY), .use_udp_checksums = true }; struct udp_port_cfg port6 = { .family = AF_INET6, .local_ip6 = IN6ADDR_ANY_INIT, .use_udp6_tx_checksums = true, .use_udp6_rx_checksums = true, #if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 3, 0) .ipv6_v6only = true #endif }; struct udp_tunnel_sock_cfg cfg = { .sk_user_data = wg, .encap_type = 1, .encap_rcv = receive }; int ret = 0; struct socket *new4 = NULL, *new6 = NULL; #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 3, 0) int old_bindv6only; #endif mutex_lock(&wg->socket_update_lock); if (rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock)) || rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock))) { ret = -EADDRINUSE; goto out; } if (!wg->incoming_port) wg->incoming_port = generate_default_incoming_port(wg); port4.local_udp_port = port6.local_udp_port = htons(wg->incoming_port); ret = udp_sock_create(wg->creating_net, &port4, &new4); if (ret < 0) { pr_err("Could not create outgoing IPv4 socket\n"); goto out; } #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 3, 0) /* Since udp_port_cfg only learned of ipv6_v6only in 4.3, we do this horrible * hack here and set the sysctl variable temporarily to something that will * set the right option for us in sock_create. It's super racey! */ old_bindv6only = wg->creating_net->ipv6.sysctl.bindv6only; wg->creating_net->ipv6.sysctl.bindv6only = 1; #endif ret = udp_sock_create(wg->creating_net, &port6, &new6); #if LINUX_VERSION_CODE < KERNEL_VERSION(4, 3, 0) wg->creating_net->ipv6.sysctl.bindv6only = old_bindv6only; #endif if (ret < 0) { pr_err("Could not create outgoing IPv6 socket\n"); udp_tunnel_sock_release(new4); goto out; } set_sock_opts(new4); set_sock_opts(new6); setup_udp_tunnel_sock(wg->creating_net, new4, &cfg); setup_udp_tunnel_sock(wg->creating_net, new6, &cfg); rcu_assign_pointer(wg->sock4, new4->sk); rcu_assign_pointer(wg->sock6, new6->sk); out: mutex_unlock(&wg->socket_update_lock); return ret; } void socket_uninit(struct wireguard_device *wg) { struct sock *old4, *old6; mutex_lock(&wg->socket_update_lock); old4 = rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock)); old6 = rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock)); rcu_assign_pointer(wg->sock4, NULL); rcu_assign_pointer(wg->sock6, NULL); mutex_unlock(&wg->socket_update_lock); synchronize_rcu(); sock_free(old4); sock_free(old6); }