diff options
Diffstat (limited to 'src/device.c')
-rw-r--r-- | src/device.c | 352 |
1 files changed, 352 insertions, 0 deletions
diff --git a/src/device.c b/src/device.c new file mode 100644 index 0000000..4076e58 --- /dev/null +++ b/src/device.c @@ -0,0 +1,352 @@ +/* Copyright 2015-2016 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */ + +#include "wireguard.h" +#include "packets.h" +#include "socket.h" +#include "timers.h" +#include "device.h" +#include "config.h" +#include "peer.h" +#include "uapi.h" +#include "messages.h" +#include <linux/module.h> +#include <linux/rtnetlink.h> +#include <linux/inet.h> +#include <linux/netdevice.h> +#include <linux/if_arp.h> +#include <linux/icmp.h> +#include <net/icmp.h> +#include <net/rtnetlink.h> +#include <net/ip_tunnels.h> +#include <net/netfilter/nf_conntrack.h> +#include <net/netfilter/nf_nat_core.h> + +#define MAX_QUEUED_PACKETS 1024 + +static int init(struct net_device *dev) +{ + dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!dev->tstats) + return -ENOMEM; + return 0; +} +static void uninit(struct net_device *dev) +{ + free_percpu(dev->tstats); +} + +static int open_peer(struct wireguard_peer *peer, void *data) +{ + socket_set_peer_dst(peer); + timers_init_peer(peer); + packet_send_queue(peer); + return 0; +} + +static int open(struct net_device *dev) +{ + struct wireguard_device *wg = netdev_priv(dev); + int rc = socket_init(wg); + if (rc < 0) + return rc; + peer_for_each(wg, open_peer, NULL); + return 0; +} + +static int stop_peer(struct wireguard_peer *peer, void *data) +{ + timers_uninit_peer_wait(peer); + noise_handshake_clear(&peer->handshake); + noise_keypairs_clear(&peer->keypairs); + return 0; +} + +static int stop(struct net_device *dev) +{ + struct wireguard_device *wg = netdev_priv(dev); + peer_for_each(wg, stop_peer, NULL); + skb_queue_purge(&wg->incoming_handshakes); + socket_uninit(wg); + return 0; +} + +static void skb_unsendable(struct sk_buff *skb, struct net_device *dev) +{ + /* This conntrack stuff is because the rate limiting needs to be applied + * to the original src IP, so we have to restore saddr in the IP header. */ + struct nf_conn *ct = NULL; +#if defined(CONFIG_NF_CONNTRACK) || defined(CONFIG_NF_CONNTRACK_MODULE) + enum ip_conntrack_info ctinfo; + ct = nf_ct_get(skb, &ctinfo); +#endif + ++dev->stats.tx_errors; + + if (skb->len < sizeof(struct iphdr)) + goto free; + + if (ip_hdr(skb)->version == 4) { + if (ct) + ip_hdr(skb)->saddr = ct->tuplehash[0].tuple.src.u3.ip; + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0); + } else if (ip_hdr(skb)->version == 6) { + if (ct) + ipv6_hdr(skb)->saddr = ct->tuplehash[0].tuple.src.u3.in6; + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0); + } +free: + kfree_skb(skb); +} + +static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct wireguard_device *wg = netdev_priv(dev); + struct wireguard_peer *peer; + int ret; + + if (unlikely(dev_recursion_level() > 4)) { + net_dbg_ratelimited("Routing loop detected\n"); + skb_unsendable(skb, dev); + return -ELOOP; + } + + dev->trans_start = jiffies; + + peer = routing_table_lookup_dst(&wg->peer_routing_table, skb); + if (unlikely(!peer)) { + skb_unsendable(skb, dev); + return -ENOKEY; + } + + read_lock_bh(&peer->endpoint_lock); + ret = unlikely(peer->endpoint_addr.ss_family != AF_INET && peer->endpoint_addr.ss_family != AF_INET6); + read_unlock_bh(&peer->endpoint_lock); + if (ret) { + net_dbg_ratelimited("No valid endpoint has been configured or discovered for device\n"); + peer_put(peer); + skb_unsendable(skb, dev); + return -EHOSTUNREACH; + } + + /* If the queue is getting too big, we start removing the oldest packets until it's small again. + * We do this before adding the new packet, so we don't remove GSO segments that are in excess. */ + while (skb_queue_len(&peer->tx_packet_queue) > MAX_QUEUED_PACKETS) + dev_kfree_skb(skb_dequeue(&peer->tx_packet_queue)); + + if (!skb_is_gso(skb)) + skb->next = NULL; + else { + struct sk_buff *segs = skb_gso_segment(skb, 0); + if (unlikely(IS_ERR(segs))) { + skb_unsendable(skb, dev); + peer_put(peer); + return PTR_ERR(segs); + } + dev_kfree_skb(skb); + skb = segs; + } + while (skb) { + struct sk_buff *next = skb->next; + skb->next = skb->prev = NULL; + + skb = skb_share_check(skb, GFP_ATOMIC); + if (unlikely(!skb)) + continue; + + /* We only need to keep the original dst around for icmp, + * so at this point we're in a position to drop it. */ + skb_dst_drop(skb); + + skb_queue_tail(&peer->tx_packet_queue, skb); + skb = next; + } + + ret = packet_send_queue(peer); + peer_put(peer); + return ret; +} + + +static int ioctl(struct net_device *dev, struct ifreq *ifr, int cmd) +{ + struct wireguard_device *wg = netdev_priv(dev); + + if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + switch (cmd) { + case WG_GET_DEVICE: + return config_get_device(wg, ifr->ifr_ifru.ifru_data); + case WG_SET_DEVICE: + return config_set_device(wg, ifr->ifr_ifru.ifru_data); + } + return -EINVAL; +} + +static const struct net_device_ops netdev_ops = { + .ndo_init = init, + .ndo_uninit = uninit, + .ndo_open = open, + .ndo_stop = stop, + .ndo_start_xmit = xmit, + .ndo_get_stats64 = ip_tunnel_get_stats64, + .ndo_do_ioctl = ioctl +}; + +static void destruct(struct net_device *dev) +{ + struct wireguard_device *wg = netdev_priv(dev); + + mutex_lock(&wg->device_update_lock); + peer_remove_all(wg); + wg->incoming_port = 0; + destroy_workqueue(wg->workqueue); +#ifdef CONFIG_WIREGUARD_PARALLEL + destroy_workqueue(wg->parallelqueue); + padata_free(wg->parallel_send); + padata_free(wg->parallel_receive); +#endif + routing_table_free(&wg->peer_routing_table); + memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity)); + skb_queue_purge(&wg->incoming_handshakes); + socket_uninit(wg); + cookie_checker_uninit(&wg->cookie_checker); + mutex_unlock(&wg->device_update_lock); + + put_net(wg->creating_net); + + pr_debug("Device %s has been deleted\n", dev->name); + free_netdev(dev); +} + +#define WG_FEATURES (NETIF_F_HW_CSUM | NETIF_F_RXCSUM | NETIF_F_SG | NETIF_F_GSO | NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA) + +static void setup(struct net_device *dev) +{ + struct wireguard_device *wg = netdev_priv(dev); + + dev->netdev_ops = &netdev_ops; + dev->destructor = destruct; + dev->hard_header_len = 0; + dev->addr_len = 0; + dev->needed_headroom = DATA_PACKET_HEAD_ROOM; + dev->needed_tailroom = noise_encrypted_len(MESSAGE_PADDING_MULTIPLE); + dev->type = ARPHRD_NONE; + dev->flags = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST; +#if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 3, 0) + dev->flags |= IFF_NO_QUEUE; +#else + dev->tx_queue_len = 0; +#endif + dev->features |= NETIF_F_LLTX; + dev->features |= WG_FEATURES; + dev->hw_features |= WG_FEATURES; + dev->hw_enc_features |= WG_FEATURES; + dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH - sizeof(struct udphdr) - max(sizeof(struct ipv6hdr), sizeof(struct iphdr)); + + /* We need to keep the dst around in case of icmp replies. */ + netif_keep_dst(dev); + + memset(wg, 0, sizeof(struct wireguard_device)); +} + +static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *tb[], struct nlattr *data[]) +{ + int ret = 0; + struct wireguard_device *wg = netdev_priv(dev); + + wg->creating_net = get_net(src_net); + init_rwsem(&wg->static_identity.lock); + mutex_init(&wg->socket_update_lock); + mutex_init(&wg->device_update_lock); + skb_queue_head_init(&wg->incoming_handshakes); + INIT_WORK(&wg->incoming_handshakes_work, packet_process_queued_handshake_packets); + pubkey_hashtable_init(&wg->peer_hashtable); + index_hashtable_init(&wg->index_hashtable); + routing_table_init(&wg->peer_routing_table); + INIT_LIST_HEAD(&wg->peer_list); + + wg->workqueue = alloc_workqueue(KBUILD_MODNAME "-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); + if (!wg->workqueue) { + ret = -ENOMEM; + goto err; + } + +#ifdef CONFIG_WIREGUARD_PARALLEL + wg->parallelqueue = alloc_workqueue(KBUILD_MODNAME "-crypt-%s", WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 1, dev->name); + if (!wg->parallelqueue) { + ret = -ENOMEM; + goto err; + } + + wg->parallel_send = padata_alloc_possible(wg->parallelqueue); + if (!wg->parallel_send) { + ret = -ENOMEM; + goto err; + } + padata_start(wg->parallel_send); + + wg->parallel_receive = padata_alloc_possible(wg->parallelqueue); + if (!wg->parallel_receive) { + ret = -ENOMEM; + goto err; + } + padata_start(wg->parallel_receive); +#endif + + ret = cookie_checker_init(&wg->cookie_checker, wg); + if (ret < 0) + goto err; + + ret = register_netdevice(dev); + if (ret < 0) + goto err; + + pr_debug("Device %s has been created\n", dev->name); + + return 0; + +err: + put_net(src_net); + if (wg->workqueue) + destroy_workqueue(wg->workqueue); +#ifdef CONFIG_WIREGUARD_PARALLEL + if (wg->parallelqueue) + destroy_workqueue(wg->parallelqueue); + if (wg->parallel_send) + padata_free(wg->parallel_send); + if (wg->parallel_receive) + padata_free(wg->parallel_receive); +#endif + if (wg->cookie_checker.device) + cookie_checker_uninit(&wg->cookie_checker); + return ret; +} + +static void dellink(struct net_device *dev, struct list_head *head) +{ + unregister_netdevice_queue(dev, head); +} + +static struct rtnl_link_ops link_ops __read_mostly = { + .kind = KBUILD_MODNAME, + .priv_size = sizeof(struct wireguard_device), + .setup = setup, + .newlink = newlink, + .dellink = dellink +}; + +int device_init(void) +{ + int ret = rtnl_link_register(&link_ops); + if (ret < 0) { + pr_err("Cannot register link_ops\n"); + return ret; + } + return ret; +} + +void device_uninit(void) +{ + rtnl_link_unregister(&link_ops); + rcu_barrier(); +} |