diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2017-06-07 01:39:08 -0500 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2017-09-18 17:38:16 +0200 |
commit | 0bc7c9d057d137b72c54d2da7fca522d36128f6a (patch) | |
tree | 44b70fb62507849b1e4ef9a2a78bddf9a108165e /src/receive.c | |
parent | compat: ensure we can build without compat.h (diff) | |
download | wireguard-monolithic-historical-0bc7c9d057d137b72c54d2da7fca522d36128f6a.tar.xz wireguard-monolithic-historical-0bc7c9d057d137b72c54d2da7fca522d36128f6a.zip |
queue: entirely rework parallel system
This removes our dependency on padata and moves to a different mode of
multiprocessing that is more efficient.
This began as Samuel Holland's GSoC project and was gradually
reworked/redesigned/rebased into this present commit, which is a
combination of his initial contribution and my subsequent rewriting and
redesigning.
Diffstat (limited to 'src/receive.c')
-rw-r--r-- | src/receive.c | 169 |
1 files changed, 155 insertions, 14 deletions
diff --git a/src/receive.c b/src/receive.c index da229df..a7f6004 100644 --- a/src/receive.c +++ b/src/receive.c @@ -1,11 +1,12 @@ /* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */ -#include "packets.h" +#include "queueing.h" #include "device.h" #include "peer.h" #include "timers.h" #include "messages.h" #include "cookie.h" +#include "socket.h" #include <linux/ip.h> #include <linux/ipv6.h> @@ -145,9 +146,9 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff peer_put(peer); } -void packet_process_queued_handshake_packets(struct work_struct *work) +void packet_handshake_receive_worker(struct work_struct *work) { - struct wireguard_device *wg = container_of(work, struct handshake_worker, work)->wg; + struct wireguard_device *wg = container_of(work, struct multicore_worker, work)->ptr; struct sk_buff *skb; while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) { @@ -173,10 +174,74 @@ static inline void keep_key_fresh(struct wireguard_peer *peer) if (send) { peer->sent_lastminute_handshake = true; - packet_queue_handshake_initiation(peer, false); + packet_send_queued_handshake_initiation(peer, false); } } +static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *key) +{ + struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1]; + struct sk_buff *trailer; + int num_frags; + + if (unlikely(!key)) + return false; + + if (unlikely(!key->is_valid || time_is_before_eq_jiffies64(key->birthdate + REJECT_AFTER_TIME) || key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) { + key->is_valid = false; + return false; + } + + PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter); + skb_pull(skb, sizeof(struct message_data)); + num_frags = skb_cow_data(skb, 0, &trailer); + if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg))) + return false; + + sg_init_table(sg, num_frags); + if (skb_to_sgvec(skb, sg, 0, skb->len) <= 0) + return false; + + if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key)) + return false; + + return !pskb_trim(skb, skb->len - noise_encrypted_len(0)); +} + +/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */ +static inline bool counter_validate(union noise_counter *counter, u64 their_counter) +{ + bool ret = false; + unsigned long index, index_current, top, i; + spin_lock_bh(&counter->receive.lock); + + if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || their_counter >= REJECT_AFTER_MESSAGES)) + goto out; + + ++their_counter; + + if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < counter->receive.counter)) + goto out; + + index = their_counter >> ilog2(BITS_PER_LONG); + + if (likely(their_counter > counter->receive.counter)) { + index_current = counter->receive.counter >> ilog2(BITS_PER_LONG); + top = min_t(unsigned long, index - index_current, COUNTER_BITS_TOTAL / BITS_PER_LONG); + for (i = 1; i <= top; ++i) + counter->receive.backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0; + counter->receive.counter = their_counter; + } + + index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1; + ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), &counter->receive.backtrack[index]); + +out: + spin_unlock_bh(&counter->receive.lock); + return ret; +} +#include "selftest/counter.h" + void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key) { struct net_device *dev = peer->device->dev; @@ -187,7 +252,7 @@ void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer, if (unlikely(used_new_key)) { timers_handshake_complete(peer); - packet_send_queue(peer); + packet_send_staged_packets(peer); } keep_key_fresh(peer); @@ -262,7 +327,87 @@ packet_processed: continue_processing: timers_any_authenticated_packet_received(peer); timers_any_authenticated_packet_traversal(peer); - peer_put(peer); +} + +void packet_rx_worker(struct work_struct *work) +{ + struct crypt_ctx *ctx; + struct crypt_queue *queue = container_of(work, struct crypt_queue, work); + struct sk_buff *skb; + + local_bh_disable(); + while ((ctx = queue_first_per_peer(queue)) != NULL && atomic_read(&ctx->is_finished)) { + queue_dequeue(queue); + if (likely((skb = ctx->skb) != NULL)) { + if (likely(counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(skb)->nonce))) { + skb_reset(skb); + packet_consume_data_done(skb, ctx->peer, &ctx->endpoint, noise_received_with_keypair(&ctx->peer->keypairs, ctx->keypair)); + } + else { + net_dbg_ratelimited("%s: Packet has invalid nonce %Lu (max %Lu)\n", ctx->peer->device->dev->name, PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter); + dev_kfree_skb(skb); + } + } + noise_keypair_put(ctx->keypair); + peer_put(ctx->peer); + kmem_cache_free(crypt_ctx_cache, ctx); + } + local_bh_enable(); +} + +void packet_decrypt_worker(struct work_struct *work) +{ + struct crypt_ctx *ctx; + struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr; + struct wireguard_peer *peer; + + while ((ctx = queue_dequeue_per_device(queue)) != NULL) { + if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) { + dev_kfree_skb(ctx->skb); + ctx->skb = NULL; + } + /* Dereferencing ctx is unsafe once ctx->is_finished == true, so + * we take a reference here first. */ + peer = peer_rcu_get(ctx->peer); + atomic_set(&ctx->is_finished, true); + queue_work_on(choose_cpu(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &peer->rx_queue.work); + peer_put(peer); + } +} + +static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb) +{ + struct crypt_ctx *ctx; + struct noise_keypair *keypair; + __le32 idx = ((struct message_data *)skb->data)->key_idx; + + rcu_read_lock_bh(); + keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); + rcu_read_unlock_bh(); + if (unlikely(!keypair)) { + dev_kfree_skb(skb); + return; + } + + ctx = kmem_cache_zalloc(crypt_ctx_cache, GFP_ATOMIC); + if (unlikely(!ctx)) { + dev_kfree_skb(skb); + peer_put(ctx->keypair->entry.peer); + noise_keypair_put(keypair); + return; + } + ctx->keypair = keypair; + ctx->skb = skb; + /* We already have a reference to peer from index_hashtable_lookup. */ + ctx->peer = ctx->keypair->entry.peer; + + if (likely(queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &ctx->peer->rx_queue, ctx, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu))) + return; /* Successful. No need to drop references below. */ + + noise_keypair_put(ctx->keypair); + peer_put(ctx->peer); + dev_kfree_skb(ctx->skb); + kmem_cache_free(crypt_ctx_cache, ctx); } void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) @@ -274,24 +419,20 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) case MESSAGE_HANDSHAKE_INITIATION: case MESSAGE_HANDSHAKE_RESPONSE: case MESSAGE_HANDSHAKE_COOKIE: { - int cpu_index, cpu, target_cpu; + int cpu; if (skb_queue_len(&wg->incoming_handshakes) > MAX_QUEUED_INCOMING_HANDSHAKES) { net_dbg_skb_ratelimited("%s: Too many handshakes queued, dropping packet from %pISpfsc\n", wg->dev->name, skb); goto err; } skb_queue_tail(&wg->incoming_handshakes, skb); - /* Select the CPU in a round-robin */ - cpu_index = ((unsigned int)atomic_inc_return(&wg->incoming_handshake_seqnr)) % cpumask_weight(cpu_online_mask); - target_cpu = cpumask_first(cpu_online_mask); - for (cpu = 0; cpu < cpu_index; ++cpu) - target_cpu = cpumask_next(target_cpu, cpu_online_mask); /* Queues up a call to packet_process_queued_handshake_packets(skb): */ - queue_work_on(target_cpu, wg->incoming_handshake_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, target_cpu)->work); + cpu = cpumask_next_online(&wg->incoming_handshake_cpu); + queue_work_on(cpu, wg->handshake_receive_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work); break; } case MESSAGE_DATA: PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb); - packet_consume_data(skb, wg); + packet_consume_data(wg, skb); break; default: net_dbg_skb_ratelimited("%s: Invalid packet from %pISpfsc\n", wg->dev->name, skb); |