aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2017-06-07 01:39:08 -0500
committerJason A. Donenfeld <Jason@zx2c4.com>2017-09-18 17:38:16 +0200
commit0bc7c9d057d137b72c54d2da7fca522d36128f6a (patch)
tree44b70fb62507849b1e4ef9a2a78bddf9a108165e /src
parentcompat: ensure we can build without compat.h (diff)
downloadwireguard-monolithic-historical-0bc7c9d057d137b72c54d2da7fca522d36128f6a.tar.xz
wireguard-monolithic-historical-0bc7c9d057d137b72c54d2da7fca522d36128f6a.zip
queue: entirely rework parallel system
This removes our dependency on padata and moves to a different mode of multiprocessing that is more efficient. This began as Samuel Holland's GSoC project and was gradually reworked/redesigned/rebased into this present commit, which is a combination of his initial contribution and my subsequent rewriting and redesigning.
Diffstat (limited to '')
-rw-r--r--src/Kbuild5
-rw-r--r--src/Kconfig12
-rw-r--r--src/compat/Kbuild.include8
-rw-r--r--src/compat/padata/padata.c903
-rw-r--r--src/config.c4
-rw-r--r--src/data.c430
-rw-r--r--src/device.c84
-rw-r--r--src/device.h33
-rw-r--r--src/main.c14
-rw-r--r--src/messages.h5
-rw-r--r--src/noise.c2
-rw-r--r--src/packets.h63
-rw-r--r--src/peer.c21
-rw-r--r--src/peer.h15
-rw-r--r--src/queueing.c46
-rw-r--r--src/queueing.h196
-rw-r--r--src/receive.c169
-rw-r--r--src/send.c234
-rw-r--r--src/socket.c2
-rw-r--r--src/tests/qemu/kernel.config1
-rw-r--r--src/timers.c19
21 files changed, 669 insertions, 1597 deletions
diff --git a/src/Kbuild b/src/Kbuild
index c5b8718..e6c7bc9 100644
--- a/src/Kbuild
+++ b/src/Kbuild
@@ -2,7 +2,7 @@ ccflags-y := -O3 -fvisibility=hidden
ccflags-$(CONFIG_WIREGUARD_DEBUG) += -DDEBUG -g
ccflags-y += -Wframe-larger-than=8192
ccflags-y += -D'pr_fmt(fmt)=KBUILD_MODNAME ": " fmt'
-wireguard-y := main.o noise.o device.o peer.o timers.o data.o send.o receive.o socket.o config.o hashtables.o routingtable.o ratelimiter.o cookie.o
+wireguard-y := main.o noise.o device.o peer.o timers.o queueing.o send.o receive.o socket.o config.o hashtables.o routingtable.o ratelimiter.o cookie.o
wireguard-y += crypto/curve25519.o crypto/chacha20poly1305.o crypto/blake2s.o
ifeq ($(CONFIG_X86_64),y)
@@ -26,9 +26,6 @@ endif
ifneq ($(KBUILD_EXTMOD),)
CONFIG_WIREGUARD := m
-ifneq ($(CONFIG_SMP),)
-ccflags-y += -DCONFIG_WIREGUARD_PARALLEL=y
-endif
endif
include $(src)/compat/Kbuild.include
diff --git a/src/Kconfig b/src/Kconfig
index e86a935..e84aebb 100644
--- a/src/Kconfig
+++ b/src/Kconfig
@@ -16,18 +16,6 @@ config WIREGUARD
It's safe to say Y or M here, as the driver is very lightweight and
is only in use when an administrator chooses to add an interface.
-config WIREGUARD_PARALLEL
- bool "Enable parallel engine"
- depends on SMP && WIREGUARD
- select PADATA
- default y
- ---help---
- This will allow WireGuard to utilize all CPU cores when encrypting
- and decrypting packets.
-
- It's safe to say Y here, and you probably should, as the performance
- improvements are substantial.
-
config WIREGUARD_DEBUG
bool "Debugging checks and verbose messages"
depends on WIREGUARD
diff --git a/src/compat/Kbuild.include b/src/compat/Kbuild.include
index 688a573..aacc9f6 100644
--- a/src/compat/Kbuild.include
+++ b/src/compat/Kbuild.include
@@ -31,11 +31,3 @@ ifeq ($(shell grep -F "int crypto_memneq" "$(srctree)/include/crypto/algapi.h"),
ccflags-y += -include $(src)/compat/memneq/include.h
wireguard-y += compat/memneq/memneq.o
endif
-
-ifneq ($(KBUILD_EXTMOD),)
-ifneq ($(CONFIG_SMP),)
-ifeq (,$(filter $(CONFIG_PADATA),y m))
-wireguard-y += compat/padata/padata.o
-endif
-endif
-endif
diff --git a/src/compat/padata/padata.c b/src/compat/padata/padata.c
deleted file mode 100644
index fa6acac..0000000
--- a/src/compat/padata/padata.c
+++ /dev/null
@@ -1,903 +0,0 @@
-/*
- * padata.c - generic interface to process data streams in parallel
- *
- * See Documentation/padata.txt for an api documentation.
- *
- * Copyright (C) 2008, 2009 secunet Security Networks AG
- * Copyright (C) 2008, 2009 Steffen Klassert <steffen.klassert@secunet.com>
- *
- * This program is free software; you can redistribute it and/or modify it
- * under the terms and conditions of the GNU General Public License,
- * version 2, as published by the Free Software Foundation.
- *
- * This program is distributed in the hope it will be useful, but WITHOUT
- * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
- * more details.
- *
- * You should have received a copy of the GNU General Public License along with
- * this program; if not, write to the Free Software Foundation, Inc.,
- * 51 Franklin St - Fifth Floor, Boston, MA 02110-1301 USA.
- */
-
-#include <linux/export.h>
-#include <linux/cpumask.h>
-#include <linux/err.h>
-#include <linux/cpu.h>
-#include <linux/padata.h>
-#include <linux/mutex.h>
-#include <linux/sched.h>
-#include <linux/slab.h>
-#include <linux/sysfs.h>
-#include <linux/rcupdate.h>
-#include <linux/module.h>
-#include <linux/version.h>
-
-#define MAX_OBJ_NUM 1000
-
-static int padata_index_to_cpu(struct parallel_data *pd, int cpu_index)
-{
- int cpu, target_cpu;
-
- target_cpu = cpumask_first(pd->cpumask.pcpu);
- for (cpu = 0; cpu < cpu_index; cpu++)
- target_cpu = cpumask_next(target_cpu, pd->cpumask.pcpu);
-
- return target_cpu;
-}
-
-static int padata_cpu_hash(struct parallel_data *pd)
-{
- int cpu_index;
- /*
- * Hash the sequence numbers to the cpus by taking
- * seq_nr mod. number of cpus in use.
- */
-#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 13, 0)
- spin_lock(&pd->seq_lock);
- cpu_index = pd->seq_nr % cpumask_weight(pd->cpumask.pcpu);
- pd->seq_nr++;
- spin_unlock(&pd->seq_lock);
-#else
-#ifdef CONFIG_PAX_REFCOUNT
- unsigned int seq_nr = atomic_inc_return_unchecked(&pd->seq_nr);
-#else
- unsigned int seq_nr = atomic_inc_return(&pd->seq_nr);
-#endif
- cpu_index = seq_nr % cpumask_weight(pd->cpumask.pcpu);
-#endif
-
- return padata_index_to_cpu(pd, cpu_index);
-}
-
-static void padata_parallel_worker(struct work_struct *parallel_work)
-{
- struct padata_parallel_queue *pqueue;
- LIST_HEAD(local_list);
-
- local_bh_disable();
- pqueue = container_of(parallel_work,
- struct padata_parallel_queue, work);
-
- spin_lock(&pqueue->parallel.lock);
- list_replace_init(&pqueue->parallel.list, &local_list);
- spin_unlock(&pqueue->parallel.lock);
-
- while (!list_empty(&local_list)) {
- struct padata_priv *padata;
-
- padata = list_entry(local_list.next,
- struct padata_priv, list);
-
- list_del_init(&padata->list);
-
- padata->parallel(padata);
- }
-
- local_bh_enable();
-}
-
-/**
- * padata_do_parallel - padata parallelization function
- *
- * @pinst: padata instance
- * @padata: object to be parallelized
- * @cb_cpu: cpu the serialization callback function will run on,
- * must be in the serial cpumask of padata(i.e. cpumask.cbcpu).
- *
- * The parallelization callback function will run with BHs off.
- * Note: Every object which is parallelized by padata_do_parallel
- * must be seen by padata_do_serial.
- */
-int padata_do_parallel(struct padata_instance *pinst,
- struct padata_priv *padata, int cb_cpu)
-{
- int target_cpu, err;
- struct padata_parallel_queue *queue;
- struct parallel_data *pd;
-
- rcu_read_lock_bh();
-
- pd = rcu_dereference_bh(pinst->pd);
-
- err = -EINVAL;
- if (!(pinst->flags & PADATA_INIT) || pinst->flags & PADATA_INVALID)
- goto out;
-
- if (!cpumask_test_cpu(cb_cpu, pd->cpumask.cbcpu))
- goto out;
-
- err = -EBUSY;
- if ((pinst->flags & PADATA_RESET))
- goto out;
-
- if (atomic_read(&pd->refcnt) >= MAX_OBJ_NUM)
- goto out;
-
- err = 0;
- atomic_inc(&pd->refcnt);
- padata->pd = pd;
- padata->cb_cpu = cb_cpu;
-
- target_cpu = padata_cpu_hash(pd);
- queue = per_cpu_ptr(pd->pqueue, target_cpu);
-
- spin_lock(&queue->parallel.lock);
- list_add_tail(&padata->list, &queue->parallel.list);
- spin_unlock(&queue->parallel.lock);
-
- queue_work_on(target_cpu, pinst->wq, &queue->work);
-
-out:
- rcu_read_unlock_bh();
-
- return err;
-}
-
-/*
- * padata_get_next - Get the next object that needs serialization.
- *
- * Return values are:
- *
- * A pointer to the control struct of the next object that needs
- * serialization, if present in one of the percpu reorder queues.
- *
- * -EINPROGRESS, if the next object that needs serialization will
- * be parallel processed by another cpu and is not yet present in
- * the cpu's reorder queue.
- *
- * -ENODATA, if this cpu has to do the parallel processing for
- * the next object.
- */
-static struct padata_priv *padata_get_next(struct parallel_data *pd)
-{
- int cpu, num_cpus;
- unsigned int next_nr, next_index;
- struct padata_parallel_queue *next_queue;
- struct padata_priv *padata;
- struct padata_list *reorder;
-
- num_cpus = cpumask_weight(pd->cpumask.pcpu);
-
- /*
- * Calculate the percpu reorder queue and the sequence
- * number of the next object.
- */
- next_nr = pd->processed;
- next_index = next_nr % num_cpus;
- cpu = padata_index_to_cpu(pd, next_index);
- next_queue = per_cpu_ptr(pd->pqueue, cpu);
-
- reorder = &next_queue->reorder;
-
- spin_lock(&reorder->lock);
- if (!list_empty(&reorder->list)) {
- padata = list_entry(reorder->list.next,
- struct padata_priv, list);
-
- list_del_init(&padata->list);
- atomic_dec(&pd->reorder_objects);
-
- pd->processed++;
-
- spin_unlock(&reorder->lock);
- goto out;
- }
- spin_unlock(&reorder->lock);
-
- if (__this_cpu_read(pd->pqueue->cpu_index) == next_queue->cpu_index) {
- padata = ERR_PTR(-ENODATA);
- goto out;
- }
-
- padata = ERR_PTR(-EINPROGRESS);
-out:
- return padata;
-}
-
-static void padata_reorder(struct parallel_data *pd)
-{
- int cb_cpu;
- struct padata_priv *padata;
- struct padata_serial_queue *squeue;
- struct padata_instance *pinst = pd->pinst;
-
- /*
- * We need to ensure that only one cpu can work on dequeueing of
- * the reorder queue the time. Calculating in which percpu reorder
- * queue the next object will arrive takes some time. A spinlock
- * would be highly contended. Also it is not clear in which order
- * the objects arrive to the reorder queues. So a cpu could wait to
- * get the lock just to notice that there is nothing to do at the
- * moment. Therefore we use a trylock and let the holder of the lock
- * care for all the objects enqueued during the holdtime of the lock.
- */
- if (!spin_trylock_bh(&pd->lock))
- return;
-
- while (1) {
- padata = padata_get_next(pd);
-
- /*
- * If the next object that needs serialization is parallel
- * processed by another cpu and is still on it's way to the
- * cpu's reorder queue, nothing to do for now.
- */
- if (PTR_ERR(padata) == -EINPROGRESS)
- break;
-
- /*
- * This cpu has to do the parallel processing of the next
- * object. It's waiting in the cpu's parallelization queue,
- * so exit immediately.
- */
- if (PTR_ERR(padata) == -ENODATA) {
- del_timer(&pd->timer);
- spin_unlock_bh(&pd->lock);
- return;
- }
-
- cb_cpu = padata->cb_cpu;
- squeue = per_cpu_ptr(pd->squeue, cb_cpu);
-
- spin_lock(&squeue->serial.lock);
- list_add_tail(&padata->list, &squeue->serial.list);
- spin_unlock(&squeue->serial.lock);
-
- queue_work_on(cb_cpu, pinst->wq, &squeue->work);
- }
-
- spin_unlock_bh(&pd->lock);
-
- /*
- * The next object that needs serialization might have arrived to
- * the reorder queues in the meantime, we will be called again
- * from the timer function if no one else cares for it.
- */
- if (atomic_read(&pd->reorder_objects)
- && !(pinst->flags & PADATA_RESET))
- mod_timer(&pd->timer, jiffies + HZ);
- else
- del_timer(&pd->timer);
-
- return;
-}
-
-static void padata_reorder_timer(unsigned long arg)
-{
- struct parallel_data *pd = (struct parallel_data *)arg;
-
- padata_reorder(pd);
-}
-
-static void padata_serial_worker(struct work_struct *serial_work)
-{
- struct padata_serial_queue *squeue;
- struct parallel_data *pd;
- LIST_HEAD(local_list);
-
- local_bh_disable();
- squeue = container_of(serial_work, struct padata_serial_queue, work);
- pd = squeue->pd;
-
- spin_lock(&squeue->serial.lock);
- list_replace_init(&squeue->serial.list, &local_list);
- spin_unlock(&squeue->serial.lock);
-
- while (!list_empty(&local_list)) {
- struct padata_priv *padata;
-
- padata = list_entry(local_list.next,
- struct padata_priv, list);
-
- list_del_init(&padata->list);
-
- padata->serial(padata);
- atomic_dec(&pd->refcnt);
- }
- local_bh_enable();
-}
-
-/**
- * padata_do_serial - padata serialization function
- *
- * @padata: object to be serialized.
- *
- * padata_do_serial must be called for every parallelized object.
- * The serialization callback function will run with BHs off.
- */
-void padata_do_serial(struct padata_priv *padata)
-{
- int cpu;
- struct padata_parallel_queue *pqueue;
- struct parallel_data *pd;
-
- pd = padata->pd;
-
- cpu = get_cpu();
- pqueue = per_cpu_ptr(pd->pqueue, cpu);
-
- spin_lock(&pqueue->reorder.lock);
- atomic_inc(&pd->reorder_objects);
- list_add_tail(&padata->list, &pqueue->reorder.list);
- spin_unlock(&pqueue->reorder.lock);
-
- put_cpu();
-
- padata_reorder(pd);
-}
-
-static int padata_setup_cpumasks(struct parallel_data *pd,
- const struct cpumask *pcpumask,
- const struct cpumask *cbcpumask)
-{
- if (!alloc_cpumask_var(&pd->cpumask.pcpu, GFP_KERNEL))
- return -ENOMEM;
-
- cpumask_and(pd->cpumask.pcpu, pcpumask, cpu_online_mask);
- if (!alloc_cpumask_var(&pd->cpumask.cbcpu, GFP_KERNEL)) {
- free_cpumask_var(pd->cpumask.pcpu);
- return -ENOMEM;
- }
-
- cpumask_and(pd->cpumask.cbcpu, cbcpumask, cpu_online_mask);
- return 0;
-}
-
-static void __padata_list_init(struct padata_list *pd_list)
-{
- INIT_LIST_HEAD(&pd_list->list);
- spin_lock_init(&pd_list->lock);
-}
-
-/* Initialize all percpu queues used by serial workers */
-static void padata_init_squeues(struct parallel_data *pd)
-{
- int cpu;
- struct padata_serial_queue *squeue;
-
- for_each_cpu(cpu, pd->cpumask.cbcpu) {
- squeue = per_cpu_ptr(pd->squeue, cpu);
- squeue->pd = pd;
- __padata_list_init(&squeue->serial);
- INIT_WORK(&squeue->work, padata_serial_worker);
- }
-}
-
-/* Initialize all percpu queues used by parallel workers */
-static void padata_init_pqueues(struct parallel_data *pd)
-{
- int cpu_index, cpu;
- struct padata_parallel_queue *pqueue;
-
- cpu_index = 0;
- for_each_cpu(cpu, pd->cpumask.pcpu) {
- pqueue = per_cpu_ptr(pd->pqueue, cpu);
- pqueue->pd = pd;
- pqueue->cpu_index = cpu_index;
- cpu_index++;
-
- __padata_list_init(&pqueue->reorder);
- __padata_list_init(&pqueue->parallel);
- INIT_WORK(&pqueue->work, padata_parallel_worker);
- atomic_set(&pqueue->num_obj, 0);
- }
-}
-
-/* Allocate and initialize the internal cpumask dependend resources. */
-static struct parallel_data *padata_alloc_pd(struct padata_instance *pinst,
- const struct cpumask *pcpumask,
- const struct cpumask *cbcpumask)
-{
- struct parallel_data *pd;
-
- pd = kzalloc(sizeof(struct parallel_data), GFP_KERNEL);
- if (!pd)
- goto err;
-
- pd->pqueue = alloc_percpu(struct padata_parallel_queue);
- if (!pd->pqueue)
- goto err_free_pd;
-
- pd->squeue = alloc_percpu(struct padata_serial_queue);
- if (!pd->squeue)
- goto err_free_pqueue;
- if (padata_setup_cpumasks(pd, pcpumask, cbcpumask) < 0)
- goto err_free_squeue;
-
- padata_init_pqueues(pd);
- padata_init_squeues(pd);
- setup_timer(&pd->timer, padata_reorder_timer, (unsigned long)pd);
-#if LINUX_VERSION_CODE < KERNEL_VERSION(3, 13, 0)
- pd->seq_nr = 0;
-#else
-#ifdef CONFIG_PAX_REFCOUNT
- atomic_set_unchecked(&pd->seq_nr, -1);
-#else
- atomic_set(&pd->seq_nr, -1);
-#endif
-#endif
- atomic_set(&pd->reorder_objects, 0);
- atomic_set(&pd->refcnt, 0);
- pd->pinst = pinst;
- spin_lock_init(&pd->lock);
-
- return pd;
-
-err_free_squeue:
- free_percpu(pd->squeue);
-err_free_pqueue:
- free_percpu(pd->pqueue);
-err_free_pd:
- kfree(pd);
-err:
- return NULL;
-}
-
-static void padata_free_pd(struct parallel_data *pd)
-{
- free_cpumask_var(pd->cpumask.pcpu);
- free_cpumask_var(pd->cpumask.cbcpu);
- free_percpu(pd->pqueue);
- free_percpu(pd->squeue);
- kfree(pd);
-}
-
-/* Flush all objects out of the padata queues. */
-static void padata_flush_queues(struct parallel_data *pd)
-{
- int cpu;
- struct padata_parallel_queue *pqueue;
- struct padata_serial_queue *squeue;
-
- for_each_cpu(cpu, pd->cpumask.pcpu) {
- pqueue = per_cpu_ptr(pd->pqueue, cpu);
- flush_work(&pqueue->work);
- }
-
- del_timer_sync(&pd->timer);
-
- if (atomic_read(&pd->reorder_objects))
- padata_reorder(pd);
-
- for_each_cpu(cpu, pd->cpumask.cbcpu) {
- squeue = per_cpu_ptr(pd->squeue, cpu);
- flush_work(&squeue->work);
- }
-
- BUG_ON(atomic_read(&pd->refcnt) != 0);
-}
-
-static void __padata_start(struct padata_instance *pinst)
-{
- pinst->flags |= PADATA_INIT;
-}
-
-static void __padata_stop(struct padata_instance *pinst)
-{
- if (!(pinst->flags & PADATA_INIT))
- return;
-
- pinst->flags &= ~PADATA_INIT;
-
- synchronize_rcu();
-
- get_online_cpus();
- padata_flush_queues(pinst->pd);
- put_online_cpus();
-}
-
-/* Replace the internal control structure with a new one. */
-static void padata_replace(struct padata_instance *pinst,
- struct parallel_data *pd_new)
-{
- struct parallel_data *pd_old = pinst->pd;
- int notification_mask = 0;
-
- pinst->flags |= PADATA_RESET;
-
- rcu_assign_pointer(pinst->pd, pd_new);
-
- synchronize_rcu();
-
- if (!cpumask_equal(pd_old->cpumask.pcpu, pd_new->cpumask.pcpu))
- notification_mask |= PADATA_CPU_PARALLEL;
- if (!cpumask_equal(pd_old->cpumask.cbcpu, pd_new->cpumask.cbcpu))
- notification_mask |= PADATA_CPU_SERIAL;
-
- padata_flush_queues(pd_old);
- padata_free_pd(pd_old);
-
- if (notification_mask)
- blocking_notifier_call_chain(&pinst->cpumask_change_notifier,
- notification_mask,
- &pd_new->cpumask);
-
- pinst->flags &= ~PADATA_RESET;
-}
-
-/**
- * padata_register_cpumask_notifier - Registers a notifier that will be called
- * if either pcpu or cbcpu or both cpumasks change.
- *
- * @pinst: A poineter to padata instance
- * @nblock: A pointer to notifier block.
- */
-int padata_register_cpumask_notifier(struct padata_instance *pinst,
- struct notifier_block *nblock)
-{
- return blocking_notifier_chain_register(&pinst->cpumask_change_notifier,
- nblock);
-}
-
-/**
- * padata_unregister_cpumask_notifier - Unregisters cpumask notifier
- * registered earlier using padata_register_cpumask_notifier
- *
- * @pinst: A pointer to data instance.
- * @nlock: A pointer to notifier block.
- */
-int padata_unregister_cpumask_notifier(struct padata_instance *pinst,
- struct notifier_block *nblock)
-{
- return blocking_notifier_chain_unregister(
- &pinst->cpumask_change_notifier,
- nblock);
-}
-
-
-/* If cpumask contains no active cpu, we mark the instance as invalid. */
-static bool padata_validate_cpumask(struct padata_instance *pinst,
- const struct cpumask *cpumask)
-{
- if (!cpumask_intersects(cpumask, cpu_online_mask)) {
- pinst->flags |= PADATA_INVALID;
- return false;
- }
-
- pinst->flags &= ~PADATA_INVALID;
- return true;
-}
-
-static int __padata_set_cpumasks(struct padata_instance *pinst,
- cpumask_var_t pcpumask,
- cpumask_var_t cbcpumask)
-{
- int valid;
- struct parallel_data *pd;
-
- valid = padata_validate_cpumask(pinst, pcpumask);
- if (!valid) {
- __padata_stop(pinst);
- goto out_replace;
- }
-
- valid = padata_validate_cpumask(pinst, cbcpumask);
- if (!valid)
- __padata_stop(pinst);
-
-out_replace:
- pd = padata_alloc_pd(pinst, pcpumask, cbcpumask);
- if (!pd)
- return -ENOMEM;
-
- cpumask_copy(pinst->cpumask.pcpu, pcpumask);
- cpumask_copy(pinst->cpumask.cbcpu, cbcpumask);
-
- padata_replace(pinst, pd);
-
- if (valid)
- __padata_start(pinst);
-
- return 0;
-}
-
-/**
- * padata_set_cpumask: Sets specified by @cpumask_type cpumask to the value
- * equivalent to @cpumask.
- *
- * @pinst: padata instance
- * @cpumask_type: PADATA_CPU_SERIAL or PADATA_CPU_PARALLEL corresponding
- * to parallel and serial cpumasks respectively.
- * @cpumask: the cpumask to use
- */
-int padata_set_cpumask(struct padata_instance *pinst, int cpumask_type,
- cpumask_var_t cpumask)
-{
- struct cpumask *serial_mask, *parallel_mask;
- int err = -EINVAL;
-
- mutex_lock(&pinst->lock);
- get_online_cpus();
-
- switch (cpumask_type) {
- case PADATA_CPU_PARALLEL:
- serial_mask = pinst->cpumask.cbcpu;
- parallel_mask = cpumask;
- break;
- case PADATA_CPU_SERIAL:
- parallel_mask = pinst->cpumask.pcpu;
- serial_mask = cpumask;
- break;
- default:
- goto out;
- }
-
- err = __padata_set_cpumasks(pinst, parallel_mask, serial_mask);
-
-out:
- put_online_cpus();
- mutex_unlock(&pinst->lock);
-
- return err;
-}
-
-/**
- * padata_start - start the parallel processing
- *
- * @pinst: padata instance to start
- */
-int padata_start(struct padata_instance *pinst)
-{
- int err = 0;
-
- mutex_lock(&pinst->lock);
-
- if (pinst->flags & PADATA_INVALID)
- err = -EINVAL;
-
- __padata_start(pinst);
-
- mutex_unlock(&pinst->lock);
-
- return err;
-}
-
-/**
- * padata_stop - stop the parallel processing
- *
- * @pinst: padata instance to stop
- */
-void padata_stop(struct padata_instance *pinst)
-{
- mutex_lock(&pinst->lock);
- __padata_stop(pinst);
- mutex_unlock(&pinst->lock);
-}
-
-static void __padata_free(struct padata_instance *pinst)
-{
- padata_stop(pinst);
- padata_free_pd(pinst->pd);
- free_cpumask_var(pinst->cpumask.pcpu);
- free_cpumask_var(pinst->cpumask.cbcpu);
- kfree(pinst);
-}
-
-#define kobj2pinst(_kobj) \
- container_of(_kobj, struct padata_instance, kobj)
-#define attr2pentry(_attr) \
- container_of(_attr, struct padata_sysfs_entry, attr)
-
-static void padata_sysfs_release(struct kobject *kobj)
-{
- struct padata_instance *pinst = kobj2pinst(kobj);
- __padata_free(pinst);
-}
-
-struct padata_sysfs_entry {
- struct attribute attr;
- ssize_t (*show)(struct padata_instance *, struct attribute *, char *);
- ssize_t (*store)(struct padata_instance *, struct attribute *,
- const char *, size_t);
-};
-
-static ssize_t show_cpumask(struct padata_instance *pinst,
- struct attribute *attr, char *buf)
-{
- struct cpumask *cpumask;
- ssize_t len;
-
- mutex_lock(&pinst->lock);
- if (!strcmp(attr->name, "serial_cpumask"))
- cpumask = pinst->cpumask.cbcpu;
- else
- cpumask = pinst->cpumask.pcpu;
-
- len = snprintf(buf, PAGE_SIZE, "%*pb\n",
- nr_cpu_ids, cpumask_bits(cpumask));
- mutex_unlock(&pinst->lock);
- return len < PAGE_SIZE ? len : -EINVAL;
-}
-
-static ssize_t store_cpumask(struct padata_instance *pinst,
- struct attribute *attr,
- const char *buf, size_t count)
-{
- cpumask_var_t new_cpumask;
- ssize_t ret;
- int mask_type;
-
- if (!alloc_cpumask_var(&new_cpumask, GFP_KERNEL))
- return -ENOMEM;
-
- ret = bitmap_parse(buf, count, cpumask_bits(new_cpumask),
- nr_cpumask_bits);
- if (ret < 0)
- goto out;
-
- mask_type = !strcmp(attr->name, "serial_cpumask") ?
- PADATA_CPU_SERIAL : PADATA_CPU_PARALLEL;
- ret = padata_set_cpumask(pinst, mask_type, new_cpumask);
- if (!ret)
- ret = count;
-
-out:
- free_cpumask_var(new_cpumask);
- return ret;
-}
-
-#define PADATA_ATTR_RW(_name, _show_name, _store_name) \
- static struct padata_sysfs_entry _name##_attr = \
- __ATTR(_name, 0644, _show_name, _store_name)
-#define PADATA_ATTR_RO(_name, _show_name) \
- static struct padata_sysfs_entry _name##_attr = \
- __ATTR(_name, 0400, _show_name, NULL)
-
-PADATA_ATTR_RW(serial_cpumask, show_cpumask, store_cpumask);
-PADATA_ATTR_RW(parallel_cpumask, show_cpumask, store_cpumask);
-
-/*
- * Padata sysfs provides the following objects:
- * serial_cpumask [RW] - cpumask for serial workers
- * parallel_cpumask [RW] - cpumask for parallel workers
- */
-static struct attribute *padata_default_attrs[] = {
- &serial_cpumask_attr.attr,
- &parallel_cpumask_attr.attr,
- NULL,
-};
-
-static ssize_t padata_sysfs_show(struct kobject *kobj,
- struct attribute *attr, char *buf)
-{
- struct padata_instance *pinst;
- struct padata_sysfs_entry *pentry;
- ssize_t ret = -EIO;
-
- pinst = kobj2pinst(kobj);
- pentry = attr2pentry(attr);
- if (pentry->show)
- ret = pentry->show(pinst, attr, buf);
-
- return ret;
-}
-
-static ssize_t padata_sysfs_store(struct kobject *kobj, struct attribute *attr,
- const char *buf, size_t count)
-{
- struct padata_instance *pinst;
- struct padata_sysfs_entry *pentry;
- ssize_t ret = -EIO;
-
- pinst = kobj2pinst(kobj);
- pentry = attr2pentry(attr);
- if (pentry->show)
- ret = pentry->store(pinst, attr, buf, count);
-
- return ret;
-}
-
-static const struct sysfs_ops padata_sysfs_ops = {
- .show = padata_sysfs_show,
- .store = padata_sysfs_store,
-};
-
-static struct kobj_type padata_attr_type = {
- .sysfs_ops = &padata_sysfs_ops,
- .default_attrs = padata_default_attrs,
- .release = padata_sysfs_release,
-};
-
-/**
- * padata_alloc - allocate and initialize a padata instance and specify
- * cpumasks for serial and parallel workers.
- *
- * @wq: workqueue to use for the allocated padata instance
- * @pcpumask: cpumask that will be used for padata parallelization
- * @cbcpumask: cpumask that will be used for padata serialization
- */
-struct padata_instance *padata_alloc(struct workqueue_struct *wq,
- const struct cpumask *pcpumask,
- const struct cpumask *cbcpumask)
-{
- struct padata_instance *pinst;
- struct parallel_data *pd = NULL;
-
- pinst = kzalloc(sizeof(struct padata_instance), GFP_KERNEL);
- if (!pinst)
- goto err;
-
- get_online_cpus();
- if (!alloc_cpumask_var(&pinst->cpumask.pcpu, GFP_KERNEL))
- goto err_free_inst;
- if (!alloc_cpumask_var(&pinst->cpumask.cbcpu, GFP_KERNEL)) {
- free_cpumask_var(pinst->cpumask.pcpu);
- goto err_free_inst;
- }
- if (!padata_validate_cpumask(pinst, pcpumask) ||
- !padata_validate_cpumask(pinst, cbcpumask))
- goto err_free_masks;
-
- pd = padata_alloc_pd(pinst, pcpumask, cbcpumask);
- if (!pd)
- goto err_free_masks;
-
- rcu_assign_pointer(pinst->pd, pd);
-
- pinst->wq = wq;
-
- cpumask_copy(pinst->cpumask.pcpu, pcpumask);
- cpumask_copy(pinst->cpumask.cbcpu, cbcpumask);
-
- pinst->flags = 0;
-
- put_online_cpus();
-
- BLOCKING_INIT_NOTIFIER_HEAD(&pinst->cpumask_change_notifier);
- kobject_init(&pinst->kobj, &padata_attr_type);
- mutex_init(&pinst->lock);
-
- return pinst;
-
-err_free_masks:
- free_cpumask_var(pinst->cpumask.pcpu);
- free_cpumask_var(pinst->cpumask.cbcpu);
-err_free_inst:
- kfree(pinst);
- put_online_cpus();
-err:
- return NULL;
-}
-
-/**
- * padata_alloc_possible - Allocate and initialize padata instance.
- * Use the cpu_possible_mask for serial and
- * parallel workers.
- *
- * @wq: workqueue to use for the allocated padata instance
- */
-struct padata_instance *padata_alloc_possible(struct workqueue_struct *wq)
-{
- return padata_alloc(wq, cpu_possible_mask, cpu_possible_mask);
-}
-
-/**
- * padata_free - free a padata instance
- *
- * @padata_inst: padata instance to free
- */
-void padata_free(struct padata_instance *pinst)
-{
- kobject_put(&pinst->kobj);
-}
diff --git a/src/config.c b/src/config.c
index 7ffc529..bf8557c 100644
--- a/src/config.c
+++ b/src/config.c
@@ -3,7 +3,7 @@
#include "config.h"
#include "device.h"
#include "socket.h"
-#include "packets.h"
+#include "queueing.h"
#include "timers.h"
#include "hashtables.h"
#include "peer.h"
@@ -114,7 +114,7 @@ static int set_peer(struct wireguard_device *wg, void __user *user_peer, size_t
}
if (wg->dev->flags & IFF_UP)
- packet_send_queue(peer);
+ packet_send_staged_packets(peer);
peer_put(peer);
diff --git a/src/data.c b/src/data.c
deleted file mode 100644
index fb91861..0000000
--- a/src/data.c
+++ /dev/null
@@ -1,430 +0,0 @@
-/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
-
-#include "noise.h"
-#include "device.h"
-#include "peer.h"
-#include "messages.h"
-#include "packets.h"
-#include "hashtables.h"
-
-#include <linux/rcupdate.h>
-#include <linux/slab.h>
-#include <linux/bitmap.h>
-#include <linux/scatterlist.h>
-#include <net/ip_tunnels.h>
-#include <net/xfrm.h>
-#include <crypto/algapi.h>
-
-struct encryption_ctx {
- struct padata_priv padata;
- struct sk_buff_head queue;
- struct wireguard_peer *peer;
- struct noise_keypair *keypair;
-};
-
-struct decryption_ctx {
- struct padata_priv padata;
- struct endpoint endpoint;
- struct sk_buff *skb;
- struct noise_keypair *keypair;
-};
-
-#ifdef CONFIG_WIREGUARD_PARALLEL
-static struct kmem_cache *encryption_ctx_cache __read_mostly;
-static struct kmem_cache *decryption_ctx_cache __read_mostly;
-
-int __init packet_init_data_caches(void)
-{
- encryption_ctx_cache = KMEM_CACHE(encryption_ctx, 0);
- if (!encryption_ctx_cache)
- return -ENOMEM;
- decryption_ctx_cache = KMEM_CACHE(decryption_ctx, 0);
- if (!decryption_ctx_cache) {
- kmem_cache_destroy(encryption_ctx_cache);
- return -ENOMEM;
- }
- return 0;
-}
-
-void packet_deinit_data_caches(void)
-{
- kmem_cache_destroy(encryption_ctx_cache);
- kmem_cache_destroy(decryption_ctx_cache);
-}
-#endif
-
-/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
-static inline bool counter_validate(union noise_counter *counter, u64 their_counter)
-{
- bool ret = false;
- unsigned long index, index_current, top, i;
- spin_lock_bh(&counter->receive.lock);
-
- if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || their_counter >= REJECT_AFTER_MESSAGES))
- goto out;
-
- ++their_counter;
-
- if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < counter->receive.counter))
- goto out;
-
- index = their_counter >> ilog2(BITS_PER_LONG);
-
- if (likely(their_counter > counter->receive.counter)) {
- index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
- top = min_t(unsigned long, index - index_current, COUNTER_BITS_TOTAL / BITS_PER_LONG);
- for (i = 1; i <= top; ++i)
- counter->receive.backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
- counter->receive.counter = their_counter;
- }
-
- index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
- ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), &counter->receive.backtrack[index]);
-
-out:
- spin_unlock_bh(&counter->receive.lock);
- return ret;
-}
-#include "selftest/counter.h"
-
-static inline unsigned int skb_padding(struct sk_buff *skb)
-{
- /* We do this modulo business with the MTU, just in case the networking layer
- * gives us a packet that's bigger than the MTU. Now that we support GSO, this
- * shouldn't be a real problem, and this can likely be removed. But, caution! */
- unsigned int last_unit = skb->len % skb->dev->mtu;
- unsigned int padded_size = (last_unit + MESSAGE_PADDING_MULTIPLE - 1) & ~(MESSAGE_PADDING_MULTIPLE - 1);
- if (padded_size > skb->dev->mtu)
- padded_size = skb->dev->mtu;
- return padded_size - last_unit;
-}
-
-static inline void skb_reset(struct sk_buff *skb)
-{
- skb_scrub_packet(skb, false);
- memset(&skb->headers_start, 0, offsetof(struct sk_buff, headers_end) - offsetof(struct sk_buff, headers_start));
- skb->queue_mapping = 0;
- skb->nohdr = 0;
- skb->peeked = 0;
- skb->mac_len = 0;
- skb->dev = NULL;
-#ifdef CONFIG_NET_SCHED
- skb->tc_index = 0;
- skb_reset_tc(skb);
-#endif
- skb->hdr_len = skb_headroom(skb);
- skb_reset_mac_header(skb);
- skb_reset_network_header(skb);
- skb_probe_transport_header(skb, 0);
- skb_reset_inner_headers(skb);
-}
-
-static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, bool have_simd)
-{
- struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1];
- struct message_data *header;
- unsigned int padding_len, plaintext_len, trailer_len;
- int num_frags;
- struct sk_buff *trailer;
-
- /* Store the ds bit in the cb */
- PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb);
-
- /* Calculate lengths */
- padding_len = skb_padding(skb);
- trailer_len = padding_len + noise_encrypted_len(0);
- plaintext_len = skb->len + padding_len;
-
- /* Expand data section to have room for padding and auth tag */
- num_frags = skb_cow_data(skb, trailer_len, &trailer);
- if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
- return false;
-
- /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */
- memset(skb_tail_pointer(trailer), 0, padding_len);
-
- /* Expand head section to have room for our header and the network stack's headers. */
- if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0))
- return false;
-
- /* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */
- if (likely(!skb_checksum_setup(skb, true)))
- skb_checksum_help(skb);
-
- /* Only after checksumming can we safely add on the padding at the end and the header. */
- header = (struct message_data *)skb_push(skb, sizeof(struct message_data));
- header->header.type = cpu_to_le32(MESSAGE_DATA);
- header->key_idx = keypair->remote_index;
- header->counter = cpu_to_le64(PACKET_CB(skb)->nonce);
- pskb_put(skb, trailer, trailer_len);
-
- /* Now we can encrypt the scattergather segments */
- sg_init_table(sg, num_frags);
- if (skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(plaintext_len)) <= 0)
- return false;
- return chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0, PACKET_CB(skb)->nonce, keypair->sending.key, have_simd);
-}
-
-static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *key)
-{
- struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1];
- struct sk_buff *trailer;
- int num_frags;
-
- if (unlikely(!key))
- return false;
-
- if (unlikely(!key->is_valid || time_is_before_eq_jiffies64(key->birthdate + REJECT_AFTER_TIME) || key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
- key->is_valid = false;
- return false;
- }
-
- PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter);
- skb_pull(skb, sizeof(struct message_data));
- num_frags = skb_cow_data(skb, 0, &trailer);
- if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
- return false;
-
- sg_init_table(sg, num_frags);
- if (skb_to_sgvec(skb, sg, 0, skb->len) <= 0)
- return false;
-
- if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key))
- return false;
-
- return !pskb_trim(skb, skb->len - noise_encrypted_len(0));
-}
-
-static inline bool get_encryption_nonce(u64 *nonce, struct noise_symmetric_key *key)
-{
- if (unlikely(!key))
- return false;
-
- if (unlikely(!key->is_valid || time_is_before_eq_jiffies64(key->birthdate + REJECT_AFTER_TIME))) {
- key->is_valid = false;
- return false;
- }
-
- *nonce = atomic64_inc_return(&key->counter.counter) - 1;
- if (*nonce >= REJECT_AFTER_MESSAGES) {
- key->is_valid = false;
- return false;
- }
-
- return true;
-}
-
-static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_keypair *keypair)
-{
- struct sk_buff *skb, *tmp;
- bool have_simd = chacha20poly1305_init_simd();
- skb_queue_walk_safe (queue, skb, tmp) {
- if (unlikely(!skb_encrypt(skb, keypair, have_simd))) {
- __skb_unlink(skb, queue);
- kfree_skb(skb);
- continue;
- }
- skb_reset(skb);
- }
- chacha20poly1305_deinit_simd(have_simd);
- noise_keypair_put(keypair);
-}
-
-#ifdef CONFIG_WIREGUARD_PARALLEL
-static void begin_parallel_encryption(struct padata_priv *padata)
-{
- struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata);
-#if IS_ENABLED(CONFIG_KERNEL_MODE_NEON) && defined(CONFIG_ARM)
- local_bh_enable();
-#endif
- queue_encrypt_reset(&ctx->queue, ctx->keypair);
-#if IS_ENABLED(CONFIG_KERNEL_MODE_NEON) && defined(CONFIG_ARM)
- local_bh_disable();
-#endif
- padata_do_serial(padata);
-}
-
-static void finish_parallel_encryption(struct padata_priv *padata)
-{
- struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata);
- packet_create_data_done(&ctx->queue, ctx->peer);
- atomic_dec(&ctx->peer->parallel_encryption_inflight);
- peer_put(ctx->peer);
- kmem_cache_free(encryption_ctx_cache, ctx);
-}
-
-static inline unsigned int choose_cpu(__le32 key)
-{
- unsigned int cpu_index, cpu, cb_cpu;
-
- /* This ensures that packets encrypted to the same key are sent in-order. */
- cpu_index = ((__force unsigned int)key) % cpumask_weight(cpu_online_mask);
- cb_cpu = cpumask_first(cpu_online_mask);
- for (cpu = 0; cpu < cpu_index; ++cpu)
- cb_cpu = cpumask_next(cb_cpu, cpu_online_mask);
-
- return cb_cpu;
-}
-#endif
-
-int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer)
-{
- int ret = -ENOKEY;
- struct noise_keypair *keypair;
- struct sk_buff *skb;
-
- rcu_read_lock_bh();
- keypair = noise_keypair_get(rcu_dereference_bh(peer->keypairs.current_keypair));
- rcu_read_unlock_bh();
- if (unlikely(!keypair))
- return ret;
-
- skb_queue_walk (queue, skb) {
- if (unlikely(!get_encryption_nonce(&PACKET_CB(skb)->nonce, &keypair->sending)))
- goto err;
-
- /* After the first time through the loop, if we've suceeded with a legitimate nonce,
- * then we don't want a -ENOKEY error if subsequent nonces fail. Rather, if this
- * condition arises, we simply want error out hard, and drop the entire queue. This
- * is partially lazy programming and TODO: this could be made to only requeue the
- * ones that had no nonce. But I'm not sure it's worth the added complexity, given
- * how rarely that condition should arise. */
- ret = -EPIPE;
- }
-
-#ifdef CONFIG_WIREGUARD_PARALLEL
- if ((skb_queue_len(queue) > 1 || queue->next->len > 256 || atomic_read(&peer->parallel_encryption_inflight) > 0) && cpumask_weight(cpu_online_mask) > 1) {
- struct encryption_ctx *ctx = kmem_cache_alloc(encryption_ctx_cache, GFP_ATOMIC);
- if (!ctx)
- goto serial_encrypt;
- skb_queue_head_init(&ctx->queue);
- skb_queue_splice_init(queue, &ctx->queue);
- memset(&ctx->padata, 0, sizeof(ctx->padata));
- ctx->padata.parallel = begin_parallel_encryption;
- ctx->padata.serial = finish_parallel_encryption;
- ctx->keypair = keypair;
- ctx->peer = peer_rcu_get(peer);
- ret = -EBUSY;
- if (unlikely(!ctx->peer))
- goto err_parallel;
- atomic_inc(&peer->parallel_encryption_inflight);
- if (unlikely(padata_do_parallel(peer->device->encrypt_pd, &ctx->padata, choose_cpu(keypair->remote_index)))) {
- atomic_dec(&peer->parallel_encryption_inflight);
- peer_put(ctx->peer);
-err_parallel:
- skb_queue_splice(&ctx->queue, queue);
- kmem_cache_free(encryption_ctx_cache, ctx);
- goto err;
- }
- } else
-serial_encrypt:
-#endif
- {
- queue_encrypt_reset(queue, keypair);
- packet_create_data_done(queue, peer);
- }
- return 0;
-
-err:
- noise_keypair_put(keypair);
- return ret;
-}
-
-static void begin_decrypt_packet(struct decryption_ctx *ctx)
-{
- if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) {
- peer_put(ctx->keypair->entry.peer);
- noise_keypair_put(ctx->keypair);
- dev_kfree_skb(ctx->skb);
- ctx->skb = NULL;
- }
-}
-
-static void finish_decrypt_packet(struct decryption_ctx *ctx)
-{
- bool used_new_key;
-
- if (!ctx->skb)
- return;
-
- if (unlikely(!counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(ctx->skb)->nonce))) {
- net_dbg_ratelimited("%s: Packet has invalid nonce %Lu (max %Lu)\n", ctx->keypair->entry.peer->device->dev->name, PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter);
- peer_put(ctx->keypair->entry.peer);
- noise_keypair_put(ctx->keypair);
- dev_kfree_skb(ctx->skb);
- return;
- }
-
- used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair);
- skb_reset(ctx->skb);
- packet_consume_data_done(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key);
- noise_keypair_put(ctx->keypair);
-}
-
-#ifdef CONFIG_WIREGUARD_PARALLEL
-static void begin_parallel_decryption(struct padata_priv *padata)
-{
- struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata);
-#if IS_ENABLED(CONFIG_KERNEL_MODE_NEON) && defined(CONFIG_ARM)
- local_bh_enable();
-#endif
- begin_decrypt_packet(ctx);
-#if IS_ENABLED(CONFIG_KERNEL_MODE_NEON) && defined(CONFIG_ARM)
- local_bh_disable();
-#endif
- padata_do_serial(padata);
-}
-
-static void finish_parallel_decryption(struct padata_priv *padata)
-{
- struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata);
- finish_decrypt_packet(ctx);
- kmem_cache_free(decryption_ctx_cache, ctx);
-}
-#endif
-
-void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg)
-{
- struct noise_keypair *keypair;
- __le32 idx = ((struct message_data *)skb->data)->key_idx;
-
- rcu_read_lock_bh();
- keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx));
- rcu_read_unlock_bh();
- if (unlikely(!keypair))
- goto err;
-
-#ifdef CONFIG_WIREGUARD_PARALLEL
- if (cpumask_weight(cpu_online_mask) > 1) {
- struct decryption_ctx *ctx = kmem_cache_alloc(decryption_ctx_cache, GFP_ATOMIC);
- if (unlikely(!ctx))
- goto err_peer;
- ctx->skb = skb;
- ctx->keypair = keypair;
- memset(&ctx->padata, 0, sizeof(ctx->padata));
- ctx->padata.parallel = begin_parallel_decryption;
- ctx->padata.serial = finish_parallel_decryption;
- if (unlikely(padata_do_parallel(wg->decrypt_pd, &ctx->padata, choose_cpu(idx)))) {
- kmem_cache_free(decryption_ctx_cache, ctx);
- goto err_peer;
- }
- } else
-#endif
- {
- struct decryption_ctx ctx = {
- .skb = skb,
- .keypair = keypair
- };
- begin_decrypt_packet(&ctx);
- finish_decrypt_packet(&ctx);
- }
- return;
-
-#ifdef CONFIG_WIREGUARD_PARALLEL
-err_peer:
- peer_put(keypair->entry.peer);
- noise_keypair_put(keypair);
-#endif
-err:
- dev_kfree_skb(skb);
-}
diff --git a/src/device.c b/src/device.c
index 2514822..eb1d59c 100644
--- a/src/device.c
+++ b/src/device.c
@@ -1,6 +1,6 @@
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
-#include "packets.h"
+#include "queueing.h"
#include "socket.h"
#include "timers.h"
#include "device.h"
@@ -57,7 +57,7 @@ static int open(struct net_device *dev)
return ret;
peer_for_each (wg, peer, temp, true) {
timers_init_peer(peer);
- packet_send_queue(peer);
+ packet_send_staged_packets(peer);
if (peer->persistent_keepalive_interval)
packet_send_keepalive(peer);
}
@@ -95,11 +95,10 @@ static int stop(struct net_device *dev)
struct wireguard_device *wg = netdev_priv(dev);
struct wireguard_peer *peer, *temp;
peer_for_each (wg, peer, temp, true) {
+ skb_queue_purge(&peer->staged_packet_queue);
timers_uninit_peer(peer);
noise_handshake_clear(&peer->handshake);
noise_keypairs_clear(&peer->keypairs);
- if (peer->timers_enabled)
- del_timer(&peer->timer_zero_key_material);
}
skb_queue_purge(&wg->incoming_handshakes);
socket_uninit(wg);
@@ -111,6 +110,7 @@ static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev)
struct wireguard_device *wg = netdev_priv(dev);
struct wireguard_peer *peer;
struct sk_buff *next;
+ struct sk_buff_head packets;
int ret;
if (unlikely(dev_recursion_level() > 4)) {
@@ -141,11 +141,7 @@ static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev)
goto err_peer;
}
- /* If the queue is getting too big, we start removing the oldest packets until it's small again.
- * We do this before adding the new packet, so we don't remove GSO segments that are in excess. */
- while (skb_queue_len(&peer->tx_packet_queue) > MAX_QUEUED_OUTGOING_PACKETS)
- dev_kfree_skb(skb_dequeue(&peer->tx_packet_queue));
-
+ __skb_queue_head_init(&packets);
if (!skb_is_gso(skb))
skb->next = NULL;
else {
@@ -169,10 +165,19 @@ static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev)
* so at this point we're in a position to drop it. */
skb_dst_drop(skb);
- skb_queue_tail(&peer->tx_packet_queue, skb);
+ __skb_queue_tail(&packets, skb);
} while ((skb = next) != NULL);
- packet_send_queue(peer);
+ spin_lock_bh(&peer->staged_packet_queue.lock);
+ /* If the queue is getting too big, we start removing the oldest packets until it's small again.
+ * We do this before adding the new packet, so we don't remove GSO segments that are in excess. */
+ while (skb_queue_len(&peer->staged_packet_queue) > MAX_STAGED_PACKETS)
+ dev_kfree_skb(__skb_dequeue(&peer->staged_packet_queue));
+ skb_queue_splice_tail(&packets, &peer->staged_packet_queue);
+ spin_unlock_bh(&peer->staged_packet_queue.lock);
+
+ packet_send_staged_packets(peer);
+
peer_put(peer);
return NETDEV_TX_OK;
@@ -220,15 +225,13 @@ static void destruct(struct net_device *dev)
list_del(&wg->device_list);
rtnl_unlock();
mutex_lock(&wg->device_update_lock);
- peer_remove_all(wg);
+ peer_remove_all(wg); /* The final references are cleared in the below calls to destroy_workqueue. */
wg->incoming_port = 0;
- destroy_workqueue(wg->incoming_handshake_wq);
- destroy_workqueue(wg->peer_wq);
-#ifdef CONFIG_WIREGUARD_PARALLEL
- padata_free(wg->encrypt_pd);
- padata_free(wg->decrypt_pd);
- destroy_workqueue(wg->crypt_wq);
-#endif
+ destroy_workqueue(wg->handshake_receive_wq);
+ destroy_workqueue(wg->handshake_send_wq);
+ free_percpu(wg->decrypt_queue.worker);
+ free_percpu(wg->encrypt_queue.worker);
+ destroy_workqueue(wg->packet_crypt_wq);
routing_table_free(&wg->peer_routing_table);
ratelimiter_uninit();
memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity));
@@ -275,7 +278,7 @@ static void setup(struct net_device *dev)
static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *tb[], struct nlattr *data[], struct netlink_ext_ack *extack)
{
- int ret = -ENOMEM, cpu;
+ int ret = -ENOMEM;
struct wireguard_device *wg = netdev_priv(dev);
wg->creating_net = get_net(src_net);
@@ -293,38 +296,27 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t
if (!dev->tstats)
goto error_1;
- wg->incoming_handshakes_worker = alloc_percpu(struct handshake_worker);
+ wg->incoming_handshakes_worker = packet_alloc_percpu_multicore_worker(packet_handshake_receive_worker, wg);
if (!wg->incoming_handshakes_worker)
goto error_2;
- for_each_possible_cpu (cpu) {
- per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->wg = wg;
- INIT_WORK(&per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work, packet_process_queued_handshake_packets);
- }
- atomic_set(&wg->incoming_handshake_seqnr, 0);
- wg->incoming_handshake_wq = alloc_workqueue("wg-kex-%s", WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name);
- if (!wg->incoming_handshake_wq)
+ wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s", WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name);
+ if (!wg->handshake_receive_wq)
goto error_3;
- wg->peer_wq = alloc_workqueue("wg-kex-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name);
- if (!wg->peer_wq)
+ wg->handshake_send_wq = alloc_workqueue("wg-kex-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name);
+ if (!wg->handshake_send_wq)
goto error_4;
-#ifdef CONFIG_WIREGUARD_PARALLEL
- wg->crypt_wq = alloc_workqueue("wg-crypt-%s", WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 2, dev->name);
- if (!wg->crypt_wq)
+ wg->packet_crypt_wq = alloc_workqueue("wg-crypt-%s", WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, dev->name);
+ if (!wg->packet_crypt_wq)
goto error_5;
- wg->encrypt_pd = padata_alloc_possible(wg->crypt_wq);
- if (!wg->encrypt_pd)
+ if (packet_queue_init(&wg->encrypt_queue, packet_encrypt_worker, true) < 0)
goto error_6;
- padata_start(wg->encrypt_pd);
- wg->decrypt_pd = padata_alloc_possible(wg->crypt_wq);
- if (!wg->decrypt_pd)
+ if (packet_queue_init(&wg->decrypt_queue, packet_decrypt_worker, true) < 0)
goto error_7;
- padata_start(wg->decrypt_pd);
-#endif
ret = ratelimiter_init();
if (ret < 0)
@@ -346,17 +338,15 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t
error_9:
ratelimiter_uninit();
error_8:
-#ifdef CONFIG_WIREGUARD_PARALLEL
- padata_free(wg->decrypt_pd);
+ free_percpu(wg->decrypt_queue.worker);
error_7:
- padata_free(wg->encrypt_pd);
+ free_percpu(wg->encrypt_queue.worker);
error_6:
- destroy_workqueue(wg->crypt_wq);
+ destroy_workqueue(wg->packet_crypt_wq);
error_5:
-#endif
- destroy_workqueue(wg->peer_wq);
+ destroy_workqueue(wg->handshake_send_wq);
error_4:
- destroy_workqueue(wg->incoming_handshake_wq);
+ destroy_workqueue(wg->handshake_receive_wq);
error_3:
free_percpu(wg->incoming_handshakes_worker);
error_2:
diff --git a/src/device.h b/src/device.h
index 77f1b2e..66090e7 100644
--- a/src/device.h
+++ b/src/device.h
@@ -13,14 +13,27 @@
#include <linux/workqueue.h>
#include <linux/mutex.h>
#include <linux/net.h>
-#include <linux/padata.h>
struct wireguard_device;
-struct handshake_worker {
- struct wireguard_device *wg;
+
+struct multicore_worker {
+ void *ptr;
struct work_struct work;
};
+struct crypt_queue {
+ spinlock_t lock;
+ struct list_head queue;
+ union {
+ struct {
+ struct multicore_worker __percpu *worker;
+ int last_cpu;
+ };
+ struct work_struct work;
+ };
+ int len;
+};
+
struct wireguard_device {
struct net_device *dev;
struct list_head device_list;
@@ -29,21 +42,17 @@ struct wireguard_device {
u32 fwmark;
struct net *creating_net;
struct noise_static_identity static_identity;
- struct workqueue_struct *incoming_handshake_wq, *peer_wq;
+ struct workqueue_struct *handshake_receive_wq, *handshake_send_wq, *packet_crypt_wq;
struct sk_buff_head incoming_handshakes;
- atomic_t incoming_handshake_seqnr;
- struct handshake_worker __percpu *incoming_handshakes_worker;
+ struct crypt_queue encrypt_queue, decrypt_queue;
+ int incoming_handshake_cpu;
+ struct multicore_worker __percpu *incoming_handshakes_worker;
struct cookie_checker cookie_checker;
struct pubkey_hashtable peer_hashtable;
struct index_hashtable index_hashtable;
struct routing_table peer_routing_table;
struct list_head peer_list;
- struct mutex device_update_lock;
- struct mutex socket_update_lock;
-#ifdef CONFIG_WIREGUARD_PARALLEL
- struct workqueue_struct *crypt_wq;
- struct padata_instance *encrypt_pd, *decrypt_pd;
-#endif
+ struct mutex device_update_lock, socket_update_lock;
};
int device_init(void);
diff --git a/src/main.c b/src/main.c
index 0697741..e2686a2 100644
--- a/src/main.c
+++ b/src/main.c
@@ -3,7 +3,7 @@
#include "version.h"
#include "device.h"
#include "noise.h"
-#include "packets.h"
+#include "queueing.h"
#include "ratelimiter.h"
#include "crypto/chacha20poly1305.h"
#include "crypto/blake2s.h"
@@ -27,11 +27,9 @@ static int __init mod_init(void)
#endif
noise_init();
-#ifdef CONFIG_WIREGUARD_PARALLEL
- ret = packet_init_data_caches();
+ ret = init_crypt_ctx_cache();
if (ret < 0)
goto err_packet;
-#endif
ret = device_init();
if (ret < 0)
@@ -43,19 +41,15 @@ static int __init mod_init(void)
return 0;
err_device:
-#ifdef CONFIG_WIREGUARD_PARALLEL
- packet_deinit_data_caches();
+ deinit_crypt_ctx_cache();
err_packet:
-#endif
return ret;
}
static void __exit mod_exit(void)
{
device_uninit();
-#ifdef CONFIG_WIREGUARD_PARALLEL
- packet_deinit_data_caches();
-#endif
+ deinit_crypt_ctx_cache();
pr_debug("WireGuard unloaded\n");
}
diff --git a/src/messages.h b/src/messages.h
index 2c0658d..490a773 100644
--- a/src/messages.h
+++ b/src/messages.h
@@ -49,8 +49,9 @@ enum limits {
MAX_PEERS_PER_DEVICE = 1 << 20,
KEEPALIVE_TIMEOUT = 10 * HZ,
MAX_TIMER_HANDSHAKES = (90 * HZ) / REKEY_TIMEOUT,
- MAX_QUEUED_INCOMING_HANDSHAKES = 4096,
- MAX_QUEUED_OUTGOING_PACKETS = 1024
+ MAX_QUEUED_INCOMING_HANDSHAKES = 4096, /* TODO: replace this with DQL */
+ MAX_STAGED_PACKETS = 1024,
+ MAX_QUEUED_PACKETS = 1024 /* TODO: replace this with DQL */
};
enum message_type {
diff --git a/src/noise.c b/src/noise.c
index 199c9d5..3b02148 100644
--- a/src/noise.c
+++ b/src/noise.c
@@ -4,7 +4,7 @@
#include "device.h"
#include "peer.h"
#include "messages.h"
-#include "packets.h"
+#include "queueing.h"
#include "hashtables.h"
#include <linux/rcupdate.h>
diff --git a/src/packets.h b/src/packets.h
deleted file mode 100644
index c956c7a..0000000
--- a/src/packets.h
+++ /dev/null
@@ -1,63 +0,0 @@
-/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
-
-#ifndef PACKETS_H
-#define PACKETS_H
-
-#include "noise.h"
-#include "messages.h"
-#include "socket.h"
-
-#include <linux/types.h>
-#include <linux/padata.h>
-#include <linux/skbuff.h>
-#include <linux/ip.h>
-#include <linux/ipv6.h>
-
-struct wireguard_device;
-struct wireguard_peer;
-struct sk_buff;
-
-struct packet_cb {
- u64 nonce;
- u8 ds;
-};
-#define PACKET_CB(skb) ((struct packet_cb *)skb->cb)
-
-/* receive.c */
-void packet_receive(struct wireguard_device *wg, struct sk_buff *skb);
-void packet_process_queued_handshake_packets(struct work_struct *work);
-void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key);
-
-/* send.c */
-void packet_send_queue(struct wireguard_peer *peer);
-void packet_send_keepalive(struct wireguard_peer *peer);
-void packet_queue_handshake_initiation(struct wireguard_peer *peer, bool is_retry);
-void packet_send_queued_handshakes(struct work_struct *work);
-void packet_send_handshake_response(struct wireguard_peer *peer);
-void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index);
-void packet_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *peer);
-
-/* data.c */
-int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer);
-void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg);
-
-/* Returns either the correct skb->protocol value, or 0 if invalid. */
-static inline __be16 skb_examine_untrusted_ip_hdr(struct sk_buff *skb)
-{
- if (skb_network_header(skb) >= skb->head && (skb_network_header(skb) + sizeof(struct iphdr)) <= skb_tail_pointer(skb) && ip_hdr(skb)->version == 4)
- return htons(ETH_P_IP);
- if (skb_network_header(skb) >= skb->head && (skb_network_header(skb) + sizeof(struct ipv6hdr)) <= skb_tail_pointer(skb) && ipv6_hdr(skb)->version == 6)
- return htons(ETH_P_IPV6);
- return 0;
-}
-
-#ifdef CONFIG_WIREGUARD_PARALLEL
-int packet_init_data_caches(void);
-void packet_deinit_data_caches(void);
-#endif
-
-#ifdef DEBUG
-bool packet_counter_selftest(void);
-#endif
-
-#endif
diff --git a/src/peer.c b/src/peer.c
index f539d99..cebda70 100644
--- a/src/peer.c
+++ b/src/peer.c
@@ -2,7 +2,7 @@
#include "peer.h"
#include "device.h"
-#include "packets.h"
+#include "queueing.h"
#include "timers.h"
#include "hashtables.h"
#include "noise.h"
@@ -32,6 +32,7 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_
}
peer->internal_id = atomic64_inc_return(&peer_counter);
+ peer->serial_work_cpu = nr_cpumask_bits;
peer->device = wg;
cookie_init(&peer->latest_cookie);
if (!noise_handshake_init(&peer->handshake, &wg->static_identity, public_key, preshared_key, peer)) {
@@ -40,15 +41,14 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_
}
cookie_checker_precompute_peer_keys(peer);
mutex_init(&peer->keypairs.keypair_update_lock);
- INIT_WORK(&peer->transmit_handshake_work, packet_send_queued_handshakes);
+ INIT_WORK(&peer->transmit_handshake_work, packet_handshake_send_worker);
rwlock_init(&peer->endpoint_lock);
- skb_queue_head_init(&peer->tx_packet_queue);
kref_init(&peer->refcount);
pubkey_hashtable_add(&wg->peer_hashtable, peer);
list_add_tail(&peer->peer_list, &wg->peer_list);
-#ifdef CONFIG_WIREGUARD_PARALLEL
- atomic_set(&peer->parallel_encryption_inflight, 0);
-#endif
+ packet_queue_init(&peer->tx_queue, packet_tx_worker, false);
+ packet_queue_init(&peer->rx_queue, packet_rx_worker, false);
+ skb_queue_head_init(&peer->staged_packet_queue);
pr_debug("%s: Peer %Lu created\n", wg->dev->name, peer->internal_id);
return peer;
}
@@ -83,9 +83,10 @@ void peer_remove(struct wireguard_peer *peer)
timers_uninit_peer(peer);
routing_table_remove_by_peer(&peer->device->peer_routing_table, peer);
pubkey_hashtable_remove(&peer->device->peer_hashtable, peer);
- if (peer->device->peer_wq)
- flush_workqueue(peer->device->peer_wq);
- skb_queue_purge(&peer->tx_packet_queue);
+ skb_queue_purge(&peer->staged_packet_queue);
+ flush_workqueue(peer->device->packet_crypt_wq); /* The first flush is for encrypt/decrypt step. */
+ flush_workqueue(peer->device->packet_crypt_wq); /* The second flush is for send/receive step. */
+ flush_workqueue(peer->device->handshake_send_wq);
peer_put(peer);
}
@@ -93,7 +94,7 @@ static void rcu_release(struct rcu_head *rcu)
{
struct wireguard_peer *peer = container_of(rcu, struct wireguard_peer, rcu);
pr_debug("%s: Peer %Lu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr);
- skb_queue_purge(&peer->tx_packet_queue);
+ skb_queue_purge(&peer->staged_packet_queue);
dst_cache_destroy(&peer->endpoint_cache);
kzfree(peer);
}
diff --git a/src/peer.h b/src/peer.h
index c10406b..edd7c06 100644
--- a/src/peer.h
+++ b/src/peer.h
@@ -3,6 +3,7 @@
#ifndef PEER_H
#define PEER_H
+#include "device.h"
#include "noise.h"
#include "cookie.h"
@@ -34,8 +35,10 @@ struct wireguard_peer {
struct endpoint endpoint;
struct dst_cache endpoint_cache;
rwlock_t endpoint_lock;
- struct noise_handshake handshake;
+ struct crypt_queue tx_queue, rx_queue;
+ int serial_work_cpu;
struct noise_keypairs keypairs;
+ struct noise_handshake handshake;
u64 last_sent_handshake;
struct work_struct transmit_handshake_work, clear_peer_work;
struct cookie latest_cookie;
@@ -44,19 +47,13 @@ struct wireguard_peer {
struct timer_list timer_retransmit_handshake, timer_send_keepalive, timer_new_handshake, timer_zero_key_material, timer_persistent_keepalive;
unsigned int timer_handshake_attempts;
unsigned long persistent_keepalive_interval;
- bool timers_enabled;
- bool timer_need_another_keepalive;
- bool need_resend_queue;
- bool sent_lastminute_handshake;
+ bool timers_enabled, timer_need_another_keepalive, sent_lastminute_handshake;
struct timeval walltime_last_handshake;
- struct sk_buff_head tx_packet_queue;
+ struct sk_buff_head staged_packet_queue;
struct kref refcount;
struct rcu_head rcu;
struct list_head peer_list;
u64 internal_id;
-#ifdef CONFIG_WIREGUARD_PARALLEL
- atomic_t parallel_encryption_inflight;
-#endif
};
struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]);
diff --git a/src/queueing.c b/src/queueing.c
new file mode 100644
index 0000000..86e1324
--- /dev/null
+++ b/src/queueing.c
@@ -0,0 +1,46 @@
+/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
+
+#include "queueing.h"
+#include <linux/slab.h>
+
+struct kmem_cache *crypt_ctx_cache __read_mostly;
+
+struct multicore_worker __percpu *packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr)
+{
+ int cpu;
+ struct multicore_worker __percpu *worker = alloc_percpu(struct multicore_worker);
+ if (!worker)
+ return NULL;
+ for_each_possible_cpu (cpu) {
+ per_cpu_ptr(worker, cpu)->ptr = ptr;
+ INIT_WORK(&per_cpu_ptr(worker, cpu)->work, function);
+ }
+ return worker;
+}
+
+int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool multicore)
+{
+ INIT_LIST_HEAD(&queue->queue);
+ queue->len = 0;
+ spin_lock_init(&queue->lock);
+ if (multicore) {
+ queue->worker = packet_alloc_percpu_multicore_worker(function, queue);
+ if (!queue->worker)
+ return -ENOMEM;
+ } else
+ INIT_WORK(&queue->work, function);
+ return 0;
+}
+
+int __init init_crypt_ctx_cache(void)
+{
+ crypt_ctx_cache = KMEM_CACHE(crypt_ctx, 0);
+ if (!crypt_ctx_cache)
+ return -ENOMEM;
+ return 0;
+}
+
+void deinit_crypt_ctx_cache(void)
+{
+ kmem_cache_destroy(crypt_ctx_cache);
+}
diff --git a/src/queueing.h b/src/queueing.h
new file mode 100644
index 0000000..30df3a8
--- /dev/null
+++ b/src/queueing.h
@@ -0,0 +1,196 @@
+/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
+
+#ifndef QUEUEING_H
+#define QUEUEING_H
+
+#include "peer.h"
+#include <linux/types.h>
+#include <linux/skbuff.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
+
+struct wireguard_device;
+struct wireguard_peer;
+struct multicore_worker;
+struct crypt_queue;
+struct sk_buff;
+
+/* queueing.c APIs: */
+extern struct kmem_cache *crypt_ctx_cache __read_mostly;
+int init_crypt_ctx_cache(void);
+void deinit_crypt_ctx_cache(void);
+int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool multicore);
+struct multicore_worker __percpu *packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr);
+
+/* receive.c APIs: */
+void packet_receive(struct wireguard_device *wg, struct sk_buff *skb);
+void packet_handshake_receive_worker(struct work_struct *work);
+/* Workqueue workers: */
+void packet_rx_worker(struct work_struct *work);
+void packet_decrypt_worker(struct work_struct *work);
+
+/* send.c APIs: */
+void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool is_retry);
+void packet_send_handshake_response(struct wireguard_peer *peer);
+void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index);
+void packet_send_keepalive(struct wireguard_peer *peer);
+void packet_send_staged_packets(struct wireguard_peer *peer);
+/* Workqueue workers: */
+void packet_handshake_send_worker(struct work_struct *work);
+void packet_tx_worker(struct work_struct *work);
+void packet_encrypt_worker(struct work_struct *work);
+
+struct packet_cb {
+ u64 nonce;
+ u8 ds;
+};
+#define PACKET_CB(skb) ((struct packet_cb *)skb->cb)
+
+struct crypt_ctx {
+ struct list_head per_peer_node, per_device_node;
+ union {
+ struct sk_buff_head packets;
+ struct sk_buff *skb;
+ };
+ struct wireguard_peer *peer;
+ struct noise_keypair *keypair;
+ struct endpoint endpoint;
+ atomic_t is_finished;
+};
+
+/* Returns either the correct skb->protocol value, or 0 if invalid. */
+static inline __be16 skb_examine_untrusted_ip_hdr(struct sk_buff *skb)
+{
+ if (skb_network_header(skb) >= skb->head && (skb_network_header(skb) + sizeof(struct iphdr)) <= skb_tail_pointer(skb) && ip_hdr(skb)->version == 4)
+ return htons(ETH_P_IP);
+ if (skb_network_header(skb) >= skb->head && (skb_network_header(skb) + sizeof(struct ipv6hdr)) <= skb_tail_pointer(skb) && ipv6_hdr(skb)->version == 6)
+ return htons(ETH_P_IPV6);
+ return 0;
+}
+
+static inline unsigned int skb_padding(struct sk_buff *skb)
+{
+ /* We do this modulo business with the MTU, just in case the networking layer
+ * gives us a packet that's bigger than the MTU. Now that we support GSO, this
+ * shouldn't be a real problem, and this can likely be removed. But, caution! */
+ unsigned int last_unit = skb->len % skb->dev->mtu;
+ unsigned int padded_size = (last_unit + MESSAGE_PADDING_MULTIPLE - 1) & ~(MESSAGE_PADDING_MULTIPLE - 1);
+ if (padded_size > skb->dev->mtu)
+ padded_size = skb->dev->mtu;
+ return padded_size - last_unit;
+}
+
+static inline void skb_reset(struct sk_buff *skb)
+{
+ skb_scrub_packet(skb, false);
+ memset(&skb->headers_start, 0, offsetof(struct sk_buff, headers_end) - offsetof(struct sk_buff, headers_start));
+ skb->queue_mapping = 0;
+ skb->nohdr = 0;
+ skb->peeked = 0;
+ skb->mac_len = 0;
+ skb->dev = NULL;
+#ifdef CONFIG_NET_SCHED
+ skb->tc_index = 0;
+ skb_reset_tc(skb);
+#endif
+ skb->hdr_len = skb_headroom(skb);
+ skb_reset_mac_header(skb);
+ skb_reset_network_header(skb);
+ skb_probe_transport_header(skb, 0);
+ skb_reset_inner_headers(skb);
+}
+
+static inline int choose_cpu(int *stored_cpu, unsigned int id)
+{
+ unsigned int cpu = *stored_cpu, cpu_index, i;
+ if (unlikely(cpu == nr_cpumask_bits || !cpumask_test_cpu(cpu, cpu_online_mask))) {
+ cpu_index = id % cpumask_weight(cpu_online_mask);
+ cpu = cpumask_first(cpu_online_mask);
+ for (i = 0; i < cpu_index; ++i)
+ cpu = cpumask_next(cpu, cpu_online_mask);
+ *stored_cpu = cpu;
+ }
+ return cpu;
+}
+
+/* This function is racy, in the sense that next is unlocked, so it could return
+ * the same CPU twice. A race-free version of this would be to instead store an
+ * atomic sequence number, do an increment-and-return, and then iterate through
+ * every possible CPU until we get to that index -- choose_cpu. However that's
+ * a bit slower, and it doesn't seem like this potential race actually introduces
+ * any performance loss, so we live with it. */
+static inline int cpumask_next_online(int *next)
+{
+ int cpu = *next;
+ while (unlikely(!cpumask_test_cpu(cpu, cpu_online_mask)))
+ cpu = cpumask_next(cpu, cpu_online_mask) % nr_cpumask_bits;
+ *next = cpumask_next(cpu, cpu_online_mask) % nr_cpumask_bits;
+ return cpu;
+}
+
+static inline struct list_head *queue_dequeue(struct crypt_queue *queue)
+{
+ struct list_head *node;
+ spin_lock_bh(&queue->lock);
+ node = queue->queue.next;
+ if (&queue->queue == node) {
+ spin_unlock_bh(&queue->lock);
+ return NULL;
+ }
+ list_del(node);
+ --queue->len;
+ spin_unlock_bh(&queue->lock);
+ return node;
+}
+
+static inline bool queue_enqueue(struct crypt_queue *queue, struct list_head *node, int limit)
+{
+ spin_lock_bh(&queue->lock);
+ if (limit && queue->len >= limit) {
+ spin_unlock_bh(&queue->lock);
+ return false;
+ }
+ list_add_tail(node, &queue->queue);
+ ++queue->len;
+ spin_unlock_bh(&queue->lock);
+ return true;
+}
+
+static inline struct crypt_ctx *queue_dequeue_per_peer(struct crypt_queue *queue)
+{
+ struct list_head *node = queue_dequeue(queue);
+ return node ? list_entry(node, struct crypt_ctx, per_peer_node) : NULL;
+}
+
+static inline struct crypt_ctx *queue_dequeue_per_device(struct crypt_queue *queue)
+{
+ struct list_head *node = queue_dequeue(queue);
+ return node ? list_entry(node, struct crypt_ctx, per_device_node) : NULL;
+}
+
+static inline struct crypt_ctx *queue_first_per_peer(struct crypt_queue *queue)
+{
+ return list_first_entry_or_null(&queue->queue, struct crypt_ctx, per_peer_node);
+}
+
+static inline bool queue_enqueue_per_peer(struct crypt_queue *peer_queue, struct crypt_ctx *ctx)
+{
+ return queue_enqueue(peer_queue, &ctx->per_peer_node, MAX_QUEUED_PACKETS);
+}
+
+static inline bool queue_enqueue_per_device_and_peer(struct crypt_queue *device_queue, struct crypt_queue *peer_queue, struct crypt_ctx *ctx, struct workqueue_struct *wq, int *next_cpu)
+{
+ int cpu;
+ if (unlikely(!queue_enqueue_per_peer(peer_queue, ctx)))
+ return false;
+ cpu = cpumask_next_online(next_cpu);
+ queue_enqueue(device_queue, &ctx->per_device_node, 0);
+ queue_work_on(cpu, wq, &per_cpu_ptr(device_queue->worker, cpu)->work);
+ return true;
+}
+
+#ifdef DEBUG
+bool packet_counter_selftest(void);
+#endif
+
+#endif
diff --git a/src/receive.c b/src/receive.c
index da229df..a7f6004 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -1,11 +1,12 @@
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
-#include "packets.h"
+#include "queueing.h"
#include "device.h"
#include "peer.h"
#include "timers.h"
#include "messages.h"
#include "cookie.h"
+#include "socket.h"
#include <linux/ip.h>
#include <linux/ipv6.h>
@@ -145,9 +146,9 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff
peer_put(peer);
}
-void packet_process_queued_handshake_packets(struct work_struct *work)
+void packet_handshake_receive_worker(struct work_struct *work)
{
- struct wireguard_device *wg = container_of(work, struct handshake_worker, work)->wg;
+ struct wireguard_device *wg = container_of(work, struct multicore_worker, work)->ptr;
struct sk_buff *skb;
while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) {
@@ -173,10 +174,74 @@ static inline void keep_key_fresh(struct wireguard_peer *peer)
if (send) {
peer->sent_lastminute_handshake = true;
- packet_queue_handshake_initiation(peer, false);
+ packet_send_queued_handshake_initiation(peer, false);
}
}
+static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *key)
+{
+ struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1];
+ struct sk_buff *trailer;
+ int num_frags;
+
+ if (unlikely(!key))
+ return false;
+
+ if (unlikely(!key->is_valid || time_is_before_eq_jiffies64(key->birthdate + REJECT_AFTER_TIME) || key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
+ key->is_valid = false;
+ return false;
+ }
+
+ PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter);
+ skb_pull(skb, sizeof(struct message_data));
+ num_frags = skb_cow_data(skb, 0, &trailer);
+ if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
+ return false;
+
+ sg_init_table(sg, num_frags);
+ if (skb_to_sgvec(skb, sg, 0, skb->len) <= 0)
+ return false;
+
+ if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key))
+ return false;
+
+ return !pskb_trim(skb, skb->len - noise_encrypted_len(0));
+}
+
+/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
+static inline bool counter_validate(union noise_counter *counter, u64 their_counter)
+{
+ bool ret = false;
+ unsigned long index, index_current, top, i;
+ spin_lock_bh(&counter->receive.lock);
+
+ if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || their_counter >= REJECT_AFTER_MESSAGES))
+ goto out;
+
+ ++their_counter;
+
+ if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < counter->receive.counter))
+ goto out;
+
+ index = their_counter >> ilog2(BITS_PER_LONG);
+
+ if (likely(their_counter > counter->receive.counter)) {
+ index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
+ top = min_t(unsigned long, index - index_current, COUNTER_BITS_TOTAL / BITS_PER_LONG);
+ for (i = 1; i <= top; ++i)
+ counter->receive.backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
+ counter->receive.counter = their_counter;
+ }
+
+ index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
+ ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), &counter->receive.backtrack[index]);
+
+out:
+ spin_unlock_bh(&counter->receive.lock);
+ return ret;
+}
+#include "selftest/counter.h"
+
void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key)
{
struct net_device *dev = peer->device->dev;
@@ -187,7 +252,7 @@ void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer,
if (unlikely(used_new_key)) {
timers_handshake_complete(peer);
- packet_send_queue(peer);
+ packet_send_staged_packets(peer);
}
keep_key_fresh(peer);
@@ -262,7 +327,87 @@ packet_processed:
continue_processing:
timers_any_authenticated_packet_received(peer);
timers_any_authenticated_packet_traversal(peer);
- peer_put(peer);
+}
+
+void packet_rx_worker(struct work_struct *work)
+{
+ struct crypt_ctx *ctx;
+ struct crypt_queue *queue = container_of(work, struct crypt_queue, work);
+ struct sk_buff *skb;
+
+ local_bh_disable();
+ while ((ctx = queue_first_per_peer(queue)) != NULL && atomic_read(&ctx->is_finished)) {
+ queue_dequeue(queue);
+ if (likely((skb = ctx->skb) != NULL)) {
+ if (likely(counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(skb)->nonce))) {
+ skb_reset(skb);
+ packet_consume_data_done(skb, ctx->peer, &ctx->endpoint, noise_received_with_keypair(&ctx->peer->keypairs, ctx->keypair));
+ }
+ else {
+ net_dbg_ratelimited("%s: Packet has invalid nonce %Lu (max %Lu)\n", ctx->peer->device->dev->name, PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter);
+ dev_kfree_skb(skb);
+ }
+ }
+ noise_keypair_put(ctx->keypair);
+ peer_put(ctx->peer);
+ kmem_cache_free(crypt_ctx_cache, ctx);
+ }
+ local_bh_enable();
+}
+
+void packet_decrypt_worker(struct work_struct *work)
+{
+ struct crypt_ctx *ctx;
+ struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr;
+ struct wireguard_peer *peer;
+
+ while ((ctx = queue_dequeue_per_device(queue)) != NULL) {
+ if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) {
+ dev_kfree_skb(ctx->skb);
+ ctx->skb = NULL;
+ }
+ /* Dereferencing ctx is unsafe once ctx->is_finished == true, so
+ * we take a reference here first. */
+ peer = peer_rcu_get(ctx->peer);
+ atomic_set(&ctx->is_finished, true);
+ queue_work_on(choose_cpu(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &peer->rx_queue.work);
+ peer_put(peer);
+ }
+}
+
+static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb)
+{
+ struct crypt_ctx *ctx;
+ struct noise_keypair *keypair;
+ __le32 idx = ((struct message_data *)skb->data)->key_idx;
+
+ rcu_read_lock_bh();
+ keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx));
+ rcu_read_unlock_bh();
+ if (unlikely(!keypair)) {
+ dev_kfree_skb(skb);
+ return;
+ }
+
+ ctx = kmem_cache_zalloc(crypt_ctx_cache, GFP_ATOMIC);
+ if (unlikely(!ctx)) {
+ dev_kfree_skb(skb);
+ peer_put(ctx->keypair->entry.peer);
+ noise_keypair_put(keypair);
+ return;
+ }
+ ctx->keypair = keypair;
+ ctx->skb = skb;
+ /* We already have a reference to peer from index_hashtable_lookup. */
+ ctx->peer = ctx->keypair->entry.peer;
+
+ if (likely(queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &ctx->peer->rx_queue, ctx, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu)))
+ return; /* Successful. No need to drop references below. */
+
+ noise_keypair_put(ctx->keypair);
+ peer_put(ctx->peer);
+ dev_kfree_skb(ctx->skb);
+ kmem_cache_free(crypt_ctx_cache, ctx);
}
void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
@@ -274,24 +419,20 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
case MESSAGE_HANDSHAKE_INITIATION:
case MESSAGE_HANDSHAKE_RESPONSE:
case MESSAGE_HANDSHAKE_COOKIE: {
- int cpu_index, cpu, target_cpu;
+ int cpu;
if (skb_queue_len(&wg->incoming_handshakes) > MAX_QUEUED_INCOMING_HANDSHAKES) {
net_dbg_skb_ratelimited("%s: Too many handshakes queued, dropping packet from %pISpfsc\n", wg->dev->name, skb);
goto err;
}
skb_queue_tail(&wg->incoming_handshakes, skb);
- /* Select the CPU in a round-robin */
- cpu_index = ((unsigned int)atomic_inc_return(&wg->incoming_handshake_seqnr)) % cpumask_weight(cpu_online_mask);
- target_cpu = cpumask_first(cpu_online_mask);
- for (cpu = 0; cpu < cpu_index; ++cpu)
- target_cpu = cpumask_next(target_cpu, cpu_online_mask);
/* Queues up a call to packet_process_queued_handshake_packets(skb): */
- queue_work_on(target_cpu, wg->incoming_handshake_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, target_cpu)->work);
+ cpu = cpumask_next_online(&wg->incoming_handshake_cpu);
+ queue_work_on(cpu, wg->handshake_receive_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work);
break;
}
case MESSAGE_DATA:
PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb);
- packet_consume_data(skb, wg);
+ packet_consume_data(wg, skb);
break;
default:
net_dbg_skb_ratelimited("%s: Invalid packet from %pISpfsc\n", wg->dev->name, skb);
diff --git a/src/send.c b/src/send.c
index 6390efd..c725317 100644
--- a/src/send.c
+++ b/src/send.c
@@ -1,6 +1,6 @@
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
-#include "packets.h"
+#include "queueing.h"
#include "timers.h"
#include "device.h"
#include "peer.h"
@@ -12,6 +12,7 @@
#include <linux/inetdevice.h>
#include <linux/socket.h>
#include <linux/jiffies.h>
+#include <net/ip_tunnels.h>
#include <net/udp.h>
#include <net/sock.h>
@@ -37,14 +38,14 @@ static void packet_send_handshake_initiation(struct wireguard_peer *peer)
}
}
-void packet_send_queued_handshakes(struct work_struct *work)
+void packet_handshake_send_worker(struct work_struct *work)
{
struct wireguard_peer *peer = container_of(work, struct wireguard_peer, transmit_handshake_work);
packet_send_handshake_initiation(peer);
peer_put(peer);
}
-void packet_queue_handshake_initiation(struct wireguard_peer *peer, bool is_retry)
+void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool is_retry)
{
if (!is_retry)
peer->timer_handshake_attempts = 0;
@@ -56,7 +57,7 @@ void packet_queue_handshake_initiation(struct wireguard_peer *peer, bool is_retr
peer = peer_rcu_get(peer);
/* Queues up calling packet_send_queued_handshakes(peer), where we do a peer_put(peer) after: */
- if (!queue_work(peer->device->peer_wq, &peer->transmit_handshake_work))
+ if (!queue_work(peer->device->handshake_send_wq, &peer->transmit_handshake_work))
peer_put(peer); /* If the work was already queued, we want to drop the extra reference */
}
@@ -100,25 +101,70 @@ static inline void keep_key_fresh(struct wireguard_peer *peer)
rcu_read_unlock_bh();
if (send)
- packet_queue_handshake_initiation(peer, false);
+ packet_send_queued_handshake_initiation(peer, false);
+}
+
+static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, bool have_simd)
+{
+ struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1];
+ struct message_data *header;
+ unsigned int padding_len, plaintext_len, trailer_len;
+ int num_frags;
+ struct sk_buff *trailer;
+
+ /* Calculate lengths */
+ padding_len = skb_padding(skb);
+ trailer_len = padding_len + noise_encrypted_len(0);
+ plaintext_len = skb->len + padding_len;
+
+ /* Expand data section to have room for padding and auth tag */
+ num_frags = skb_cow_data(skb, trailer_len, &trailer);
+ if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
+ return false;
+
+ /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */
+ memset(skb_tail_pointer(trailer), 0, padding_len);
+
+ /* Expand head section to have room for our header and the network stack's headers. */
+ if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0))
+ return false;
+
+ /* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */
+ if (likely(!skb_checksum_setup(skb, true)))
+ skb_checksum_help(skb);
+
+ /* Only after checksumming can we safely add on the padding at the end and the header. */
+ header = (struct message_data *)skb_push(skb, sizeof(struct message_data));
+ header->header.type = cpu_to_le32(MESSAGE_DATA);
+ header->key_idx = keypair->remote_index;
+ header->counter = cpu_to_le64(PACKET_CB(skb)->nonce);
+ pskb_put(skb, trailer, trailer_len);
+
+ /* Now we can encrypt the scattergather segments */
+ sg_init_table(sg, num_frags);
+ if (skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(plaintext_len)) <= 0)
+ return false;
+ return chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0, PACKET_CB(skb)->nonce, keypair->sending.key, have_simd);
}
void packet_send_keepalive(struct wireguard_peer *peer)
{
struct sk_buff *skb;
- if (skb_queue_empty(&peer->tx_packet_queue)) {
+
+ if (skb_queue_empty(&peer->staged_packet_queue)) {
skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH, GFP_ATOMIC);
if (unlikely(!skb))
return;
skb_reserve(skb, DATA_PACKET_HEAD_ROOM);
skb->dev = peer->device->dev;
- skb_queue_tail(&peer->tx_packet_queue, skb);
+ skb_queue_tail(&peer->staged_packet_queue, skb);
net_dbg_ratelimited("%s: Sending keepalive packet to peer %Lu (%pISpfsc)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr);
}
- packet_send_queue(peer);
+
+ packet_send_staged_packets(peer);
}
-void packet_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *peer)
+static void packet_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *peer)
{
struct sk_buff *skb, *tmp;
bool is_keepalive, data_sent = false;
@@ -136,65 +182,133 @@ void packet_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *
timers_data_sent(peer);
keep_key_fresh(peer);
+}
- if (unlikely(peer->need_resend_queue))
- packet_send_queue(peer);
+void packet_tx_worker(struct work_struct *work)
+{
+ struct crypt_queue *queue = container_of(work, struct crypt_queue, work);
+ struct crypt_ctx *ctx;
+
+ while ((ctx = queue_first_per_peer(queue)) != NULL && atomic_read(&ctx->is_finished)) {
+ queue_dequeue(queue);
+ packet_create_data_done(&ctx->packets, ctx->peer);
+ peer_put(ctx->peer);
+ kmem_cache_free(crypt_ctx_cache, ctx);
+ }
}
-void packet_send_queue(struct wireguard_peer *peer)
+void packet_encrypt_worker(struct work_struct *work)
{
- struct sk_buff_head queue;
- struct sk_buff *skb;
+ struct crypt_ctx *ctx;
+ struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr;
+ struct sk_buff *skb, *tmp;
+ struct wireguard_peer *peer;
+ bool have_simd = chacha20poly1305_init_simd();
+
+ while ((ctx = queue_dequeue_per_device(queue)) != NULL) {
+ skb_queue_walk_safe(&ctx->packets, skb, tmp) {
+ if (likely(skb_encrypt(skb, ctx->keypair, have_simd))) {
+ skb_reset(skb);
+ } else {
+ __skb_unlink(skb, &ctx->packets);
+ dev_kfree_skb(skb);
+ }
+ }
+ /* Dereferencing ctx is unsafe once ctx->is_finished == true, so
+ * we grab an additional reference to peer. */
+ peer = peer_rcu_get(ctx->peer);
+ atomic_set(&ctx->is_finished, true);
+ queue_work_on(choose_cpu(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &peer->tx_queue.work);
+ peer_put(peer);
+ }
+ chacha20poly1305_deinit_simd(have_simd);
+}
- peer->need_resend_queue = false;
+static void packet_create_data(struct wireguard_peer *peer, struct sk_buff_head *packets, struct noise_keypair *keypair)
+{
+ struct crypt_ctx *ctx;
+ struct wireguard_device *wg = peer->device;
- /* Steal the current queue into our local one. */
- skb_queue_head_init(&queue);
- spin_lock_bh(&peer->tx_packet_queue.lock);
- skb_queue_splice_init(&peer->tx_packet_queue, &queue);
- spin_unlock_bh(&peer->tx_packet_queue.lock);
+ ctx = kmem_cache_zalloc(crypt_ctx_cache, GFP_ATOMIC);
+ if (unlikely(!ctx)) {
+ skb_queue_purge(packets);
+ goto err_drop_refs;
+ }
+ /* This function consumes the passed references to peer and keypair. */
+ ctx->keypair = keypair;
+ ctx->peer = peer;
+ __skb_queue_head_init(&ctx->packets);
+ skb_queue_splice_tail(packets, &ctx->packets);
+ if (likely(queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, ctx, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu)))
+ return; /* Successful. No need to fall through to drop references below. */
+
+ skb_queue_purge(&ctx->packets);
+ kmem_cache_free(crypt_ctx_cache, ctx);
+
+err_drop_refs:
+ noise_keypair_put(keypair);
+ peer_put(peer);
+}
- if (unlikely(skb_queue_empty(&queue)))
+void packet_send_staged_packets(struct wireguard_peer *peer)
+{
+ struct noise_keypair *keypair;
+ struct noise_symmetric_key *key;
+ struct sk_buff_head packets;
+ struct sk_buff *skb;
+
+ /* Steal the current queue into our local one. */
+ __skb_queue_head_init(&packets);
+ spin_lock_bh(&peer->staged_packet_queue.lock);
+ skb_queue_splice_init(&peer->staged_packet_queue, &packets);
+ spin_unlock_bh(&peer->staged_packet_queue.lock);
+ if (unlikely(skb_queue_empty(&packets)))
return;
- /* We submit it for encryption and sending. */
- switch (packet_create_data(&queue, peer)) {
- case 0:
- break;
- case -EBUSY:
- /* EBUSY happens when the parallel workers are all filled up, in which
- * case we should requeue everything. */
-
- /* First, we mark that we should try to do this later, when existing
- * jobs are done. */
- peer->need_resend_queue = true;
-
- /* We stick the remaining skbs from local_queue at the top of the peer's
- * queue again, setting the top of local_queue to be the skb that begins
- * the requeueing. */
- spin_lock_bh(&peer->tx_packet_queue.lock);
- skb_queue_splice(&queue, &peer->tx_packet_queue);
- spin_unlock_bh(&peer->tx_packet_queue.lock);
- break;
- case -ENOKEY:
- /* ENOKEY means that we don't have a valid session for the peer, which
- * means we should initiate a session, but after requeuing like above.
- * Since we'll be queuing these up for potentially a little while, we
- * first make sure they're no longer using up a socket's write buffer. */
-
- skb_queue_walk (&queue, skb)
- skb_orphan(skb);
-
- spin_lock_bh(&peer->tx_packet_queue.lock);
- skb_queue_splice(&queue, &peer->tx_packet_queue);
- spin_unlock_bh(&peer->tx_packet_queue.lock);
-
- packet_queue_handshake_initiation(peer, false);
- break;
- default:
- /* If we failed for any other reason, we want to just free the packets and
- * forget about them. We do this unlocked, since we're the only ones with
- * a reference to the local queue. */
- __skb_queue_purge(&queue);
+ /* First we make sure we have a valid reference to a valid key. */
+ rcu_read_lock_bh();
+ keypair = noise_keypair_get(rcu_dereference_bh(peer->keypairs.current_keypair));
+ rcu_read_unlock_bh();
+ if (unlikely(!keypair))
+ goto out_nokey;
+ key = &keypair->sending;
+ if (unlikely(!key || !key->is_valid))
+ goto out_nokey;
+ if (unlikely(time_is_before_eq_jiffies64(key->birthdate + REJECT_AFTER_TIME)))
+ goto out_invalid;
+
+ /* After we know we have a somewhat valid key, we now try to assign nonces to
+ * all of the packets in the queue. If we can't assign nonces for all of them,
+ * we just consider it a failure and wait for the next handshake. */
+ skb_queue_walk (&packets, skb) {
+ PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb);
+ PACKET_CB(skb)->nonce = atomic64_inc_return(&key->counter.counter) - 1;
+ if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
+ goto out_invalid;
}
+
+ /* We pass off our peer and keypair references too the data subsystem and return. */
+ packet_create_data(peer_rcu_get(peer), &packets, keypair);
+ return;
+
+out_invalid:
+ key->is_valid = false;
+out_nokey:
+ noise_keypair_put(keypair);
+
+ /* We orphan the packets if we're waiting on a handshake, so that they
+ * don't block a socket's pool. */
+ skb_queue_walk (&packets, skb)
+ skb_orphan(skb);
+ /* Then we put them back on the top of the queue. We're not too concerned about
+ * accidently getting things a little out of order if packets are being added
+ * really fast, because this queue is for before packets can even be sent and
+ * it's small anyway. */
+ spin_lock_bh(&peer->staged_packet_queue.lock);
+ skb_queue_splice(&packets, &peer->staged_packet_queue);
+ spin_unlock_bh(&peer->staged_packet_queue.lock);
+
+ /* If we're exiting because there's something wrong with the key, it means
+ * we should initiate a new handshake. */
+ packet_send_queued_handshake_initiation(peer, false);
}
diff --git a/src/socket.c b/src/socket.c
index dce5313..4f78de1 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -3,7 +3,7 @@
#include "device.h"
#include "peer.h"
#include "socket.h"
-#include "packets.h"
+#include "queueing.h"
#include "messages.h"
#include <linux/ctype.h>
diff --git a/src/tests/qemu/kernel.config b/src/tests/qemu/kernel.config
index 84a9875..e6016d0 100644
--- a/src/tests/qemu/kernel.config
+++ b/src/tests/qemu/kernel.config
@@ -71,5 +71,4 @@ CONFIG_BLK_DEV_INITRD=y
CONFIG_LEGACY_VSYSCALL_NONE=y
CONFIG_KERNEL_GZIP=y
CONFIG_WIREGUARD=y
-CONFIG_WIREGUARD_PARALLEL=y
CONFIG_WIREGUARD_DEBUG=y
diff --git a/src/timers.c b/src/timers.c
index 9712c9e..e7cdd11 100644
--- a/src/timers.c
+++ b/src/timers.c
@@ -3,7 +3,8 @@
#include "timers.h"
#include "device.h"
#include "peer.h"
-#include "packets.h"
+#include "queueing.h"
+#include "socket.h"
/*
* Timer for retransmitting the handshake if we don't hear back after `REKEY_TIMEOUT + jitter` ms
@@ -30,10 +31,12 @@ static void expired_retransmit_handshake(unsigned long ptr)
if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) {
pr_debug("%s: Handshake for peer %Lu (%pISpfsc) did not complete after %d attempts, giving up\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2);
- del_timer(&peer->timer_send_keepalive);
- /* We remove all existing packets and don't try again,
+ if (likely(peer->timers_enabled))
+ del_timer(&peer->timer_send_keepalive);
+ /* We drop all packets without a keypair and don't try again,
* if we try unsuccessfully for too long to make a handshake. */
- skb_queue_purge(&peer->tx_packet_queue);
+ skb_queue_purge(&peer->staged_packet_queue);
+
/* We set a timer for destroying any residue that might be left
* of a partial exchange. */
if (likely(peer->timers_enabled) && !timer_pending(&peer->timer_zero_key_material))
@@ -45,7 +48,7 @@ static void expired_retransmit_handshake(unsigned long ptr)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
socket_clear_peer_endpoint_src(peer);
- packet_queue_handshake_initiation(peer, true);
+ packet_send_queued_handshake_initiation(peer, true);
}
peer_put(peer);
}
@@ -56,7 +59,7 @@ static void expired_send_keepalive(unsigned long ptr)
packet_send_keepalive(peer);
if (peer->timer_need_another_keepalive) {
peer->timer_need_another_keepalive = false;
- if (peer->timers_enabled)
+ if (likely(peer->timers_enabled))
mod_timer(&peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT);
}
peer_put(peer);
@@ -68,14 +71,14 @@ static void expired_new_handshake(unsigned long ptr)
pr_debug("%s: Retrying handshake with peer %Lu (%pISpfsc) because we stopped hearing back after %d seconds\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) / HZ);
/* We clear the endpoint address src address, in case this is the cause of trouble. */
socket_clear_peer_endpoint_src(peer);
- packet_queue_handshake_initiation(peer, false);
+ packet_send_queued_handshake_initiation(peer, false);
peer_put(peer);
}
static void expired_zero_key_material(unsigned long ptr)
{
peer_get_from_ptr(ptr);
- if (!queue_work(peer->device->peer_wq, &peer->clear_peer_work)) /* Takes our reference. */
+ if (!queue_work(peer->device->handshake_send_wq, &peer->clear_peer_work)) /* Takes our reference. */
peer_put(peer); /* If the work was already on the queue, we want to drop the extra reference */
}
static void queued_expired_zero_key_material(struct work_struct *work)