From 9fe1150c42988ff788e8a51425736b20d24833b9 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Tue, 23 May 2017 14:14:21 +0200 Subject: handshake: process in parallel --- src/device.c | 55 ++++++++++++++++++++++++++++++++++++------------------- src/device.h | 11 +++++++++-- src/messages.h | 1 - src/peer.c | 4 ++-- src/receive.c | 21 ++++++++++++--------- src/send.c | 2 +- src/timers.c | 2 +- 7 files changed, 61 insertions(+), 35 deletions(-) diff --git a/src/device.c b/src/device.c index 5d65277..e10aeed 100644 --- a/src/device.c +++ b/src/device.c @@ -237,7 +237,8 @@ static void destruct(struct net_device *dev) mutex_lock(&wg->device_update_lock); peer_remove_all(wg); wg->incoming_port = 0; - destroy_workqueue(wg->handshake_wq); + destroy_workqueue(wg->incoming_handshake_wq); + destroy_workqueue(wg->peer_wq); #ifdef CONFIG_WIREGUARD_PARALLEL padata_free(wg->encrypt_pd); padata_free(wg->decrypt_pd); @@ -253,7 +254,7 @@ static void destruct(struct net_device *dev) #endif mutex_unlock(&wg->device_update_lock); free_percpu(dev->tstats); - + free_percpu(wg->incoming_handshakes_worker); put_net(wg->creating_net); pr_debug("Device %s has been deleted\n", dev->name); @@ -292,7 +293,7 @@ static void setup(struct net_device *dev) static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *tb[], struct nlattr *data[]) { - int ret = -ENOMEM; + int ret = -ENOMEM, cpu; struct wireguard_device *wg = netdev_priv(dev); wg->creating_net = get_net(src_net); @@ -300,7 +301,6 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t mutex_init(&wg->socket_update_lock); mutex_init(&wg->device_update_lock); skb_queue_head_init(&wg->incoming_handshakes); - INIT_WORK(&wg->incoming_handshakes_work, packet_process_queued_handshake_packets); pubkey_hashtable_init(&wg->peer_hashtable); index_hashtable_init(&wg->index_hashtable); routing_table_init(&wg->peer_routing_table); @@ -310,61 +310,78 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t if (!dev->tstats) goto error_1; - wg->handshake_wq = alloc_workqueue("wg-kex-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); - if (!wg->handshake_wq) + wg->incoming_handshakes_worker = alloc_percpu(struct handshake_worker); + if (!wg->incoming_handshakes_worker) goto error_2; + for_each_possible_cpu(cpu) { + per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->wg = wg; + INIT_WORK(&per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work, packet_process_queued_handshake_packets); + } + atomic_set(&wg->incoming_handshake_seqnr, 0); + + wg->incoming_handshake_wq = alloc_workqueue("wg-kex-%s", WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name); + if (!wg->incoming_handshake_wq) + goto error_3; + + wg->peer_wq = alloc_workqueue("wg-kex-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); + if (!wg->peer_wq) + goto error_4; #ifdef CONFIG_WIREGUARD_PARALLEL wg->crypt_wq = alloc_workqueue("wg-crypt-%s", WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 2, dev->name); if (!wg->crypt_wq) - goto error_3; + goto error_5; wg->encrypt_pd = padata_alloc_possible(wg->crypt_wq); if (!wg->encrypt_pd) - goto error_4; + goto error_6; padata_start(wg->encrypt_pd); wg->decrypt_pd = padata_alloc_possible(wg->crypt_wq); if (!wg->decrypt_pd) - goto error_5; + goto error_7; padata_start(wg->decrypt_pd); #endif ret = cookie_checker_init(&wg->cookie_checker, wg); if (ret < 0) - goto error_6; + goto error_8; #ifdef CONFIG_PM_SLEEP wg->clear_peers_on_suspend.notifier_call = suspending_clear_noise_peers; ret = register_pm_notifier(&wg->clear_peers_on_suspend); if (ret < 0) - goto error_7; + goto error_9; #endif ret = register_netdevice(dev); if (ret < 0) - goto error_8; + goto error_10; pr_debug("Device %s has been created\n", dev->name); return 0; -error_8: +error_10: #ifdef CONFIG_PM_SLEEP unregister_pm_notifier(&wg->clear_peers_on_suspend); -error_7: +error_9: #endif cookie_checker_uninit(&wg->cookie_checker); -error_6: +error_8: #ifdef CONFIG_WIREGUARD_PARALLEL padata_free(wg->decrypt_pd); -error_5: +error_7: padata_free(wg->encrypt_pd); -error_4: +error_6: destroy_workqueue(wg->crypt_wq); -error_3: +error_5: #endif - destroy_workqueue(wg->handshake_wq); + destroy_workqueue(wg->peer_wq); +error_4: + destroy_workqueue(wg->incoming_handshake_wq); +error_3: + free_percpu(wg->incoming_handshakes_worker); error_2: free_percpu(dev->tstats); error_1: diff --git a/src/device.h b/src/device.h index cdfb5f7..f443191 100644 --- a/src/device.h +++ b/src/device.h @@ -16,15 +16,22 @@ #include #include +struct wireguard_device; +struct handshake_worker { + struct wireguard_device *wg; + struct work_struct work; +}; + struct wireguard_device { struct sock __rcu *sock4, *sock6; u16 incoming_port; u32 fwmark; struct net *creating_net; - struct workqueue_struct *handshake_wq; struct noise_static_identity static_identity; + struct workqueue_struct *incoming_handshake_wq, *peer_wq; struct sk_buff_head incoming_handshakes; - struct work_struct incoming_handshakes_work; + atomic_t incoming_handshake_seqnr; + struct handshake_worker __percpu *incoming_handshakes_worker; struct cookie_checker cookie_checker; struct pubkey_hashtable peer_hashtable; struct index_hashtable index_hashtable; diff --git a/src/messages.h b/src/messages.h index defc831..6119cd5 100644 --- a/src/messages.h +++ b/src/messages.h @@ -50,7 +50,6 @@ enum limits { KEEPALIVE_TIMEOUT = 10 * HZ, MAX_TIMER_HANDSHAKES = (90 * HZ) / REKEY_TIMEOUT, MAX_QUEUED_INCOMING_HANDSHAKES = 4096, - MAX_BURST_INCOMING_HANDSHAKES = 16, MAX_QUEUED_OUTGOING_PACKETS = 1024 }; diff --git a/src/peer.c b/src/peer.c index b6ead5e..cc84ce6 100644 --- a/src/peer.c +++ b/src/peer.c @@ -77,8 +77,8 @@ void peer_remove(struct wireguard_peer *peer) timers_uninit_peer(peer); routing_table_remove_by_peer(&peer->device->peer_routing_table, peer); pubkey_hashtable_remove(&peer->device->peer_hashtable, peer); - if (peer->device->handshake_wq) - flush_workqueue(peer->device->handshake_wq); + if (peer->device->peer_wq) + flush_workqueue(peer->device->peer_wq); skb_queue_purge(&peer->tx_packet_queue); peer_put(peer); } diff --git a/src/receive.c b/src/receive.c index ed146d1..c5d0b12 100644 --- a/src/receive.c +++ b/src/receive.c @@ -103,7 +103,7 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff return; } - under_load = skb_queue_len(&wg->incoming_handshakes) >= MAX_QUEUED_INCOMING_HANDSHAKES / 2; + under_load = skb_queue_len(&wg->incoming_handshakes) >= MAX_QUEUED_INCOMING_HANDSHAKES / 8; mac_state = cookie_validate_packet(&wg->cookie_checker, skb, under_load); if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)) packet_needs_cookie = false; @@ -171,17 +171,13 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff void packet_process_queued_handshake_packets(struct work_struct *work) { - struct wireguard_device *wg = container_of(work, struct wireguard_device, incoming_handshakes_work); + struct wireguard_device *wg = container_of(work, struct handshake_worker, work)->wg; struct sk_buff *skb; - size_t num_processed = 0; while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) { receive_handshake_packet(wg, skb); dev_kfree_skb(skb); - if (++num_processed == MAX_BURST_INCOMING_HANDSHAKES) { - queue_work(wg->handshake_wq, &wg->incoming_handshakes_work); - return; - } + cond_resched(); } } @@ -291,15 +287,22 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) switch (message_type) { case MESSAGE_HANDSHAKE_INITIATION: case MESSAGE_HANDSHAKE_RESPONSE: - case MESSAGE_HANDSHAKE_COOKIE: + case MESSAGE_HANDSHAKE_COOKIE: { + int cpu_index, cpu, target_cpu; if (skb_queue_len(&wg->incoming_handshakes) > MAX_QUEUED_INCOMING_HANDSHAKES) { net_dbg_skb_ratelimited("Too many handshakes queued, dropping packet from %pISpfsc\n", skb); goto err; } skb_queue_tail(&wg->incoming_handshakes, skb); + /* Select the CPU in a round-robin */ + cpu_index = ((unsigned int)atomic_inc_return(&wg->incoming_handshake_seqnr)) % cpumask_weight(cpu_online_mask); + target_cpu = cpumask_first(cpu_online_mask); + for (cpu = 0; cpu < cpu_index; ++cpu) + target_cpu = cpumask_next(target_cpu, cpu_online_mask); /* Queues up a call to packet_process_queued_handshake_packets(skb): */ - queue_work(wg->handshake_wq, &wg->incoming_handshakes_work); + queue_work_on(target_cpu, wg->incoming_handshake_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, target_cpu)->work); break; + } case MESSAGE_DATA: PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb); packet_consume_data(skb, wg); diff --git a/src/send.c b/src/send.c index a952da7..72b2854 100644 --- a/src/send.c +++ b/src/send.c @@ -56,7 +56,7 @@ void packet_queue_handshake_initiation(struct wireguard_peer *peer) return; /* Queues up calling packet_send_queued_handshakes(peer), where we do a peer_put(peer) after: */ - if (!queue_work(peer->device->handshake_wq, &peer->transmit_handshake_work)) + if (!queue_work(peer->device->peer_wq, &peer->transmit_handshake_work)) peer_put(peer); /* If the work was already queued, we want to drop the extra reference */ } diff --git a/src/timers.c b/src/timers.c index 92c7c78..fda7b93 100644 --- a/src/timers.c +++ b/src/timers.c @@ -74,7 +74,7 @@ static void expired_new_handshake(unsigned long ptr) static void expired_kill_ephemerals(unsigned long ptr) { peer_get_from_ptr(ptr); - if (!queue_work(peer->device->handshake_wq, &peer->clear_peer_work)) /* Takes our reference. */ + if (!queue_work(peer->device->peer_wq, &peer->clear_peer_work)) /* Takes our reference. */ peer_put(peer); /* If the work was already on the queue, we want to drop the extra reference */ } static void queued_expired_kill_ephemerals(struct work_struct *work) -- cgit v1.2.3-59-g8ed1b