From e1fead9769303cf160addfb9f16f8f9fda1ff617 Mon Sep 17 00:00:00 2001 From: Samuel Holland Date: Wed, 7 Jun 2017 01:39:08 -0500 Subject: data: entirely rework parallel system This removes our dependency on padata. Signed-off-by: Samuel Holland --- src/data.c | 432 +++++++++++++++++++++++++++++++------------------------------ 1 file changed, 220 insertions(+), 212 deletions(-) (limited to 'src/data.c') diff --git a/src/data.c b/src/data.c index fb91861..b5569c7 100644 --- a/src/data.c +++ b/src/data.c @@ -5,6 +5,8 @@ #include "peer.h" #include "messages.h" #include "packets.h" +#include "queue.h" +#include "timers.h" #include "hashtables.h" #include @@ -15,43 +17,42 @@ #include #include -struct encryption_ctx { - struct padata_priv padata; - struct sk_buff_head queue; - struct wireguard_peer *peer; - struct noise_keypair *keypair; -}; - -struct decryption_ctx { - struct padata_priv padata; - struct endpoint endpoint; - struct sk_buff *skb; - struct noise_keypair *keypair; -}; +static struct kmem_cache *crypt_ctx_cache __read_mostly; -#ifdef CONFIG_WIREGUARD_PARALLEL -static struct kmem_cache *encryption_ctx_cache __read_mostly; -static struct kmem_cache *decryption_ctx_cache __read_mostly; - -int __init packet_init_data_caches(void) +int __init init_crypt_cache(void) { - encryption_ctx_cache = KMEM_CACHE(encryption_ctx, 0); - if (!encryption_ctx_cache) - return -ENOMEM; - decryption_ctx_cache = KMEM_CACHE(decryption_ctx, 0); - if (!decryption_ctx_cache) { - kmem_cache_destroy(encryption_ctx_cache); + crypt_ctx_cache = KMEM_CACHE(crypt_ctx, 0); + if (!crypt_ctx_cache) return -ENOMEM; - } return 0; } -void packet_deinit_data_caches(void) +void deinit_crypt_cache(void) { - kmem_cache_destroy(encryption_ctx_cache); - kmem_cache_destroy(decryption_ctx_cache); + kmem_cache_destroy(crypt_ctx_cache); } -#endif + +static void drop_ctx(struct crypt_ctx *ctx, bool sending) +{ + if (ctx->keypair) + noise_keypair_put(ctx->keypair); + peer_put(ctx->peer); + if (sending) + skb_queue_purge(&ctx->packets); + else + dev_kfree_skb(ctx->skb); + kmem_cache_free(crypt_ctx_cache, ctx); +} + +#define drop_ctx_and_continue(ctx, sending) ({ \ + drop_ctx(ctx, sending); \ + continue; \ +}) + +#define drop_ctx_and_return(ctx, sending) ({ \ + drop_ctx(ctx, sending); \ + return; \ +}) /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */ static inline bool counter_validate(union noise_counter *counter, u64 their_counter) @@ -195,236 +196,243 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key * return !pskb_trim(skb, skb->len - noise_encrypted_len(0)); } -static inline bool get_encryption_nonce(u64 *nonce, struct noise_symmetric_key *key) +static inline bool packet_initialize_ctx(struct crypt_ctx *ctx) { - if (unlikely(!key)) - return false; - - if (unlikely(!key->is_valid || time_is_before_eq_jiffies64(key->birthdate + REJECT_AFTER_TIME))) { - key->is_valid = false; - return false; - } + struct noise_symmetric_key *key; + struct sk_buff *skb; - *nonce = atomic64_inc_return(&key->counter.counter) - 1; - if (*nonce >= REJECT_AFTER_MESSAGES) { - key->is_valid = false; + rcu_read_lock_bh(); + ctx->keypair = noise_keypair_get(rcu_dereference_bh(ctx->peer->keypairs.current_keypair)); + rcu_read_unlock_bh(); + if (unlikely(!ctx->keypair)) return false; + key = &ctx->keypair->sending; + if (unlikely(!key || !key->is_valid)) + goto out_nokey; + if (unlikely(time_is_before_eq_jiffies64(key->birthdate + REJECT_AFTER_TIME))) + goto out_invalid; + + skb_queue_walk(&ctx->packets, skb) { + PACKET_CB(skb)->nonce = atomic64_inc_return(&key->counter.counter) - 1; + if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES)) + goto out_invalid; } return true; + +out_invalid: + key->is_valid = false; +out_nokey: + noise_keypair_put(ctx->keypair); + ctx->keypair = NULL; + return false; } -static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_keypair *keypair) +void packet_send_worker(struct work_struct *work) { + struct crypt_ctx *ctx; + struct crypt_queue *queue = container_of(work, struct crypt_queue, work); struct sk_buff *skb, *tmp; - bool have_simd = chacha20poly1305_init_simd(); - skb_queue_walk_safe (queue, skb, tmp) { - if (unlikely(!skb_encrypt(skb, keypair, have_simd))) { - __skb_unlink(skb, queue); - kfree_skb(skb); - continue; + struct wireguard_peer *peer = container_of(queue, struct wireguard_peer, send_queue); + bool data_sent = false; + + timers_any_authenticated_packet_traversal(peer); + while ((ctx = queue_first_peer(queue)) != NULL && atomic_read(&ctx->state) == CTX_FINISHED) { + queue_dequeue(queue); + skb_queue_walk_safe(&ctx->packets, skb, tmp) { + bool is_keepalive = skb->len == message_data_len(0); + if (likely(!socket_send_skb_to_peer(peer, skb, PACKET_CB(skb)->ds) && !is_keepalive)) + data_sent = true; } - skb_reset(skb); + noise_keypair_put(ctx->keypair); + peer_put(ctx->peer); + kmem_cache_free(crypt_ctx_cache, ctx); } - chacha20poly1305_deinit_simd(have_simd); - noise_keypair_put(keypair); + if (likely(data_sent)) + timers_data_sent(peer); + keep_key_fresh_send(peer); } -#ifdef CONFIG_WIREGUARD_PARALLEL -static void begin_parallel_encryption(struct padata_priv *padata) +void packet_encrypt_worker(struct work_struct *work) { - struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata); -#if IS_ENABLED(CONFIG_KERNEL_MODE_NEON) && defined(CONFIG_ARM) - local_bh_enable(); -#endif - queue_encrypt_reset(&ctx->queue, ctx->keypair); -#if IS_ENABLED(CONFIG_KERNEL_MODE_NEON) && defined(CONFIG_ARM) - local_bh_disable(); -#endif - padata_do_serial(padata); -} + struct crypt_ctx *ctx; + struct crypt_queue *queue = container_of(work, struct crypt_queue, work); + struct sk_buff *skb, *tmp; + struct wireguard_peer *peer; + bool have_simd = chacha20poly1305_init_simd(); -static void finish_parallel_encryption(struct padata_priv *padata) -{ - struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata); - packet_create_data_done(&ctx->queue, ctx->peer); - atomic_dec(&ctx->peer->parallel_encryption_inflight); - peer_put(ctx->peer); - kmem_cache_free(encryption_ctx_cache, ctx); + while ((ctx = queue_dequeue_shared(queue)) != NULL) { + skb_queue_walk_safe(&ctx->packets, skb, tmp) { + if (likely(skb_encrypt(skb, ctx->keypair, have_simd))) { + skb_reset(skb); + } else { + __skb_unlink(skb, &ctx->packets); + dev_kfree_skb(skb); + } + } + /* Dereferencing ctx is unsafe once ctx->state == CTX_FINISHED. */ + peer = peer_rcu_get(ctx->peer); + atomic_set(&ctx->state, CTX_FINISHED); + queue_work_on(peer->work_cpu, peer->device->crypt_wq, &peer->send_queue.work); + peer_put(peer); + } + chacha20poly1305_deinit_simd(have_simd); } -static inline unsigned int choose_cpu(__le32 key) +void packet_init_worker(struct work_struct *work) { - unsigned int cpu_index, cpu, cb_cpu; - - /* This ensures that packets encrypted to the same key are sent in-order. */ - cpu_index = ((__force unsigned int)key) % cpumask_weight(cpu_online_mask); - cb_cpu = cpumask_first(cpu_online_mask); - for (cpu = 0; cpu < cpu_index; ++cpu) - cb_cpu = cpumask_next(cb_cpu, cpu_online_mask); - - return cb_cpu; + struct crypt_ctx *ctx; + struct crypt_queue *queue = container_of(work, struct crypt_queue, work); + struct wireguard_peer *peer = container_of(queue, struct wireguard_peer, init_queue); + + spin_lock(&peer->init_queue_lock); + while ((ctx = queue_first_peer(queue)) != NULL) { + if (unlikely(!packet_initialize_ctx(ctx))) { + packet_queue_handshake_initiation(peer, false); + break; + } + queue_dequeue(queue); + if (unlikely(!queue_enqueue_peer(&peer->send_queue, ctx))) + drop_ctx_and_continue(ctx, true); + queue_enqueue_shared(peer->device->encrypt_queue, ctx, peer->device->crypt_wq, &peer->device->encrypt_cpu); + } + spin_unlock(&peer->init_queue_lock); } -#endif -int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer) +void packet_create_data(struct wireguard_peer *peer, struct sk_buff_head *packets) { - int ret = -ENOKEY; - struct noise_keypair *keypair; + struct crypt_ctx *ctx; struct sk_buff *skb; + struct wireguard_device *wg = peer->device; + bool need_handshake = false; - rcu_read_lock_bh(); - keypair = noise_keypair_get(rcu_dereference_bh(peer->keypairs.current_keypair)); - rcu_read_unlock_bh(); - if (unlikely(!keypair)) - return ret; - - skb_queue_walk (queue, skb) { - if (unlikely(!get_encryption_nonce(&PACKET_CB(skb)->nonce, &keypair->sending))) - goto err; - - /* After the first time through the loop, if we've suceeded with a legitimate nonce, - * then we don't want a -ENOKEY error if subsequent nonces fail. Rather, if this - * condition arises, we simply want error out hard, and drop the entire queue. This - * is partially lazy programming and TODO: this could be made to only requeue the - * ones that had no nonce. But I'm not sure it's worth the added complexity, given - * how rarely that condition should arise. */ - ret = -EPIPE; + ctx = kmem_cache_alloc(crypt_ctx_cache, GFP_ATOMIC); + if (unlikely(!ctx)) { + skb_queue_purge(packets); + return; } - -#ifdef CONFIG_WIREGUARD_PARALLEL - if ((skb_queue_len(queue) > 1 || queue->next->len > 256 || atomic_read(&peer->parallel_encryption_inflight) > 0) && cpumask_weight(cpu_online_mask) > 1) { - struct encryption_ctx *ctx = kmem_cache_alloc(encryption_ctx_cache, GFP_ATOMIC); - if (!ctx) - goto serial_encrypt; - skb_queue_head_init(&ctx->queue); - skb_queue_splice_init(queue, &ctx->queue); - memset(&ctx->padata, 0, sizeof(ctx->padata)); - ctx->padata.parallel = begin_parallel_encryption; - ctx->padata.serial = finish_parallel_encryption; - ctx->keypair = keypair; - ctx->peer = peer_rcu_get(peer); - ret = -EBUSY; - if (unlikely(!ctx->peer)) - goto err_parallel; - atomic_inc(&peer->parallel_encryption_inflight); - if (unlikely(padata_do_parallel(peer->device->encrypt_pd, &ctx->padata, choose_cpu(keypair->remote_index)))) { - atomic_dec(&peer->parallel_encryption_inflight); - peer_put(ctx->peer); -err_parallel: - skb_queue_splice(&ctx->queue, queue); - kmem_cache_free(encryption_ctx_cache, ctx); - goto err; + skb_queue_head_init(&ctx->packets); + skb_queue_splice_tail(packets, &ctx->packets); + ctx->peer = peer_rcu_get(peer); + ctx->keypair = NULL; + atomic_set(&ctx->state, CTX_NEW); + + /* If there are already packets on the init queue, these must go behind + * them to maintain the correct order, so we can only take the fast path + * when the queue is empty. */ + if (likely(queue_empty(&peer->init_queue))) { + if (likely(packet_initialize_ctx(ctx))) { + if (unlikely(!queue_enqueue_peer(&peer->send_queue, ctx))) + drop_ctx_and_return(ctx, true); + queue_enqueue_shared(wg->encrypt_queue, ctx, wg->crypt_wq, &wg->encrypt_cpu); + return; } - } else -serial_encrypt: -#endif - { - queue_encrypt_reset(queue, keypair); - packet_create_data_done(queue, peer); + /* Initialization failed, so we need a new keypair. */ + need_handshake = true; } - return 0; -err: - noise_keypair_put(keypair); - return ret; -} - -static void begin_decrypt_packet(struct decryption_ctx *ctx) -{ - if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) { - peer_put(ctx->keypair->entry.peer); - noise_keypair_put(ctx->keypair); - dev_kfree_skb(ctx->skb); - ctx->skb = NULL; + /* Packets are kept around in the init queue as long as there is an + * ongoing handshake. Throw out the oldest packets instead of the new + * ones. If we cannot acquire the lock, packets are being dequeued on + * another thread. */ + if (unlikely(queue_full(&peer->init_queue)) && spin_trylock(&peer->init_queue_lock)) { + struct crypt_ctx *tmp = queue_dequeue_peer(&peer->init_queue); + if (likely(tmp)) + drop_ctx(tmp, true); + spin_unlock(&peer->init_queue_lock); } + skb_queue_walk(&ctx->packets, skb) + skb_orphan(skb); + if (unlikely(!queue_enqueue_peer(&peer->init_queue, ctx))) + drop_ctx_and_return(ctx, true); + if (need_handshake) + packet_queue_handshake_initiation(peer, false); + /* If we have a valid keypair, but took the slow path because init_queue + * had packets on it, init_queue.worker() may have finished + * processing the existing packets and returned since we checked if the + * init_queue was empty. Run the worker again if this is the only ctx + * remaining on the queue. */ + else if (unlikely(queue_first_peer(&peer->init_queue) == ctx)) + queue_work(peer->device->crypt_wq, &peer->init_queue.work); } -static void finish_decrypt_packet(struct decryption_ctx *ctx) +void packet_receive_worker(struct work_struct *work) { - bool used_new_key; - - if (!ctx->skb) - return; + struct crypt_ctx *ctx; + struct crypt_queue *queue = container_of(work, struct crypt_queue, work); + struct sk_buff *skb; - if (unlikely(!counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(ctx->skb)->nonce))) { - net_dbg_ratelimited("%s: Packet has invalid nonce %Lu (max %Lu)\n", ctx->keypair->entry.peer->device->dev->name, PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter); - peer_put(ctx->keypair->entry.peer); + while ((ctx = queue_first_peer(queue)) != NULL && atomic_read(&ctx->state) == CTX_FINISHED) { + queue_dequeue(queue); + if (likely(skb = ctx->skb)) { + if (unlikely(!counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(skb)->nonce))) { + net_dbg_ratelimited("%s: Packet has invalid nonce %Lu (max %Lu)\n", ctx->peer->device->dev->name, PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter); + dev_kfree_skb(skb); + } else { + skb_reset(skb); + packet_consume_data_done(skb, ctx->peer, &ctx->endpoint, noise_received_with_keypair(&ctx->peer->keypairs, ctx->keypair)); + } + } noise_keypair_put(ctx->keypair); - dev_kfree_skb(ctx->skb); - return; + peer_put(ctx->peer); + kmem_cache_free(crypt_ctx_cache, ctx); } - - used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair); - skb_reset(ctx->skb); - packet_consume_data_done(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key); - noise_keypair_put(ctx->keypair); } -#ifdef CONFIG_WIREGUARD_PARALLEL -static void begin_parallel_decryption(struct padata_priv *padata) +void packet_decrypt_worker(struct work_struct *work) { - struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata); -#if IS_ENABLED(CONFIG_KERNEL_MODE_NEON) && defined(CONFIG_ARM) - local_bh_enable(); -#endif - begin_decrypt_packet(ctx); -#if IS_ENABLED(CONFIG_KERNEL_MODE_NEON) && defined(CONFIG_ARM) - local_bh_disable(); -#endif - padata_do_serial(padata); -} + struct crypt_ctx *ctx; + struct crypt_queue *queue = container_of(work, struct crypt_queue, work); + struct wireguard_peer *peer; -static void finish_parallel_decryption(struct padata_priv *padata) -{ - struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata); - finish_decrypt_packet(ctx); - kmem_cache_free(decryption_ctx_cache, ctx); + while ((ctx = queue_dequeue_shared(queue)) != NULL) { + if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) { + dev_kfree_skb(ctx->skb); + ctx->skb = NULL; + } + /* Dereferencing ctx is unsafe once ctx->state == CTX_FINISHED. */ + peer = peer_rcu_get(ctx->peer); + atomic_set(&ctx->state, CTX_FINISHED); + queue_work_on(peer->work_cpu, peer->device->crypt_wq, &peer->receive_queue.work); + peer_put(peer); + } } -#endif void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg) { - struct noise_keypair *keypair; + struct crypt_ctx *ctx; __le32 idx = ((struct message_data *)skb->data)->key_idx; + ctx = kmem_cache_alloc(crypt_ctx_cache, GFP_ATOMIC); + if (unlikely(!ctx)) { + dev_kfree_skb(skb); + return; + } rcu_read_lock_bh(); - keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); + ctx->keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); rcu_read_unlock_bh(); - if (unlikely(!keypair)) - goto err; - -#ifdef CONFIG_WIREGUARD_PARALLEL - if (cpumask_weight(cpu_online_mask) > 1) { - struct decryption_ctx *ctx = kmem_cache_alloc(decryption_ctx_cache, GFP_ATOMIC); - if (unlikely(!ctx)) - goto err_peer; - ctx->skb = skb; - ctx->keypair = keypair; - memset(&ctx->padata, 0, sizeof(ctx->padata)); - ctx->padata.parallel = begin_parallel_decryption; - ctx->padata.serial = finish_parallel_decryption; - if (unlikely(padata_do_parallel(wg->decrypt_pd, &ctx->padata, choose_cpu(idx)))) { - kmem_cache_free(decryption_ctx_cache, ctx); - goto err_peer; - } - } else -#endif - { - struct decryption_ctx ctx = { - .skb = skb, - .keypair = keypair - }; - begin_decrypt_packet(&ctx); - finish_decrypt_packet(&ctx); + if (unlikely(!ctx->keypair)) { + kmem_cache_free(crypt_ctx_cache, ctx); + dev_kfree_skb(skb); + return; } - return; + ctx->skb = skb; + /* index_hashtable_lookup() already gets a reference to peer. */ + ctx->peer = ctx->keypair->entry.peer; + atomic_set(&ctx->state, CTX_NEW); + + if (unlikely(!queue_enqueue_peer(&ctx->peer->receive_queue, ctx))) + drop_ctx_and_return(ctx, false); + queue_enqueue_shared(wg->decrypt_queue, ctx, wg->crypt_wq, &wg->decrypt_cpu); +} -#ifdef CONFIG_WIREGUARD_PARALLEL -err_peer: - peer_put(keypair->entry.peer); - noise_keypair_put(keypair); -#endif -err: - dev_kfree_skb(skb); +void peer_purge_queues(struct wireguard_peer *peer) +{ + struct crypt_ctx *ctx; + + if (!spin_trylock(&peer->init_queue_lock)) + return; + while ((ctx = queue_dequeue_peer(&peer->init_queue)) != NULL) + drop_ctx(ctx, true); + spin_unlock(&peer->init_queue_lock); } -- cgit v1.2.3-59-g8ed1b