aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src/data.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/data.c')
-rw-r--r--src/data.c477
1 files changed, 477 insertions, 0 deletions
diff --git a/src/data.c b/src/data.c
new file mode 100644
index 0000000..5b3c781
--- /dev/null
+++ b/src/data.c
@@ -0,0 +1,477 @@
+/* Copyright 2015-2016 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
+
+#include "wireguard.h"
+#include "noise.h"
+#include "messages.h"
+#include "packets.h"
+#include "hashtables.h"
+#include <crypto/algapi.h>
+#include <net/xfrm.h>
+#include <linux/rcupdate.h>
+#include <linux/slab.h>
+#include <linux/bitmap.h>
+#include <linux/scatterlist.h>
+
+/* This is appendix C of RFC 2401 - a sliding window bitmap. */
+static inline bool counter_validate(union noise_counter *counter, u64 their_counter)
+{
+ bool ret = false;
+ u64 difference;
+ spin_lock_bh(&counter->receive.lock);
+
+ if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || their_counter >= REJECT_AFTER_MESSAGES))
+ goto out;
+
+ ++their_counter;
+
+ if (likely(their_counter > counter->receive.counter)) {
+ difference = their_counter - counter->receive.counter;
+ if (likely(difference < BITS_PER_LONG)) {
+ counter->receive.backtrack <<= difference;
+ counter->receive.backtrack |= 1;
+ } else
+ counter->receive.backtrack = 1;
+ counter->receive.counter = their_counter;
+ ret = true;
+ goto out;
+ }
+
+ difference = counter->receive.counter - their_counter;
+ if (unlikely(difference >= BITS_PER_LONG))
+ goto out;
+ ret = !test_and_set_bit(difference, &counter->receive.backtrack);
+
+out:
+ spin_unlock_bh(&counter->receive.lock);
+ return ret;
+}
+
+#ifdef DEBUG
+void packet_counter_selftest(void)
+{
+ bool success = true;
+ unsigned int i = 0;
+ union noise_counter counter = { { 0 } };
+ spin_lock_init(&counter.receive.lock);
+
+#define T(n, v) do { ++i; if (counter_validate(&counter, n) != v) { pr_info("nonce counter self-test %u: FAIL\n", i); success = false; } } while (0)
+ T(0, true);
+ T(1, true);
+ T(1, false);
+ T(9, true);
+ T(8, true);
+ T(7, true);
+ T(7, false);
+ T(BITS_PER_LONG, true);
+ T(BITS_PER_LONG - 1, true);
+ T(BITS_PER_LONG - 1, false);
+ T(BITS_PER_LONG - 2, true);
+ T(2, true);
+ T(2, false);
+ T(BITS_PER_LONG + 16, true);
+ T(3, false);
+ T(BITS_PER_LONG + 16, false);
+ T(BITS_PER_LONG * 4, true);
+ T(BITS_PER_LONG * 4 - (BITS_PER_LONG - 1), true);
+ T(10, false);
+ T(BITS_PER_LONG * 4 - BITS_PER_LONG, false);
+ T(BITS_PER_LONG * 4 - (BITS_PER_LONG + 1), false);
+ T(BITS_PER_LONG * 4 - (BITS_PER_LONG - 2), true);
+ T(BITS_PER_LONG * 4 + 1 - BITS_PER_LONG, false);
+ T(0, false);
+ T(REJECT_AFTER_MESSAGES, false);
+ T(REJECT_AFTER_MESSAGES - 1, true);
+ T(REJECT_AFTER_MESSAGES, false);
+ T(REJECT_AFTER_MESSAGES - 1, false);
+ T(REJECT_AFTER_MESSAGES - 2, true);
+ T(REJECT_AFTER_MESSAGES + 1, false);
+ T(REJECT_AFTER_MESSAGES + 2, false);
+ T(REJECT_AFTER_MESSAGES - 2, false);
+ T(REJECT_AFTER_MESSAGES - 3, true);
+ T(0, false);
+#undef T
+
+ if (success)
+ pr_info("nonce counter self-tests: pass\n");
+}
+#endif
+
+static inline size_t skb_padding(struct sk_buff *skb)
+{
+ /* We do this modulo business with the MTU, just in case the networking layer
+ * gives us a packet that's bigger than the MTU. Now that we support GSO, this
+ * shouldn't be a real problem, and this can likely be removed. But, caution! */
+ size_t last_unit = skb->len % skb->dev->mtu;
+ size_t padded_size = (last_unit + MESSAGE_PADDING_MULTIPLE - 1) & ~(MESSAGE_PADDING_MULTIPLE - 1);
+ if (padded_size > skb->dev->mtu)
+ padded_size = skb->dev->mtu;
+ return padded_size - last_unit;
+}
+
+static inline void skb_reset(struct sk_buff *skb)
+{
+ skb_scrub_packet(skb, false);
+ memset(&skb->headers_start, 0, offsetof(struct sk_buff, headers_end) - offsetof(struct sk_buff, headers_start));
+ skb->queue_mapping = 0;
+ skb->nohdr = 0;
+ skb->peeked = 0;
+ skb->mac_len = 0;
+ skb->dev = NULL;
+ skb->hdr_len = skb_headroom(skb);
+ skb->mac_header = (typeof(skb->mac_header))~0U;
+ skb->transport_header = (typeof(skb->transport_header))~0U;
+ skb_reset_network_header(skb);
+}
+
+static inline void skb_encrypt(struct sk_buff *skb, struct packet_data_encryption_ctx *ctx)
+{
+ struct scatterlist sg[ctx->num_frags]; /* This should be bound to at most 128 by the caller. */
+ struct message_data *header;
+
+ /* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */
+ if (likely(!skb_checksum_setup(skb, true)))
+ skb_checksum_help(skb);
+
+ /* Only after checksumming can we safely add on the padding at the end and the header. */
+ header = (struct message_data *)skb_push(skb, sizeof(struct message_data));
+ header->header.type = MESSAGE_DATA;
+ header->key_idx = ctx->keypair->remote_index;
+ header->counter = cpu_to_le64(ctx->nonce);
+ pskb_put(skb, ctx->trailer, ctx->trailer_len);
+
+ /* Now we can encrypt the scattergather segments */
+ sg_init_table(sg, ctx->num_frags);
+ skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(ctx->plaintext_len));
+ chacha20poly1305_encrypt_sg(sg, sg, ctx->plaintext_len, NULL, 0, ctx->nonce, ctx->keypair->sending.key);
+
+ /* When we're done, we free the reference to the key pair */
+ noise_keypair_put(ctx->keypair);
+}
+
+static inline bool skb_decrypt(struct sk_buff *skb, unsigned int num_frags, uint64_t nonce, struct noise_symmetric_key *key)
+{
+ struct scatterlist sg[num_frags]; /* This should be bound to at most 128 by the caller. */
+
+ if (unlikely(!key))
+ return false;
+
+ if (unlikely(!key->is_valid || time_is_before_eq_jiffies64(key->birthdate + REJECT_AFTER_TIME) || key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
+ key->is_valid = false;
+ return false;
+ }
+
+ sg_init_table(sg, num_frags);
+ skb_to_sgvec(skb, sg, 0, skb->len);
+
+ if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, nonce, key->key))
+ return false;
+
+ return pskb_trim(skb, skb->len - noise_encrypted_len(0)) == 0;
+}
+
+static inline bool get_encryption_nonce(uint64_t *nonce, struct noise_symmetric_key *key)
+{
+ if (unlikely(!key))
+ return false;
+
+ if (unlikely(!key->is_valid || time_is_before_eq_jiffies64(key->birthdate + REJECT_AFTER_TIME))) {
+ key->is_valid = false;
+ return false;
+ }
+
+ *nonce = atomic64_inc_return(&key->counter.counter) - 1;
+ if (*nonce >= REJECT_AFTER_MESSAGES) {
+ key->is_valid = false;
+ return false;
+ }
+
+ return true;
+}
+
+#ifdef CONFIG_WIREGUARD_PARALLEL
+static void do_encryption(struct padata_priv *padata)
+{
+ struct packet_data_encryption_ctx *ctx = container_of(padata, struct packet_data_encryption_ctx, padata);
+
+ skb_encrypt(ctx->skb, ctx);
+ skb_reset(ctx->skb);
+
+ padata_do_serial(padata);
+}
+
+static void finish_encryption(struct padata_priv *padata)
+{
+ struct packet_data_encryption_ctx *ctx = container_of(padata, struct packet_data_encryption_ctx, padata);
+
+ ctx->callback(ctx->skb, ctx->peer);
+}
+
+static inline int start_encryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu)
+{
+ memset(priv, 0, sizeof(struct padata_priv));
+ priv->parallel = do_encryption;
+ priv->serial = finish_encryption;
+ return padata_do_parallel(padata, priv, cb_cpu);
+}
+
+static inline unsigned int choose_cpu(__le32 key)
+{
+ unsigned int cpu_index, cpu, cb_cpu;
+
+ /* This ensures that packets encrypted to the same key are sent in-order. */
+ cpu_index = ((__force unsigned int)key) % cpumask_weight(cpu_online_mask);
+ cb_cpu = cpumask_first(cpu_online_mask);
+ for (cpu = 0; cpu < cpu_index; ++cpu)
+ cb_cpu = cpumask_next(cb_cpu, cpu_online_mask);
+
+ return cb_cpu;
+}
+#endif
+
+int packet_create_data(struct sk_buff *skb, struct wireguard_peer *peer, void(*callback)(struct sk_buff *, struct wireguard_peer *), bool parallel)
+{
+ int ret = -ENOKEY;
+ struct noise_keypair *keypair;
+ struct packet_data_encryption_ctx *ctx = NULL;
+ u64 nonce;
+ struct sk_buff *trailer = NULL;
+ size_t plaintext_len, padding_len, trailer_len;
+ unsigned int num_frags;
+
+ rcu_read_lock();
+ keypair = rcu_dereference(peer->keypairs.current_keypair);
+ if (unlikely(!keypair))
+ goto err_rcu;
+ kref_get(&keypair->refcount);
+ rcu_read_unlock();
+
+ if (unlikely(!get_encryption_nonce(&nonce, &keypair->sending)))
+ goto err;
+
+ padding_len = skb_padding(skb);
+ trailer_len = padding_len + noise_encrypted_len(0);
+ plaintext_len = skb->len + padding_len;
+
+ /* Expand data section to have room for padding and auth tag */
+ ret = skb_cow_data(skb, trailer_len, &trailer);
+ if (unlikely(ret < 0))
+ goto err;
+ num_frags = ret;
+ ret = -ENOMEM;
+ if (unlikely(num_frags > 128))
+ goto err;
+
+ /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */
+ memset(skb_tail_pointer(trailer), 0, padding_len);
+
+ /* Expand head section to have room for our header and the network stack's headers,
+ * plus our key and nonce in the head. */
+ ret = skb_cow_head(skb, DATA_PACKET_HEAD_ROOM);
+ if (unlikely(ret < 0))
+ goto err;
+
+ ctx = (struct packet_data_encryption_ctx *)skb->head;
+ ctx->skb = skb;
+ ctx->callback = callback;
+ ctx->peer = peer;
+ ctx->num_frags = num_frags;
+ ctx->trailer_len = trailer_len;
+ ctx->trailer = trailer;
+ ctx->plaintext_len = plaintext_len;
+ ctx->nonce = nonce;
+ ctx->keypair = keypair;
+
+#ifdef CONFIG_WIREGUARD_PARALLEL
+ if (parallel && cpumask_weight(cpu_online_mask) > 1) {
+ unsigned int cpu = choose_cpu(keypair->remote_index);
+ ret = start_encryption(peer->device->parallel_send, &ctx->padata, cpu);
+ if (unlikely(ret < 0))
+ goto err;
+ } else
+#endif
+ {
+ skb_encrypt(skb, ctx);
+ skb_reset(skb);
+ callback(skb, peer);
+ }
+ return 0;
+
+err:
+ noise_keypair_put(keypair);
+ return ret;
+err_rcu:
+ rcu_read_unlock();
+ return ret;
+}
+
+struct packet_data_decryption_ctx {
+ struct padata_priv padata;
+ struct sk_buff *skb;
+ void (*callback)(struct sk_buff *skb, struct wireguard_peer *, struct sockaddr_storage *, bool used_new_key, int err);
+ struct noise_keypair *keypair;
+ struct sockaddr_storage addr;
+ uint64_t nonce;
+ unsigned int num_frags;
+ int ret;
+};
+
+static void begin_decrypt_packet(struct packet_data_decryption_ctx *ctx)
+{
+ if (unlikely(!skb_decrypt(ctx->skb, ctx->num_frags, ctx->nonce, &ctx->keypair->receiving)))
+ goto err;
+
+ skb_reset(ctx->skb);
+ ctx->ret = 0;
+ return;
+
+err:
+ ctx->ret = -ENOKEY;
+ peer_put(ctx->keypair->entry.peer);
+}
+
+static void finish_decrypt_packet(struct packet_data_decryption_ctx *ctx)
+{
+ struct noise_keypairs *keypairs;
+ bool used_new_key = false;
+ int ret = ctx->ret;
+ if (ret)
+ goto err;
+
+ keypairs = &ctx->keypair->entry.peer->keypairs;
+ ret = counter_validate(&ctx->keypair->receiving.counter, ctx->nonce) ? 0 : -ERANGE;
+
+ if (likely(!ret))
+ used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair);
+ else {
+ /* TODO: currently either the nonce window is not big enough, or we're sending things in
+ * the wrong order. Try uncommenting the below code to see for yourself. This is a problem
+ * that needs to be solved.
+ *
+ * Debug with:
+ * #define XSTR(s) STR(s)
+ * #define STR(s) #s
+ * net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu, backtrack %" XSTR(BITS_PER_LONG) "pbl)\n", ctx->nonce, ctx->keypair->receiving.counter.receive.counter, &ctx->keypair->receiving.counter.receive.backtrack);
+ */
+ peer_put(ctx->keypair->entry.peer);
+ goto err;
+ }
+
+ noise_keypair_put(ctx->keypair);
+ ctx->callback(ctx->skb, ctx->keypair->entry.peer, &ctx->addr, used_new_key, 0);
+ return;
+
+err:
+ noise_keypair_put(ctx->keypair);
+ ctx->callback(ctx->skb, NULL, NULL, false, ret);
+}
+
+#ifdef CONFIG_WIREGUARD_PARALLEL
+static void do_decryption(struct padata_priv *padata)
+{
+ struct packet_data_decryption_ctx *ctx = container_of(padata, struct packet_data_decryption_ctx, padata);
+ begin_decrypt_packet(ctx);
+ padata_do_serial(padata);
+}
+
+static void finish_decryption(struct padata_priv *padata)
+{
+ struct packet_data_decryption_ctx *ctx = container_of(padata, struct packet_data_decryption_ctx, padata);
+ finish_decrypt_packet(ctx);
+ kfree(ctx);
+}
+
+static inline int start_decryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu)
+{
+ priv->parallel = do_decryption;
+ priv->serial = finish_decryption;
+ return padata_do_parallel(padata, priv, cb_cpu);
+}
+#endif
+
+void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_device *wg, void(*callback)(struct sk_buff *skb, struct wireguard_peer *, struct sockaddr_storage *, bool used_new_key, int err))
+{
+ int ret;
+ struct sockaddr_storage addr = { 0 };
+ unsigned int num_frags;
+ struct sk_buff *trailer;
+ struct message_data *header;
+ struct noise_keypair *keypair;
+ uint64_t nonce;
+ __le32 idx;
+
+ ret = socket_addr_from_skb(&addr, skb);
+ if (unlikely(ret < 0))
+ goto err;
+
+ ret = -ENOMEM;
+ if (unlikely(!pskb_may_pull(skb, offset + sizeof(struct message_data))))
+ goto err;
+
+ header = (struct message_data *)(skb->data + offset);
+ offset += sizeof(struct message_data);
+ skb_pull(skb, offset);
+
+ idx = header->key_idx;
+ nonce = le64_to_cpu(header->counter);
+
+ ret = skb_cow_data(skb, 0, &trailer);
+ if (unlikely(ret < 0))
+ goto err;
+ num_frags = ret;
+ ret = -ENOMEM;
+ if (unlikely(num_frags > 128))
+ goto err;
+ ret = -EINVAL;
+ rcu_read_lock();
+ keypair = (struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx);
+ if (unlikely(!keypair)) {
+ rcu_read_unlock();
+ goto err;
+ }
+ kref_get(&keypair->refcount);
+ rcu_read_unlock();
+#ifdef CONFIG_WIREGUARD_PARALLEL
+ if (cpumask_weight(cpu_online_mask) > 1) {
+ struct packet_data_decryption_ctx *ctx;
+ unsigned int cpu = choose_cpu(idx);
+
+ ret = -ENOMEM;
+ ctx = kzalloc(sizeof(struct packet_data_decryption_ctx), GFP_ATOMIC);
+ if (unlikely(!ctx))
+ goto err_peer;
+
+ ctx->skb = skb;
+ ctx->keypair = keypair;
+ ctx->callback = callback;
+ ctx->nonce = nonce;
+ ctx->num_frags = num_frags;
+ ctx->addr = addr;
+ ret = start_decryption(wg->parallel_receive, &ctx->padata, cpu);
+ if (unlikely(ret)) {
+ kfree(ctx);
+ goto err_peer;
+ }
+ } else
+#endif
+ {
+ struct packet_data_decryption_ctx ctx = {
+ .skb = skb,
+ .keypair = keypair,
+ .callback = callback,
+ .nonce = nonce,
+ .num_frags = num_frags,
+ .addr = addr
+ };
+ begin_decrypt_packet(&ctx);
+ finish_decrypt_packet(&ctx);
+ }
+ return;
+
+#ifdef CONFIG_WIREGUARD_PARALLEL
+err_peer:
+ peer_put(keypair->entry.peer);
+ noise_keypair_put(keypair);
+#endif
+err:
+ callback(skb, NULL, NULL, false, ret);
+}