/* Copyright 2015-2016 Jason A. Donenfeld . All Rights Reserved. */ #include "wireguard.h" #include "packets.h" #include "socket.h" #include "timers.h" #include "device.h" #include "config.h" #include "peer.h" #include "uapi.h" #include "messages.h" #include #include #include #include #include #include #include #include #include #include #include #define MAX_QUEUED_PACKETS 1024 static int init(struct net_device *dev) { dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); if (!dev->tstats) return -ENOMEM; return 0; } static void uninit(struct net_device *dev) { free_percpu(dev->tstats); } static int open_peer(struct wireguard_peer *peer, void *data) { socket_set_peer_dst(peer); timers_init_peer(peer); packet_send_queue(peer); if (peer->persistent_keepalive_interval) packet_send_keepalive(peer); return 0; } static int open(struct net_device *dev) { struct wireguard_device *wg = netdev_priv(dev); int rc = socket_init(wg); if (rc < 0) return rc; peer_for_each(wg, open_peer, NULL); return 0; } static int stop_peer(struct wireguard_peer *peer, void *data) { timers_uninit_peer_wait(peer); noise_handshake_clear(&peer->handshake); noise_keypairs_clear(&peer->keypairs); return 0; } static int stop(struct net_device *dev) { struct wireguard_device *wg = netdev_priv(dev); peer_for_each(wg, stop_peer, NULL); skb_queue_purge(&wg->incoming_handshakes); socket_uninit(wg); return 0; } static void skb_unsendable(struct sk_buff *skb, struct net_device *dev) { /* This conntrack stuff is because the rate limiting needs to be applied * to the original src IP, so we have to restore saddr in the IP header. */ struct nf_conn *ct = NULL; #if defined(CONFIG_NF_CONNTRACK) || defined(CONFIG_NF_CONNTRACK_MODULE) enum ip_conntrack_info ctinfo; ct = nf_ct_get(skb, &ctinfo); #endif ++dev->stats.tx_errors; if (skb->len < sizeof(struct iphdr)) goto free; if (ip_hdr(skb)->version == 4) { if (ct) ip_hdr(skb)->saddr = ct->tuplehash[0].tuple.src.u3.ip; icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0); } else if (ip_hdr(skb)->version == 6) { if (ct) ipv6_hdr(skb)->saddr = ct->tuplehash[0].tuple.src.u3.in6; icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0); } free: kfree_skb(skb); } static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev) { struct wireguard_device *wg = netdev_priv(dev); struct wireguard_peer *peer; int ret; if (unlikely(dev_recursion_level() > 4)) { net_dbg_ratelimited("Routing loop detected\n"); skb_unsendable(skb, dev); return -ELOOP; } peer = routing_table_lookup_dst(&wg->peer_routing_table, skb); if (unlikely(!peer)) { skb_unsendable(skb, dev); return -ENOKEY; } read_lock_bh(&peer->endpoint_lock); ret = peer->endpoint_addr.ss_family != AF_INET && peer->endpoint_addr.ss_family != AF_INET6; read_unlock_bh(&peer->endpoint_lock); if (unlikely(ret)) { net_dbg_ratelimited("No valid endpoint has been configured or discovered for device\n"); peer_put(peer); skb_unsendable(skb, dev); return -EHOSTUNREACH; } /* If the queue is getting too big, we start removing the oldest packets until it's small again. * We do this before adding the new packet, so we don't remove GSO segments that are in excess. */ while (skb_queue_len(&peer->tx_packet_queue) > MAX_QUEUED_PACKETS) dev_kfree_skb(skb_dequeue(&peer->tx_packet_queue)); if (!skb_is_gso(skb)) skb->next = NULL; else { struct sk_buff *segs = skb_gso_segment(skb, 0); if (unlikely(IS_ERR(segs))) { skb_unsendable(skb, dev); peer_put(peer); return PTR_ERR(segs); } dev_kfree_skb(skb); skb = segs; } while (skb) { struct sk_buff *next = skb->next; skb->next = skb->prev = NULL; skb = skb_share_check(skb, GFP_ATOMIC); if (unlikely(!skb)) continue; /* We only need to keep the original dst around for icmp, * so at this point we're in a position to drop it. */ skb_dst_drop(skb); skb_queue_tail(&peer->tx_packet_queue, skb); skb = next; } ret = packet_send_queue(peer); peer_put(peer); return ret; } static int ioctl(struct net_device *dev, struct ifreq *ifr, int cmd) { struct wireguard_device *wg = netdev_priv(dev); if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN)) return -EPERM; switch (cmd) { case WG_GET_DEVICE: return config_get_device(wg, ifr->ifr_ifru.ifru_data); case WG_SET_DEVICE: return config_set_device(wg, ifr->ifr_ifru.ifru_data); } return -EINVAL; } static const struct net_device_ops netdev_ops = { .ndo_init = init, .ndo_uninit = uninit, .ndo_open = open, .ndo_stop = stop, .ndo_start_xmit = xmit, .ndo_get_stats64 = ip_tunnel_get_stats64, .ndo_do_ioctl = ioctl }; static void destruct(struct net_device *dev) { struct wireguard_device *wg = netdev_priv(dev); mutex_lock(&wg->device_update_lock); peer_remove_all(wg); wg->incoming_port = 0; destroy_workqueue(wg->workqueue); #ifdef CONFIG_WIREGUARD_PARALLEL padata_free(wg->parallel_send); padata_free(wg->parallel_receive); destroy_workqueue(wg->parallelqueue); #endif routing_table_free(&wg->peer_routing_table); memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity)); skb_queue_purge(&wg->incoming_handshakes); socket_uninit(wg); cookie_checker_uninit(&wg->cookie_checker); mutex_unlock(&wg->device_update_lock); put_net(wg->creating_net); pr_debug("Device %s has been deleted\n", dev->name); free_netdev(dev); } #define WG_FEATURES (NETIF_F_HW_CSUM | NETIF_F_RXCSUM | NETIF_F_SG | NETIF_F_GSO | NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA) static void setup(struct net_device *dev) { struct wireguard_device *wg = netdev_priv(dev); dev->netdev_ops = &netdev_ops; dev->destructor = destruct; dev->hard_header_len = 0; dev->addr_len = 0; dev->needed_headroom = DATA_PACKET_HEAD_ROOM; dev->needed_tailroom = noise_encrypted_len(MESSAGE_PADDING_MULTIPLE); dev->type = ARPHRD_NONE; dev->flags = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST; #if LINUX_VERSION_CODE >= KERNEL_VERSION(4, 3, 0) dev->flags |= IFF_NO_QUEUE; #else dev->tx_queue_len = 0; #endif dev->features |= NETIF_F_LLTX; dev->features |= WG_FEATURES; dev->hw_features |= WG_FEATURES; dev->hw_enc_features |= WG_FEATURES; dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH - sizeof(struct udphdr) - max(sizeof(struct ipv6hdr), sizeof(struct iphdr)); /* We need to keep the dst around in case of icmp replies. */ netif_keep_dst(dev); memset(wg, 0, sizeof(struct wireguard_device)); } static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *tb[], struct nlattr *data[]) { int ret = 0; struct wireguard_device *wg = netdev_priv(dev); wg->creating_net = get_net(src_net); init_rwsem(&wg->static_identity.lock); mutex_init(&wg->socket_update_lock); mutex_init(&wg->device_update_lock); skb_queue_head_init(&wg->incoming_handshakes); INIT_WORK(&wg->incoming_handshakes_work, packet_process_queued_handshake_packets); pubkey_hashtable_init(&wg->peer_hashtable); index_hashtable_init(&wg->index_hashtable); routing_table_init(&wg->peer_routing_table); INIT_LIST_HEAD(&wg->peer_list); wg->workqueue = alloc_workqueue(KBUILD_MODNAME "-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); if (!wg->workqueue) { ret = -ENOMEM; goto err; } #ifdef CONFIG_WIREGUARD_PARALLEL wg->parallelqueue = alloc_workqueue(KBUILD_MODNAME "-crypt-%s", WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 1, dev->name); if (!wg->parallelqueue) { ret = -ENOMEM; goto err; } wg->parallel_send = padata_alloc_possible(wg->parallelqueue); if (!wg->parallel_send) { ret = -ENOMEM; goto err; } padata_start(wg->parallel_send); wg->parallel_receive = padata_alloc_possible(wg->parallelqueue); if (!wg->parallel_receive) { ret = -ENOMEM; goto err; } padata_start(wg->parallel_receive); #endif ret = cookie_checker_init(&wg->cookie_checker, wg); if (ret < 0) goto err; ret = register_netdevice(dev); if (ret < 0) goto err; pr_debug("Device %s has been created\n", dev->name); return 0; err: put_net(src_net); if (wg->workqueue) destroy_workqueue(wg->workqueue); #ifdef CONFIG_WIREGUARD_PARALLEL if (wg->parallel_send) padata_free(wg->parallel_send); if (wg->parallel_receive) padata_free(wg->parallel_receive); if (wg->parallelqueue) destroy_workqueue(wg->parallelqueue); #endif if (wg->cookie_checker.device) cookie_checker_uninit(&wg->cookie_checker); return ret; } static void dellink(struct net_device *dev, struct list_head *head) { unregister_netdevice_queue(dev, head); } static struct rtnl_link_ops link_ops __read_mostly = { .kind = KBUILD_MODNAME, .priv_size = sizeof(struct wireguard_device), .setup = setup, .newlink = newlink, .dellink = dellink }; int device_init(void) { int ret = rtnl_link_register(&link_ops); if (ret < 0) { pr_err("Cannot register link_ops\n"); return ret; } return ret; } void device_uninit(void) { rtnl_link_unregister(&link_ops); rcu_barrier(); }