From 074b21a75a87f9d37209cd17f35d779435688d52 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 4 Nov 2016 16:00:22 +0100 Subject: data: use a memory cache for parallel ctx --- src/data.c | 141 +++++++++++++++++++++++++++++++++------------------------- src/main.c | 9 +++- src/packets.h | 2 + 3 files changed, 91 insertions(+), 61 deletions(-) diff --git a/src/data.c b/src/data.c index 1f6c5b6..e027510 100644 --- a/src/data.c +++ b/src/data.c @@ -15,6 +15,56 @@ #include #include +struct encryption_skb_cb { + uint8_t ds; + uint8_t num_frags; + unsigned int plaintext_len, trailer_len; + struct sk_buff *trailer; + uint64_t nonce; +}; + +struct encryption_ctx { + struct padata_priv padata; + struct sk_buff_head queue; + void (*callback)(struct sk_buff_head *, struct wireguard_peer *); + struct wireguard_peer *peer; + struct noise_keypair *keypair; +}; + +struct decryption_ctx { + struct padata_priv padata; + struct sk_buff *skb; + void (*callback)(struct sk_buff *skb, struct wireguard_peer *, struct sockaddr_storage *, bool used_new_key, int err); + struct noise_keypair *keypair; + struct sockaddr_storage addr; + uint64_t nonce; + uint8_t num_frags; + int ret; +}; + +static struct kmem_cache *encryption_ctx_cache; +static struct kmem_cache *decryption_ctx_cache; + +int packet_init_data_caches(void) +{ + BUILD_BUG_ON(sizeof(struct encryption_skb_cb) > sizeof(((struct sk_buff *)0)->cb)); + encryption_ctx_cache = kmem_cache_create("wireguard_encryption_ctx", sizeof(struct encryption_ctx), 0, 0, NULL); + if (!encryption_ctx_cache) + return -ENOMEM; + decryption_ctx_cache = kmem_cache_create("wireguard_decryption_ctx", sizeof(struct decryption_ctx), 0, 0, NULL); + if (!decryption_ctx_cache) { + kmem_cache_destroy(encryption_ctx_cache); + return -ENOMEM; + } + return 0; +} + +void packet_deinit_data_caches(void) +{ + kmem_cache_destroy(encryption_ctx_cache); + kmem_cache_destroy(decryption_ctx_cache); +} + /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */ static inline bool counter_validate(union noise_counter *counter, u64 their_counter) { @@ -76,18 +126,10 @@ static inline void skb_reset(struct sk_buff *skb) skb_probe_transport_header(skb, 0); } -struct packet_data_encryption_ctx { - uint8_t ds; - uint8_t num_frags; - unsigned int plaintext_len, trailer_len; - struct sk_buff *trailer; - uint64_t nonce; -}; - static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, bool have_simd) { - struct packet_data_encryption_ctx *ctx = (struct packet_data_encryption_ctx *)skb->cb; - struct scatterlist sg[ctx->num_frags]; /* This should be bound to at most 128 by the caller. */ + struct encryption_skb_cb *cb = (struct encryption_skb_cb *)skb->cb; + struct scatterlist sg[cb->num_frags]; /* This should be bound to at most 128 by the caller. */ struct message_data *header; /* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */ @@ -98,13 +140,13 @@ static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypai header = (struct message_data *)skb_push(skb, sizeof(struct message_data)); header->header.type = MESSAGE_DATA; header->key_idx = keypair->remote_index; - header->counter = cpu_to_le64(ctx->nonce); - pskb_put(skb, ctx->trailer, ctx->trailer_len); + header->counter = cpu_to_le64(cb->nonce); + pskb_put(skb, cb->trailer, cb->trailer_len); /* Now we can encrypt the scattergather segments */ - sg_init_table(sg, ctx->num_frags); - skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(ctx->plaintext_len)); - chacha20poly1305_encrypt_sg(sg, sg, ctx->plaintext_len, NULL, 0, ctx->nonce, keypair->sending.key, have_simd); + sg_init_table(sg, cb->num_frags); + skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(cb->plaintext_len)); + chacha20poly1305_encrypt_sg(sg, sg, cb->plaintext_len, NULL, 0, cb->nonce, keypair->sending.key, have_simd); } static inline bool skb_decrypt(struct sk_buff *skb, uint8_t num_frags, uint64_t nonce, struct noise_symmetric_key *key) @@ -147,14 +189,6 @@ static inline bool get_encryption_nonce(uint64_t *nonce, struct noise_symmetric_ return true; } -struct packet_bundle_ctx { - struct padata_priv padata; - struct sk_buff_head queue; - void (*callback)(struct sk_buff_head *, struct wireguard_peer *); - struct wireguard_peer *peer; - struct noise_keypair *keypair; -}; - static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_keypair *keypair) { struct sk_buff *skb; @@ -170,7 +204,7 @@ static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_ #ifdef CONFIG_WIREGUARD_PARALLEL static void do_encryption(struct padata_priv *padata) { - struct packet_bundle_ctx *ctx = container_of(padata, struct packet_bundle_ctx, padata); + struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata); queue_encrypt_reset(&ctx->queue, ctx->keypair); padata_do_serial(padata); @@ -178,11 +212,11 @@ static void do_encryption(struct padata_priv *padata) static void finish_encryption(struct padata_priv *padata) { - struct packet_bundle_ctx *ctx = container_of(padata, struct packet_bundle_ctx, padata); + struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata); ctx->callback(&ctx->queue, ctx->peer); peer_put(ctx->peer); - kfree(ctx); + kmem_cache_free(encryption_ctx_cache, ctx); } static inline int start_encryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu) @@ -220,33 +254,31 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, rcu_read_unlock(); skb_queue_walk(queue, skb) { - struct packet_data_encryption_ctx *ctx = (struct packet_data_encryption_ctx *)skb->cb; + struct encryption_skb_cb *cb = (struct encryption_skb_cb *)skb->cb; unsigned int padding_len, num_frags; - BUILD_BUG_ON(sizeof(struct packet_data_encryption_ctx) > sizeof(skb->cb)); - - if (unlikely(!get_encryption_nonce(&ctx->nonce, &keypair->sending))) + if (unlikely(!get_encryption_nonce(&cb->nonce, &keypair->sending))) goto err; padding_len = skb_padding(skb); - ctx->trailer_len = padding_len + noise_encrypted_len(0); - ctx->plaintext_len = skb->len + padding_len; + cb->trailer_len = padding_len + noise_encrypted_len(0); + cb->plaintext_len = skb->len + padding_len; /* Store the ds bit in the cb */ - ctx->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb); + cb->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb); /* Expand data section to have room for padding and auth tag */ - ret = skb_cow_data(skb, ctx->trailer_len, &ctx->trailer); + ret = skb_cow_data(skb, cb->trailer_len, &cb->trailer); if (unlikely(ret < 0)) goto err; num_frags = ret; ret = -ENOMEM; if (unlikely(num_frags > 128)) goto err; - ctx->num_frags = num_frags; + cb->num_frags = num_frags; /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */ - memset(skb_tail_pointer(ctx->trailer), 0, padding_len); + memset(skb_tail_pointer(cb->trailer), 0, padding_len); /* Expand head section to have room for our header and the network stack's headers. */ ret = skb_cow_head(skb, DATA_PACKET_HEAD_ROOM); @@ -265,9 +297,9 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, #ifdef CONFIG_WIREGUARD_PARALLEL if ((skb_queue_len(queue) > 1 || queue->next->len > 256 || padata_queue_len(peer->device->parallel_send) > 0) && cpumask_weight(cpu_online_mask) > 1) { unsigned int cpu = choose_cpu(keypair->remote_index); - struct packet_bundle_ctx *ctx = kmalloc(sizeof(struct packet_bundle_ctx), GFP_ATOMIC); + struct encryption_ctx *ctx = kmem_cache_alloc(encryption_ctx_cache, GFP_ATOMIC); if (!ctx) - goto serial; + goto serial_encrypt; skb_queue_head_init(&ctx->queue); skb_queue_splice_init(queue, &ctx->queue); ctx->callback = callback; @@ -281,13 +313,13 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, peer_put(ctx->peer); err_parallel: skb_queue_splice(&ctx->queue, queue); - kfree(ctx); + kmem_cache_free(encryption_ctx_cache, ctx); goto err; } } else #endif { -serial: +serial_encrypt: queue_encrypt_reset(queue, keypair); callback(queue, peer); } @@ -301,18 +333,7 @@ err_rcu: return ret; } -struct packet_data_decryption_ctx { - struct padata_priv padata; - struct sk_buff *skb; - void (*callback)(struct sk_buff *skb, struct wireguard_peer *, struct sockaddr_storage *, bool used_new_key, int err); - struct noise_keypair *keypair; - struct sockaddr_storage addr; - uint64_t nonce; - uint8_t num_frags; - int ret; -}; - -static void begin_decrypt_packet(struct packet_data_decryption_ctx *ctx) +static void begin_decrypt_packet(struct decryption_ctx *ctx) { if (unlikely(!skb_decrypt(ctx->skb, ctx->num_frags, ctx->nonce, &ctx->keypair->receiving))) goto err; @@ -326,7 +347,7 @@ err: peer_put(ctx->keypair->entry.peer); } -static void finish_decrypt_packet(struct packet_data_decryption_ctx *ctx) +static void finish_decrypt_packet(struct decryption_ctx *ctx) { struct noise_keypairs *keypairs; bool used_new_key = false; @@ -357,16 +378,16 @@ err: #ifdef CONFIG_WIREGUARD_PARALLEL static void do_decryption(struct padata_priv *padata) { - struct packet_data_decryption_ctx *ctx = container_of(padata, struct packet_data_decryption_ctx, padata); + struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata); begin_decrypt_packet(ctx); padata_do_serial(padata); } static void finish_decryption(struct padata_priv *padata) { - struct packet_data_decryption_ctx *ctx = container_of(padata, struct packet_data_decryption_ctx, padata); + struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata); finish_decrypt_packet(ctx); - kfree(ctx); + kmem_cache_free(decryption_ctx_cache, ctx); } static inline int start_decryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu) @@ -420,11 +441,11 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de rcu_read_unlock(); #ifdef CONFIG_WIREGUARD_PARALLEL if (cpumask_weight(cpu_online_mask) > 1) { - struct packet_data_decryption_ctx *ctx; unsigned int cpu = choose_cpu(idx); + struct decryption_ctx *ctx; ret = -ENOMEM; - ctx = kzalloc(sizeof(struct packet_data_decryption_ctx), GFP_ATOMIC); + ctx = kmem_cache_alloc(decryption_ctx_cache, GFP_ATOMIC); if (unlikely(!ctx)) goto err_peer; @@ -436,13 +457,13 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de ctx->addr = addr; ret = start_decryption(wg->parallel_receive, &ctx->padata, cpu); if (unlikely(ret)) { - kfree(ctx); + kmem_cache_free(decryption_ctx_cache, ctx); goto err_peer; } } else #endif { - struct packet_data_decryption_ctx ctx = { + struct decryption_ctx ctx = { .skb = skb, .keypair = keypair, .callback = callback, diff --git a/src/main.c b/src/main.c index 1399953..e381d09 100644 --- a/src/main.c +++ b/src/main.c @@ -29,10 +29,16 @@ static int __init mod_init(void) chacha20poly1305_init(); noise_init(); - ret = device_init(); + ret = packet_init_data_caches(); if (ret < 0) return ret; + ret = device_init(); + if (ret < 0) { + packet_deinit_data_caches(); + return ret; + } + pr_info("WireGuard loaded. See www.wireguard.io for information.\n"); pr_info("(C) Copyright 2015-2016 Jason A. Donenfeld . All Rights Reserved.\n"); return ret; @@ -41,6 +47,7 @@ static int __init mod_init(void) static void __exit mod_exit(void) { device_uninit(); + packet_deinit_data_caches(); pr_debug("WireGuard has been unloaded\n"); } diff --git a/src/packets.h b/src/packets.h index 31abb57..c9d82d1 100644 --- a/src/packets.h +++ b/src/packets.h @@ -41,6 +41,8 @@ void packet_send_queued_handshakes(struct work_struct *work); /* data.c */ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, void(*callback)(struct sk_buff_head *, struct wireguard_peer *)); void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_device *wg, void(*callback)(struct sk_buff *, struct wireguard_peer *, struct sockaddr_storage *, bool used_new_key, int err)); +int packet_init_data_caches(void); +void packet_deinit_data_caches(void); #define DATA_PACKET_HEAD_ROOM ALIGN(sizeof(struct message_data) + SKB_HEADER_LEN, 4) -- cgit v1.2.3-59-g8ed1b