diff options
Diffstat (limited to '')
-rw-r--r-- | src/data.c | 149 |
1 files changed, 95 insertions, 54 deletions
@@ -11,6 +11,7 @@ #include <linux/slab.h> #include <linux/bitmap.h> #include <linux/scatterlist.h> +#include <net/ip_tunnels.h> #include <net/xfrm.h> #include <crypto/algapi.h> @@ -75,11 +76,21 @@ static inline void skb_reset(struct sk_buff *skb) skb_probe_transport_header(skb, 0); } -static inline void skb_encrypt(struct sk_buff *skb, struct packet_data_encryption_ctx *ctx) +struct packet_data_encryption_ctx { + uint8_t ds; + uint8_t num_frags; + unsigned int plaintext_len, trailer_len; + struct sk_buff *trailer; + uint64_t nonce; +}; + +static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair) { + struct packet_data_encryption_ctx *ctx = (struct packet_data_encryption_ctx *)skb->cb; struct scatterlist sg[ctx->num_frags]; /* This should be bound to at most 128 by the caller. */ struct message_data *header; + /* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */ if (likely(!skb_checksum_setup(skb, true))) skb_checksum_help(skb); @@ -87,17 +98,14 @@ static inline void skb_encrypt(struct sk_buff *skb, struct packet_data_encryptio /* Only after checksumming can we safely add on the padding at the end and the header. */ header = (struct message_data *)skb_push(skb, sizeof(struct message_data)); header->header.type = MESSAGE_DATA; - header->key_idx = ctx->keypair->remote_index; + header->key_idx = keypair->remote_index; header->counter = cpu_to_le64(ctx->nonce); pskb_put(skb, ctx->trailer, ctx->trailer_len); /* Now we can encrypt the scattergather segments */ sg_init_table(sg, ctx->num_frags); skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(ctx->plaintext_len)); - chacha20poly1305_encrypt_sg(sg, sg, ctx->plaintext_len, NULL, 0, ctx->nonce, ctx->keypair->sending.key); - - /* When we're done, we free the reference to the key pair */ - noise_keypair_put(ctx->keypair); + chacha20poly1305_encrypt_sg(sg, sg, ctx->plaintext_len, NULL, 0, ctx->nonce, keypair->sending.key); } static inline bool skb_decrypt(struct sk_buff *skb, uint8_t num_frags, uint64_t nonce, struct noise_symmetric_key *key) @@ -140,23 +148,43 @@ static inline bool get_encryption_nonce(uint64_t *nonce, struct noise_symmetric_ return true; } +struct packet_bundle_ctx { + struct padata_priv padata; + struct sk_buff_head queue; + void (*callback)(struct sk_buff_head *, struct wireguard_peer *); + struct wireguard_peer *peer; + struct noise_keypair *keypair; +}; + +static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_keypair *keypair) +{ + struct sk_buff *skb; + /* TODO: as a later optimization, we can activate the FPU just once + * for the entire loop, rather than turning it on and off for each + * packet. */ + skb_queue_walk(queue, skb) { + skb_encrypt(skb, keypair); + skb_reset(skb); + } + noise_keypair_put(keypair); +} + #ifdef CONFIG_WIREGUARD_PARALLEL static void do_encryption(struct padata_priv *padata) { - struct packet_data_encryption_ctx *ctx = container_of(padata, struct packet_data_encryption_ctx, padata); - - skb_encrypt(ctx->skb, ctx); - skb_reset(ctx->skb); + struct packet_bundle_ctx *ctx = container_of(padata, struct packet_bundle_ctx, padata); + queue_encrypt_reset(&ctx->queue, ctx->keypair); padata_do_serial(padata); } static void finish_encryption(struct padata_priv *padata) { - struct packet_data_encryption_ctx *ctx = container_of(padata, struct packet_data_encryption_ctx, padata); + struct packet_bundle_ctx *ctx = container_of(padata, struct packet_bundle_ctx, padata); - ctx->callback(ctx->skb, ctx->peer); + ctx->callback(&ctx->queue, ctx->peer); peer_put(ctx->peer); + kfree(ctx); } static inline int start_encryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu) @@ -181,15 +209,11 @@ static inline unsigned int choose_cpu(__le32 key) } #endif -int packet_create_data(struct sk_buff *skb, struct wireguard_peer *peer, void(*callback)(struct sk_buff *, struct wireguard_peer *), bool parallel) +int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, void(*callback)(struct sk_buff_head *, struct wireguard_peer *)) { int ret = -ENOKEY; struct noise_keypair *keypair; - struct packet_data_encryption_ctx *ctx = NULL; - u64 nonce; - struct sk_buff *trailer = NULL; - unsigned int plaintext_len, padding_len, trailer_len; - unsigned int num_frags; + struct sk_buff *skb; rcu_read_lock(); keypair = noise_keypair_get(rcu_dereference(peer->keypairs.current_keypair)); @@ -197,60 +221,77 @@ int packet_create_data(struct sk_buff *skb, struct wireguard_peer *peer, void(*c goto err_rcu; rcu_read_unlock(); - if (unlikely(!get_encryption_nonce(&nonce, &keypair->sending))) - goto err; + skb_queue_walk(queue, skb) { + struct packet_data_encryption_ctx *ctx = (struct packet_data_encryption_ctx *)skb->cb; + unsigned int padding_len, num_frags; - padding_len = skb_padding(skb); - trailer_len = padding_len + noise_encrypted_len(0); - plaintext_len = skb->len + padding_len; + BUILD_BUG_ON(sizeof(struct packet_data_encryption_ctx) > sizeof(skb->cb)); - /* Expand data section to have room for padding and auth tag */ - ret = skb_cow_data(skb, trailer_len, &trailer); - if (unlikely(ret < 0)) - goto err; - num_frags = ret; - ret = -ENOMEM; - if (unlikely(num_frags > 128)) - goto err; + if (unlikely(!get_encryption_nonce(&ctx->nonce, &keypair->sending))) + goto err; - /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */ - memset(skb_tail_pointer(trailer), 0, padding_len); + padding_len = skb_padding(skb); + ctx->trailer_len = padding_len + noise_encrypted_len(0); + ctx->plaintext_len = skb->len + padding_len; - /* Expand head section to have room for our header and the network stack's headers, - * plus our key and nonce in the head. */ - ret = skb_cow_head(skb, DATA_PACKET_HEAD_ROOM); - if (unlikely(ret < 0)) - goto err; + /* Store the ds bit in the cb */ + ctx->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb); - ctx = (struct packet_data_encryption_ctx *)skb->head; - ctx->skb = skb; - ctx->callback = callback; - ctx->peer = peer; - ctx->num_frags = num_frags; - ctx->trailer_len = trailer_len; - ctx->trailer = trailer; - ctx->plaintext_len = plaintext_len; - ctx->nonce = nonce; - ctx->keypair = keypair; + /* Expand data section to have room for padding and auth tag */ + ret = skb_cow_data(skb, ctx->trailer_len, &ctx->trailer); + if (unlikely(ret < 0)) + goto err; + num_frags = ret; + ret = -ENOMEM; + if (unlikely(num_frags > 128)) + goto err; + ctx->num_frags = num_frags; + + /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */ + memset(skb_tail_pointer(ctx->trailer), 0, padding_len); + + /* Expand head section to have room for our header and the network stack's headers. */ + ret = skb_cow_head(skb, DATA_PACKET_HEAD_ROOM); + if (unlikely(ret < 0)) + goto err; + + /* After the first time through the loop, if we've suceeded with a legitimate nonce, + * then we don't want a -ENOKEY error if subsequent nonces fail. Rather, if this + * condition arises, we simply want error out hard, and drop the entire queue. This + * is partially lazy programming and TODO: this could be made to only requeue the + * ones that had no nonce. But I'm not sure it's worth the added complexity, given + * how rarely that condition should arise. */ + ret = -EPIPE; + } #ifdef CONFIG_WIREGUARD_PARALLEL - if ((parallel || padata_queue_len(peer->device->parallel_send) > 0) && cpumask_weight(cpu_online_mask) > 1) { + if ((skb_queue_len(queue) > 1 || queue->next->len > 256 || padata_queue_len(peer->device->parallel_send) > 0) && cpumask_weight(cpu_online_mask) > 1) { unsigned int cpu = choose_cpu(keypair->remote_index); - ret = -EBUSY; + struct packet_bundle_ctx *ctx = kmalloc(sizeof(struct packet_bundle_ctx), GFP_ATOMIC); + if (!ctx) + goto serial; + skb_queue_head_init(&ctx->queue); + skb_queue_splice_init(queue, &ctx->queue); + ctx->callback = callback; + ctx->keypair = keypair; ctx->peer = peer_rcu_get(peer); + ret = -EBUSY; if (unlikely(!ctx->peer)) - goto err; + goto err_parallel; ret = start_encryption(peer->device->parallel_send, &ctx->padata, cpu); if (unlikely(ret < 0)) { peer_put(ctx->peer); +err_parallel: + skb_queue_splice(&ctx->queue, queue); + kfree(ctx); goto err; } } else #endif { - skb_encrypt(skb, ctx); - skb_reset(skb); - callback(skb, peer); +serial: + queue_encrypt_reset(queue, keypair); + callback(queue, peer); } return 0; |