aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/receive.c
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2017-06-07 01:39:08 -0500
committerJason A. Donenfeld <Jason@zx2c4.com>2017-09-18 17:38:16 +0200
commit0bc7c9d057d137b72c54d2da7fca522d36128f6a (patch)
tree44b70fb62507849b1e4ef9a2a78bddf9a108165e /src/receive.c
parentcompat: ensure we can build without compat.h (diff)
downloadwireguard-monolithic-historical-0bc7c9d057d137b72c54d2da7fca522d36128f6a.tar.xz
wireguard-monolithic-historical-0bc7c9d057d137b72c54d2da7fca522d36128f6a.zip
queue: entirely rework parallel system
This removes our dependency on padata and moves to a different mode of multiprocessing that is more efficient. This began as Samuel Holland's GSoC project and was gradually reworked/redesigned/rebased into this present commit, which is a combination of his initial contribution and my subsequent rewriting and redesigning.
Diffstat (limited to 'src/receive.c')
-rw-r--r--src/receive.c169
1 files changed, 155 insertions, 14 deletions
diff --git a/src/receive.c b/src/receive.c
index da229df..a7f6004 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -1,11 +1,12 @@
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
-#include "packets.h"
+#include "queueing.h"
#include "device.h"
#include "peer.h"
#include "timers.h"
#include "messages.h"
#include "cookie.h"
+#include "socket.h"
#include <linux/ip.h>
#include <linux/ipv6.h>
@@ -145,9 +146,9 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff
peer_put(peer);
}
-void packet_process_queued_handshake_packets(struct work_struct *work)
+void packet_handshake_receive_worker(struct work_struct *work)
{
- struct wireguard_device *wg = container_of(work, struct handshake_worker, work)->wg;
+ struct wireguard_device *wg = container_of(work, struct multicore_worker, work)->ptr;
struct sk_buff *skb;
while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) {
@@ -173,10 +174,74 @@ static inline void keep_key_fresh(struct wireguard_peer *peer)
if (send) {
peer->sent_lastminute_handshake = true;
- packet_queue_handshake_initiation(peer, false);
+ packet_send_queued_handshake_initiation(peer, false);
}
}
+static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *key)
+{
+ struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1];
+ struct sk_buff *trailer;
+ int num_frags;
+
+ if (unlikely(!key))
+ return false;
+
+ if (unlikely(!key->is_valid || time_is_before_eq_jiffies64(key->birthdate + REJECT_AFTER_TIME) || key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
+ key->is_valid = false;
+ return false;
+ }
+
+ PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter);
+ skb_pull(skb, sizeof(struct message_data));
+ num_frags = skb_cow_data(skb, 0, &trailer);
+ if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
+ return false;
+
+ sg_init_table(sg, num_frags);
+ if (skb_to_sgvec(skb, sg, 0, skb->len) <= 0)
+ return false;
+
+ if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key))
+ return false;
+
+ return !pskb_trim(skb, skb->len - noise_encrypted_len(0));
+}
+
+/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
+static inline bool counter_validate(union noise_counter *counter, u64 their_counter)
+{
+ bool ret = false;
+ unsigned long index, index_current, top, i;
+ spin_lock_bh(&counter->receive.lock);
+
+ if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || their_counter >= REJECT_AFTER_MESSAGES))
+ goto out;
+
+ ++their_counter;
+
+ if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < counter->receive.counter))
+ goto out;
+
+ index = their_counter >> ilog2(BITS_PER_LONG);
+
+ if (likely(their_counter > counter->receive.counter)) {
+ index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
+ top = min_t(unsigned long, index - index_current, COUNTER_BITS_TOTAL / BITS_PER_LONG);
+ for (i = 1; i <= top; ++i)
+ counter->receive.backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
+ counter->receive.counter = their_counter;
+ }
+
+ index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
+ ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), &counter->receive.backtrack[index]);
+
+out:
+ spin_unlock_bh(&counter->receive.lock);
+ return ret;
+}
+#include "selftest/counter.h"
+
void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key)
{
struct net_device *dev = peer->device->dev;
@@ -187,7 +252,7 @@ void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer,
if (unlikely(used_new_key)) {
timers_handshake_complete(peer);
- packet_send_queue(peer);
+ packet_send_staged_packets(peer);
}
keep_key_fresh(peer);
@@ -262,7 +327,87 @@ packet_processed:
continue_processing:
timers_any_authenticated_packet_received(peer);
timers_any_authenticated_packet_traversal(peer);
- peer_put(peer);
+}
+
+void packet_rx_worker(struct work_struct *work)
+{
+ struct crypt_ctx *ctx;
+ struct crypt_queue *queue = container_of(work, struct crypt_queue, work);
+ struct sk_buff *skb;
+
+ local_bh_disable();
+ while ((ctx = queue_first_per_peer(queue)) != NULL && atomic_read(&ctx->is_finished)) {
+ queue_dequeue(queue);
+ if (likely((skb = ctx->skb) != NULL)) {
+ if (likely(counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(skb)->nonce))) {
+ skb_reset(skb);
+ packet_consume_data_done(skb, ctx->peer, &ctx->endpoint, noise_received_with_keypair(&ctx->peer->keypairs, ctx->keypair));
+ }
+ else {
+ net_dbg_ratelimited("%s: Packet has invalid nonce %Lu (max %Lu)\n", ctx->peer->device->dev->name, PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter);
+ dev_kfree_skb(skb);
+ }
+ }
+ noise_keypair_put(ctx->keypair);
+ peer_put(ctx->peer);
+ kmem_cache_free(crypt_ctx_cache, ctx);
+ }
+ local_bh_enable();
+}
+
+void packet_decrypt_worker(struct work_struct *work)
+{
+ struct crypt_ctx *ctx;
+ struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr;
+ struct wireguard_peer *peer;
+
+ while ((ctx = queue_dequeue_per_device(queue)) != NULL) {
+ if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) {
+ dev_kfree_skb(ctx->skb);
+ ctx->skb = NULL;
+ }
+ /* Dereferencing ctx is unsafe once ctx->is_finished == true, so
+ * we take a reference here first. */
+ peer = peer_rcu_get(ctx->peer);
+ atomic_set(&ctx->is_finished, true);
+ queue_work_on(choose_cpu(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &peer->rx_queue.work);
+ peer_put(peer);
+ }
+}
+
+static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb)
+{
+ struct crypt_ctx *ctx;
+ struct noise_keypair *keypair;
+ __le32 idx = ((struct message_data *)skb->data)->key_idx;
+
+ rcu_read_lock_bh();
+ keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx));
+ rcu_read_unlock_bh();
+ if (unlikely(!keypair)) {
+ dev_kfree_skb(skb);
+ return;
+ }
+
+ ctx = kmem_cache_zalloc(crypt_ctx_cache, GFP_ATOMIC);
+ if (unlikely(!ctx)) {
+ dev_kfree_skb(skb);
+ peer_put(ctx->keypair->entry.peer);
+ noise_keypair_put(keypair);
+ return;
+ }
+ ctx->keypair = keypair;
+ ctx->skb = skb;
+ /* We already have a reference to peer from index_hashtable_lookup. */
+ ctx->peer = ctx->keypair->entry.peer;
+
+ if (likely(queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &ctx->peer->rx_queue, ctx, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu)))
+ return; /* Successful. No need to drop references below. */
+
+ noise_keypair_put(ctx->keypair);
+ peer_put(ctx->peer);
+ dev_kfree_skb(ctx->skb);
+ kmem_cache_free(crypt_ctx_cache, ctx);
}
void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
@@ -274,24 +419,20 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
case MESSAGE_HANDSHAKE_INITIATION:
case MESSAGE_HANDSHAKE_RESPONSE:
case MESSAGE_HANDSHAKE_COOKIE: {
- int cpu_index, cpu, target_cpu;
+ int cpu;
if (skb_queue_len(&wg->incoming_handshakes) > MAX_QUEUED_INCOMING_HANDSHAKES) {
net_dbg_skb_ratelimited("%s: Too many handshakes queued, dropping packet from %pISpfsc\n", wg->dev->name, skb);
goto err;
}
skb_queue_tail(&wg->incoming_handshakes, skb);
- /* Select the CPU in a round-robin */
- cpu_index = ((unsigned int)atomic_inc_return(&wg->incoming_handshake_seqnr)) % cpumask_weight(cpu_online_mask);
- target_cpu = cpumask_first(cpu_online_mask);
- for (cpu = 0; cpu < cpu_index; ++cpu)
- target_cpu = cpumask_next(target_cpu, cpu_online_mask);
/* Queues up a call to packet_process_queued_handshake_packets(skb): */
- queue_work_on(target_cpu, wg->incoming_handshake_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, target_cpu)->work);
+ cpu = cpumask_next_online(&wg->incoming_handshake_cpu);
+ queue_work_on(cpu, wg->handshake_receive_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work);
break;
}
case MESSAGE_DATA:
PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb);
- packet_consume_data(skb, wg);
+ packet_consume_data(wg, skb);
break;
default:
net_dbg_skb_ratelimited("%s: Invalid packet from %pISpfsc\n", wg->dev->name, skb);