aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--src/data.c141
-rw-r--r--src/main.c9
-rw-r--r--src/packets.h2
3 files changed, 91 insertions, 61 deletions
diff --git a/src/data.c b/src/data.c
index 1f6c5b6..e027510 100644
--- a/src/data.c
+++ b/src/data.c
@@ -15,6 +15,56 @@
#include <net/xfrm.h>
#include <crypto/algapi.h>
+struct encryption_skb_cb {
+ uint8_t ds;
+ uint8_t num_frags;
+ unsigned int plaintext_len, trailer_len;
+ struct sk_buff *trailer;
+ uint64_t nonce;
+};
+
+struct encryption_ctx {
+ struct padata_priv padata;
+ struct sk_buff_head queue;
+ void (*callback)(struct sk_buff_head *, struct wireguard_peer *);
+ struct wireguard_peer *peer;
+ struct noise_keypair *keypair;
+};
+
+struct decryption_ctx {
+ struct padata_priv padata;
+ struct sk_buff *skb;
+ void (*callback)(struct sk_buff *skb, struct wireguard_peer *, struct sockaddr_storage *, bool used_new_key, int err);
+ struct noise_keypair *keypair;
+ struct sockaddr_storage addr;
+ uint64_t nonce;
+ uint8_t num_frags;
+ int ret;
+};
+
+static struct kmem_cache *encryption_ctx_cache;
+static struct kmem_cache *decryption_ctx_cache;
+
+int packet_init_data_caches(void)
+{
+ BUILD_BUG_ON(sizeof(struct encryption_skb_cb) > sizeof(((struct sk_buff *)0)->cb));
+ encryption_ctx_cache = kmem_cache_create("wireguard_encryption_ctx", sizeof(struct encryption_ctx), 0, 0, NULL);
+ if (!encryption_ctx_cache)
+ return -ENOMEM;
+ decryption_ctx_cache = kmem_cache_create("wireguard_decryption_ctx", sizeof(struct decryption_ctx), 0, 0, NULL);
+ if (!decryption_ctx_cache) {
+ kmem_cache_destroy(encryption_ctx_cache);
+ return -ENOMEM;
+ }
+ return 0;
+}
+
+void packet_deinit_data_caches(void)
+{
+ kmem_cache_destroy(encryption_ctx_cache);
+ kmem_cache_destroy(decryption_ctx_cache);
+}
+
/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
static inline bool counter_validate(union noise_counter *counter, u64 their_counter)
{
@@ -76,18 +126,10 @@ static inline void skb_reset(struct sk_buff *skb)
skb_probe_transport_header(skb, 0);
}
-struct packet_data_encryption_ctx {
- uint8_t ds;
- uint8_t num_frags;
- unsigned int plaintext_len, trailer_len;
- struct sk_buff *trailer;
- uint64_t nonce;
-};
-
static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, bool have_simd)
{
- struct packet_data_encryption_ctx *ctx = (struct packet_data_encryption_ctx *)skb->cb;
- struct scatterlist sg[ctx->num_frags]; /* This should be bound to at most 128 by the caller. */
+ struct encryption_skb_cb *cb = (struct encryption_skb_cb *)skb->cb;
+ struct scatterlist sg[cb->num_frags]; /* This should be bound to at most 128 by the caller. */
struct message_data *header;
/* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */
@@ -98,13 +140,13 @@ static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypai
header = (struct message_data *)skb_push(skb, sizeof(struct message_data));
header->header.type = MESSAGE_DATA;
header->key_idx = keypair->remote_index;
- header->counter = cpu_to_le64(ctx->nonce);
- pskb_put(skb, ctx->trailer, ctx->trailer_len);
+ header->counter = cpu_to_le64(cb->nonce);
+ pskb_put(skb, cb->trailer, cb->trailer_len);
/* Now we can encrypt the scattergather segments */
- sg_init_table(sg, ctx->num_frags);
- skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(ctx->plaintext_len));
- chacha20poly1305_encrypt_sg(sg, sg, ctx->plaintext_len, NULL, 0, ctx->nonce, keypair->sending.key, have_simd);
+ sg_init_table(sg, cb->num_frags);
+ skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(cb->plaintext_len));
+ chacha20poly1305_encrypt_sg(sg, sg, cb->plaintext_len, NULL, 0, cb->nonce, keypair->sending.key, have_simd);
}
static inline bool skb_decrypt(struct sk_buff *skb, uint8_t num_frags, uint64_t nonce, struct noise_symmetric_key *key)
@@ -147,14 +189,6 @@ static inline bool get_encryption_nonce(uint64_t *nonce, struct noise_symmetric_
return true;
}
-struct packet_bundle_ctx {
- struct padata_priv padata;
- struct sk_buff_head queue;
- void (*callback)(struct sk_buff_head *, struct wireguard_peer *);
- struct wireguard_peer *peer;
- struct noise_keypair *keypair;
-};
-
static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_keypair *keypair)
{
struct sk_buff *skb;
@@ -170,7 +204,7 @@ static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_
#ifdef CONFIG_WIREGUARD_PARALLEL
static void do_encryption(struct padata_priv *padata)
{
- struct packet_bundle_ctx *ctx = container_of(padata, struct packet_bundle_ctx, padata);
+ struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata);
queue_encrypt_reset(&ctx->queue, ctx->keypair);
padata_do_serial(padata);
@@ -178,11 +212,11 @@ static void do_encryption(struct padata_priv *padata)
static void finish_encryption(struct padata_priv *padata)
{
- struct packet_bundle_ctx *ctx = container_of(padata, struct packet_bundle_ctx, padata);
+ struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata);
ctx->callback(&ctx->queue, ctx->peer);
peer_put(ctx->peer);
- kfree(ctx);
+ kmem_cache_free(encryption_ctx_cache, ctx);
}
static inline int start_encryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu)
@@ -220,33 +254,31 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer,
rcu_read_unlock();
skb_queue_walk(queue, skb) {
- struct packet_data_encryption_ctx *ctx = (struct packet_data_encryption_ctx *)skb->cb;
+ struct encryption_skb_cb *cb = (struct encryption_skb_cb *)skb->cb;
unsigned int padding_len, num_frags;
- BUILD_BUG_ON(sizeof(struct packet_data_encryption_ctx) > sizeof(skb->cb));
-
- if (unlikely(!get_encryption_nonce(&ctx->nonce, &keypair->sending)))
+ if (unlikely(!get_encryption_nonce(&cb->nonce, &keypair->sending)))
goto err;
padding_len = skb_padding(skb);
- ctx->trailer_len = padding_len + noise_encrypted_len(0);
- ctx->plaintext_len = skb->len + padding_len;
+ cb->trailer_len = padding_len + noise_encrypted_len(0);
+ cb->plaintext_len = skb->len + padding_len;
/* Store the ds bit in the cb */
- ctx->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb);
+ cb->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb);
/* Expand data section to have room for padding and auth tag */
- ret = skb_cow_data(skb, ctx->trailer_len, &ctx->trailer);
+ ret = skb_cow_data(skb, cb->trailer_len, &cb->trailer);
if (unlikely(ret < 0))
goto err;
num_frags = ret;
ret = -ENOMEM;
if (unlikely(num_frags > 128))
goto err;
- ctx->num_frags = num_frags;
+ cb->num_frags = num_frags;
/* Set the padding to zeros, and make sure it and the auth tag are part of the skb */
- memset(skb_tail_pointer(ctx->trailer), 0, padding_len);
+ memset(skb_tail_pointer(cb->trailer), 0, padding_len);
/* Expand head section to have room for our header and the network stack's headers. */
ret = skb_cow_head(skb, DATA_PACKET_HEAD_ROOM);
@@ -265,9 +297,9 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer,
#ifdef CONFIG_WIREGUARD_PARALLEL
if ((skb_queue_len(queue) > 1 || queue->next->len > 256 || padata_queue_len(peer->device->parallel_send) > 0) && cpumask_weight(cpu_online_mask) > 1) {
unsigned int cpu = choose_cpu(keypair->remote_index);
- struct packet_bundle_ctx *ctx = kmalloc(sizeof(struct packet_bundle_ctx), GFP_ATOMIC);
+ struct encryption_ctx *ctx = kmem_cache_alloc(encryption_ctx_cache, GFP_ATOMIC);
if (!ctx)
- goto serial;
+ goto serial_encrypt;
skb_queue_head_init(&ctx->queue);
skb_queue_splice_init(queue, &ctx->queue);
ctx->callback = callback;
@@ -281,13 +313,13 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer,
peer_put(ctx->peer);
err_parallel:
skb_queue_splice(&ctx->queue, queue);
- kfree(ctx);
+ kmem_cache_free(encryption_ctx_cache, ctx);
goto err;
}
} else
#endif
{
-serial:
+serial_encrypt:
queue_encrypt_reset(queue, keypair);
callback(queue, peer);
}
@@ -301,18 +333,7 @@ err_rcu:
return ret;
}
-struct packet_data_decryption_ctx {
- struct padata_priv padata;
- struct sk_buff *skb;
- void (*callback)(struct sk_buff *skb, struct wireguard_peer *, struct sockaddr_storage *, bool used_new_key, int err);
- struct noise_keypair *keypair;
- struct sockaddr_storage addr;
- uint64_t nonce;
- uint8_t num_frags;
- int ret;
-};
-
-static void begin_decrypt_packet(struct packet_data_decryption_ctx *ctx)
+static void begin_decrypt_packet(struct decryption_ctx *ctx)
{
if (unlikely(!skb_decrypt(ctx->skb, ctx->num_frags, ctx->nonce, &ctx->keypair->receiving)))
goto err;
@@ -326,7 +347,7 @@ err:
peer_put(ctx->keypair->entry.peer);
}
-static void finish_decrypt_packet(struct packet_data_decryption_ctx *ctx)
+static void finish_decrypt_packet(struct decryption_ctx *ctx)
{
struct noise_keypairs *keypairs;
bool used_new_key = false;
@@ -357,16 +378,16 @@ err:
#ifdef CONFIG_WIREGUARD_PARALLEL
static void do_decryption(struct padata_priv *padata)
{
- struct packet_data_decryption_ctx *ctx = container_of(padata, struct packet_data_decryption_ctx, padata);
+ struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata);
begin_decrypt_packet(ctx);
padata_do_serial(padata);
}
static void finish_decryption(struct padata_priv *padata)
{
- struct packet_data_decryption_ctx *ctx = container_of(padata, struct packet_data_decryption_ctx, padata);
+ struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata);
finish_decrypt_packet(ctx);
- kfree(ctx);
+ kmem_cache_free(decryption_ctx_cache, ctx);
}
static inline int start_decryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu)
@@ -420,11 +441,11 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de
rcu_read_unlock();
#ifdef CONFIG_WIREGUARD_PARALLEL
if (cpumask_weight(cpu_online_mask) > 1) {
- struct packet_data_decryption_ctx *ctx;
unsigned int cpu = choose_cpu(idx);
+ struct decryption_ctx *ctx;
ret = -ENOMEM;
- ctx = kzalloc(sizeof(struct packet_data_decryption_ctx), GFP_ATOMIC);
+ ctx = kmem_cache_alloc(decryption_ctx_cache, GFP_ATOMIC);
if (unlikely(!ctx))
goto err_peer;
@@ -436,13 +457,13 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de
ctx->addr = addr;
ret = start_decryption(wg->parallel_receive, &ctx->padata, cpu);
if (unlikely(ret)) {
- kfree(ctx);
+ kmem_cache_free(decryption_ctx_cache, ctx);
goto err_peer;
}
} else
#endif
{
- struct packet_data_decryption_ctx ctx = {
+ struct decryption_ctx ctx = {
.skb = skb,
.keypair = keypair,
.callback = callback,
diff --git a/src/main.c b/src/main.c
index 1399953..e381d09 100644
--- a/src/main.c
+++ b/src/main.c
@@ -29,10 +29,16 @@ static int __init mod_init(void)
chacha20poly1305_init();
noise_init();
- ret = device_init();
+ ret = packet_init_data_caches();
if (ret < 0)
return ret;
+ ret = device_init();
+ if (ret < 0) {
+ packet_deinit_data_caches();
+ return ret;
+ }
+
pr_info("WireGuard loaded. See www.wireguard.io for information.\n");
pr_info("(C) Copyright 2015-2016 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.\n");
return ret;
@@ -41,6 +47,7 @@ static int __init mod_init(void)
static void __exit mod_exit(void)
{
device_uninit();
+ packet_deinit_data_caches();
pr_debug("WireGuard has been unloaded\n");
}
diff --git a/src/packets.h b/src/packets.h
index 31abb57..c9d82d1 100644
--- a/src/packets.h
+++ b/src/packets.h
@@ -41,6 +41,8 @@ void packet_send_queued_handshakes(struct work_struct *work);
/* data.c */
int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, void(*callback)(struct sk_buff_head *, struct wireguard_peer *));
void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_device *wg, void(*callback)(struct sk_buff *, struct wireguard_peer *, struct sockaddr_storage *, bool used_new_key, int err));
+int packet_init_data_caches(void);
+void packet_deinit_data_caches(void);
#define DATA_PACKET_HEAD_ROOM ALIGN(sizeof(struct message_data) + SKB_HEADER_LEN, 4)