summaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2017-03-15 19:20:58 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2017-03-20 01:02:06 +0100
commite16ba338168e00b0c2702ec5529280301c514d67 (patch)
tree8e92e5f428f263cee69f42a922047c234984f03a
parentcurve25519: add AVX implementation (diff)
downloadwireguard-monolithic-historical-e16ba338168e00b0c2702ec5529280301c514d67.tar.xz
wireguard-monolithic-historical-e16ba338168e00b0c2702ec5529280301c514d67.zip
data: big refactoring
-rw-r--r--src/cookie.c12
-rw-r--r--src/cookie.h4
-rw-r--r--src/data.c167
-rw-r--r--src/messages.h22
-rw-r--r--src/packets.h10
-rw-r--r--src/ratelimiter.c2
-rw-r--r--src/receive.c91
-rw-r--r--src/send.c9
8 files changed, 158 insertions, 159 deletions
diff --git a/src/cookie.c b/src/cookie.c
index 1c188c6..66f5d45 100644
--- a/src/cookie.c
+++ b/src/cookie.c
@@ -103,12 +103,12 @@ static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, struct cooki
up_read(&checker->secret_lock);
}
-enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, void *data_start, size_t data_len, bool check_cookie)
+enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie)
{
u8 computed_mac[COOKIE_LEN];
u8 cookie[COOKIE_LEN];
enum cookie_mac_state ret;
- struct message_macs *macs = (struct message_macs *)((u8 *)data_start + data_len - sizeof(struct message_macs));
+ struct message_macs *macs = (struct message_macs *)(skb->data + skb->len - sizeof(struct message_macs));
ret = INVALID_MAC;
down_read(&checker->device->static_identity.lock);
@@ -116,7 +116,7 @@ enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, str
up_read(&checker->device->static_identity.lock);
goto out;
}
- compute_mac1(computed_mac, data_start, data_len, checker->device->static_identity.static_public, checker->device->static_identity.has_psk ? checker->device->static_identity.preshared_key : NULL);
+ compute_mac1(computed_mac, skb->data, skb->len, checker->device->static_identity.static_public, checker->device->static_identity.has_psk ? checker->device->static_identity.preshared_key : NULL);
up_read(&checker->device->static_identity.lock);
if (crypto_memneq(computed_mac, macs->mac1, COOKIE_LEN))
goto out;
@@ -128,7 +128,7 @@ enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, str
make_cookie(cookie, skb, checker);
- compute_mac2(computed_mac, data_start, data_len, cookie);
+ compute_mac2(computed_mac, skb->data, skb->len, cookie);
if (crypto_memneq(computed_mac, macs->mac2, COOKIE_LEN))
goto out;
@@ -168,9 +168,9 @@ void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer *
up_read(&peer->latest_cookie.lock);
}
-void cookie_message_create(struct message_handshake_cookie *dst, struct sk_buff *skb, void *data_start, size_t data_len, __le32 index, struct cookie_checker *checker)
+void cookie_message_create(struct message_handshake_cookie *dst, struct sk_buff *skb, __le32 index, struct cookie_checker *checker)
{
- struct message_macs *macs = (struct message_macs *)((u8 *)data_start + data_len - sizeof(struct message_macs));
+ struct message_macs *macs = (struct message_macs *)((u8 *)skb->data + skb->len - sizeof(struct message_macs));
u8 cookie[COOKIE_LEN];
dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE);
diff --git a/src/cookie.h b/src/cookie.h
index e1c8d8e..87a0e5a 100644
--- a/src/cookie.h
+++ b/src/cookie.h
@@ -42,10 +42,10 @@ void cookie_checker_uninit(struct cookie_checker *checker);
void cookie_checker_precompute_keys(struct cookie_checker *checker, struct wireguard_peer *peer);
void cookie_init(struct cookie *cookie);
-enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, void *data_start, size_t data_len, bool check_cookie);
+enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie);
void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer *peer);
-void cookie_message_create(struct message_handshake_cookie *src, struct sk_buff *skb, void *data_start, size_t data_len, __le32 index, struct cookie_checker *checker);
+void cookie_message_create(struct message_handshake_cookie *src, struct sk_buff *skb, __le32 index, struct cookie_checker *checker);
void cookie_message_consume(struct message_handshake_cookie *src, struct wireguard_device *wg);
#endif
diff --git a/src/data.c b/src/data.c
index e91b150..dcbbd10 100644
--- a/src/data.c
+++ b/src/data.c
@@ -15,14 +15,6 @@
#include <net/xfrm.h>
#include <crypto/algapi.h>
-struct encryption_skb_cb {
- u8 ds;
- u8 num_frags;
- unsigned int plaintext_len, trailer_len;
- struct sk_buff *trailer;
- u64 nonce;
-};
-
struct encryption_ctx {
struct padata_priv padata;
struct sk_buff_head queue;
@@ -33,13 +25,11 @@ struct encryption_ctx {
struct decryption_ctx {
struct padata_priv padata;
+ struct endpoint endpoint;
struct sk_buff *skb;
packet_consume_data_callback_t callback;
struct noise_keypair *keypair;
- struct endpoint endpoint;
- u64 nonce;
int ret;
- u8 num_frags;
};
#ifdef CONFIG_WIREGUARD_PARALLEL
@@ -48,7 +38,6 @@ static struct kmem_cache *decryption_ctx_cache __read_mostly;
int packet_init_data_caches(void)
{
- BUILD_BUG_ON(sizeof(struct encryption_skb_cb) > sizeof(((struct sk_buff *)0)->cb));
encryption_ctx_cache = kmem_cache_create("wireguard_encryption_ctx", sizeof(struct encryption_ctx), 0, 0, NULL);
if (!encryption_ctx_cache)
return -ENOMEM;
@@ -130,13 +119,36 @@ static inline void skb_reset(struct sk_buff *skb)
skb_reset_mac_header(skb);
skb_reset_network_header(skb);
skb_probe_transport_header(skb, 0);
+ skb_reset_inner_headers(skb);
}
-static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, bool have_simd)
+static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, bool have_simd)
{
- struct encryption_skb_cb *cb = (struct encryption_skb_cb *)skb->cb;
- struct scatterlist sg[cb->num_frags]; /* This should be bound to at most 128 by the caller. */
+ struct scatterlist *sg;
struct message_data *header;
+ unsigned int padding_len, plaintext_len, trailer_len;
+ int num_frags;
+ struct sk_buff *trailer;
+
+ /* Store the ds bit in the cb */
+ PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb);
+
+ /* Calculate lengths */
+ padding_len = skb_padding(skb);
+ trailer_len = padding_len + noise_encrypted_len(0);
+ plaintext_len = skb->len + padding_len;
+
+ /* Expand data section to have room for padding and auth tag */
+ num_frags = skb_cow_data(skb, trailer_len, &trailer);
+ if (unlikely(num_frags < 0 || num_frags > 128))
+ return false;
+
+ /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */
+ memset(skb_tail_pointer(trailer), 0, padding_len);
+
+ /* Expand head section to have room for our header and the network stack's headers. */
+ if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0))
+ return false;
/* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */
if (likely(!skb_checksum_setup(skb, true)))
@@ -146,18 +158,23 @@ static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypai
header = (struct message_data *)skb_push(skb, sizeof(struct message_data));
header->header.type = cpu_to_le32(MESSAGE_DATA);
header->key_idx = keypair->remote_index;
- header->counter = cpu_to_le64(cb->nonce);
- pskb_put(skb, cb->trailer, cb->trailer_len);
+ header->counter = cpu_to_le64(PACKET_CB(skb)->nonce);
+ pskb_put(skb, trailer, trailer_len);
/* Now we can encrypt the scattergather segments */
- sg_init_table(sg, cb->num_frags);
- skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(cb->plaintext_len));
- chacha20poly1305_encrypt_sg(sg, sg, cb->plaintext_len, NULL, 0, cb->nonce, keypair->sending.key, have_simd);
+ sg = __builtin_alloca(num_frags * sizeof(struct scatterlist)); /* bounded to 128 */
+ sg_init_table(sg, num_frags);
+ skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(plaintext_len));
+ chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0, PACKET_CB(skb)->nonce, keypair->sending.key, have_simd);
+
+ return true;
}
-static inline bool skb_decrypt(struct sk_buff *skb, u8 num_frags, u64 nonce, struct noise_symmetric_key *key)
+static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *key)
{
- struct scatterlist sg[num_frags]; /* This should be bound to at most 128 by the caller. */
+ struct scatterlist *sg;
+ struct sk_buff *trailer;
+ int num_frags;
if (unlikely(!key))
return false;
@@ -167,10 +184,17 @@ static inline bool skb_decrypt(struct sk_buff *skb, u8 num_frags, u64 nonce, str
return false;
}
+ PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter);
+ skb_pull(skb, sizeof(struct message_data));
+ num_frags = skb_cow_data(skb, 0, &trailer);
+ if (unlikely(num_frags < 0 || num_frags > 128))
+ return false;
+ sg = __builtin_alloca(num_frags * sizeof(struct scatterlist)); /* bounded to 128 */
+
sg_init_table(sg, num_frags);
skb_to_sgvec(skb, sg, 0, skb->len);
- if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, nonce, key->key))
+ if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key))
return false;
return pskb_trim(skb, skb->len - noise_encrypted_len(0)) == 0;
@@ -197,10 +221,14 @@ static inline bool get_encryption_nonce(u64 *nonce, struct noise_symmetric_key *
static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_keypair *keypair)
{
- struct sk_buff *skb;
+ struct sk_buff *skb, *tmp;
bool have_simd = chacha20poly1305_init_simd();
- skb_queue_walk(queue, skb) {
- skb_encrypt(skb, keypair, have_simd);
+ skb_queue_walk_safe(queue, skb, tmp) {
+ if (unlikely(!skb_encrypt(skb, keypair, have_simd))) {
+ skb_unlink(skb, queue);
+ kfree_skb(skb);
+ continue;
+ }
skb_reset(skb);
}
chacha20poly1305_deinit_simd(have_simd);
@@ -261,35 +289,7 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer,
rcu_read_unlock();
skb_queue_walk(queue, skb) {
- struct encryption_skb_cb *cb = (struct encryption_skb_cb *)skb->cb;
- unsigned int padding_len, num_frags;
-
- if (unlikely(!get_encryption_nonce(&cb->nonce, &keypair->sending)))
- goto err;
-
- padding_len = skb_padding(skb);
- cb->trailer_len = padding_len + noise_encrypted_len(0);
- cb->plaintext_len = skb->len + padding_len;
-
- /* Store the ds bit in the cb */
- cb->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb);
-
- /* Expand data section to have room for padding and auth tag */
- ret = skb_cow_data(skb, cb->trailer_len, &cb->trailer);
- if (unlikely(ret < 0))
- goto err;
- num_frags = ret;
- ret = -ENOMEM;
- if (unlikely(num_frags > 128))
- goto err;
- cb->num_frags = num_frags;
-
- /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */
- memset(skb_tail_pointer(cb->trailer), 0, padding_len);
-
- /* Expand head section to have room for our header and the network stack's headers. */
- ret = skb_cow_head(skb, DATA_PACKET_HEAD_ROOM);
- if (unlikely(ret < 0))
+ if (unlikely(!get_encryption_nonce(&PACKET_CB(skb)->nonce, &keypair->sending)))
goto err;
/* After the first time through the loop, if we've suceeded with a legitimate nonce,
@@ -344,15 +344,18 @@ err_rcu:
static void begin_decrypt_packet(struct decryption_ctx *ctx)
{
- if (unlikely(!skb_decrypt(ctx->skb, ctx->num_frags, ctx->nonce, &ctx->keypair->receiving)))
+ ctx->ret = socket_endpoint_from_skb(&ctx->endpoint, ctx->skb);
+ if (unlikely(ctx->ret < 0))
+ goto err;
+
+ ctx->ret = -ENOKEY;
+ if (unlikely(!skb_decrypt(ctx->skb, &ctx->keypair->receiving)))
goto err;
- skb_reset(ctx->skb);
ctx->ret = 0;
return;
err:
- ctx->ret = -ENOKEY;
peer_put(ctx->keypair->entry.peer);
}
@@ -360,22 +363,25 @@ static void finish_decrypt_packet(struct decryption_ctx *ctx)
{
struct noise_keypairs *keypairs;
bool used_new_key = false;
+ u64 nonce = PACKET_CB(ctx->skb)->nonce;
int ret = ctx->ret;
if (ret)
goto err;
keypairs = &ctx->keypair->entry.peer->keypairs;
- ret = counter_validate(&ctx->keypair->receiving.counter, ctx->nonce) ? 0 : -ERANGE;
+ ret = counter_validate(&ctx->keypair->receiving.counter, nonce) ? 0 : -ERANGE;
if (likely(!ret))
used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair);
else {
- net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", ctx->nonce, ctx->keypair->receiving.counter.receive.counter);
+ net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", nonce, ctx->keypair->receiving.counter.receive.counter);
peer_put(ctx->keypair->entry.peer);
goto err;
}
noise_keypair_put(ctx->keypair);
+
+ skb_reset(ctx->skb);
ctx->callback(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key, 0);
return;
@@ -401,51 +407,26 @@ static void finish_decryption(struct padata_priv *padata)
static inline int start_decryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu)
{
+ memset(priv, 0, sizeof(struct padata_priv));
priv->parallel = do_decryption;
priv->serial = finish_decryption;
return padata_do_parallel(padata, priv, cb_cpu);
}
#endif
-void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_device *wg, packet_consume_data_callback_t callback)
+void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packet_consume_data_callback_t callback)
{
int ret;
- struct endpoint endpoint;
- unsigned int num_frags;
- struct sk_buff *trailer;
- struct message_data *header;
struct noise_keypair *keypair;
- u64 nonce;
- __le32 idx;
+ __le32 idx = ((struct message_data *)skb->data)->key_idx;
- ret = socket_endpoint_from_skb(&endpoint, skb);
- if (unlikely(ret < 0))
- goto err;
-
- ret = -ENOMEM;
- if (unlikely(!pskb_may_pull(skb, offset + sizeof(struct message_data))))
- goto err;
-
- header = (struct message_data *)(skb->data + offset);
- offset += sizeof(struct message_data);
- skb_pull(skb, offset);
-
- idx = header->key_idx;
- nonce = le64_to_cpu(header->counter);
-
- ret = skb_cow_data(skb, 0, &trailer);
- if (unlikely(ret < 0))
- goto err;
- num_frags = ret;
- ret = -ENOMEM;
- if (unlikely(num_frags > 128))
- goto err;
ret = -EINVAL;
rcu_read_lock();
keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx));
rcu_read_unlock();
if (unlikely(!keypair))
goto err;
+
#ifdef CONFIG_WIREGUARD_PARALLEL
if (cpumask_weight(cpu_online_mask) > 1) {
unsigned int cpu = choose_cpu(idx);
@@ -459,9 +440,6 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de
ctx->skb = skb;
ctx->keypair = keypair;
ctx->callback = callback;
- ctx->nonce = nonce;
- ctx->num_frags = num_frags;
- ctx->endpoint = endpoint;
ret = start_decryption(wg->parallel_receive, &ctx->padata, cpu);
if (unlikely(ret)) {
kmem_cache_free(decryption_ctx_cache, ctx);
@@ -473,10 +451,7 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de
struct decryption_ctx ctx = {
.skb = skb,
.keypair = keypair,
- .callback = callback,
- .nonce = nonce,
- .num_frags = num_frags,
- .endpoint = endpoint
+ .callback = callback
};
begin_decrypt_packet(&ctx);
finish_decrypt_packet(&ctx);
diff --git a/src/messages.h b/src/messages.h
index 7dc09aa..defc831 100644
--- a/src/messages.h
+++ b/src/messages.h
@@ -13,6 +13,7 @@
#include <linux/kernel.h>
#include <linux/param.h>
+#include <linux/skbuff.h>
enum noise_lengths {
NOISE_PUBLIC_KEY_LEN = CURVE25519_POINT_SIZE,
@@ -124,18 +125,25 @@ enum {
HANDSHAKE_DSCP = 0b10001000 /* AF41, plus 00 ECN */
};
-static inline enum message_type message_determine_type(void *src, size_t src_len)
+static const unsigned int message_header_sizes[MESSAGE_TOTAL] = {
+ [MESSAGE_HANDSHAKE_INITIATION] = sizeof(struct message_handshake_initiation),
+ [MESSAGE_HANDSHAKE_RESPONSE] = sizeof(struct message_handshake_response),
+ [MESSAGE_HANDSHAKE_COOKIE] = sizeof(struct message_handshake_cookie),
+ [MESSAGE_DATA] = sizeof(struct message_data)
+};
+
+static inline enum message_type message_determine_type(struct sk_buff *skb)
{
- struct message_header *header = src;
- if (unlikely(src_len < sizeof(struct message_header)))
+ struct message_header *header = (struct message_header *)skb->data;
+ if (unlikely(skb->len < sizeof(struct message_header)))
return MESSAGE_INVALID;
- if (header->type == cpu_to_le32(MESSAGE_DATA) && src_len >= MESSAGE_MINIMUM_LENGTH)
+ if (header->type == cpu_to_le32(MESSAGE_DATA) && skb->len >= MESSAGE_MINIMUM_LENGTH)
return MESSAGE_DATA;
- if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && src_len == sizeof(struct message_handshake_initiation))
+ if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && skb->len == sizeof(struct message_handshake_initiation))
return MESSAGE_HANDSHAKE_INITIATION;
- if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && src_len == sizeof(struct message_handshake_response))
+ if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && skb->len == sizeof(struct message_handshake_response))
return MESSAGE_HANDSHAKE_RESPONSE;
- if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && src_len == sizeof(struct message_handshake_cookie))
+ if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && skb->len == sizeof(struct message_handshake_cookie))
return MESSAGE_HANDSHAKE_COOKIE;
return MESSAGE_INVALID;
}
diff --git a/src/packets.h b/src/packets.h
index 6530048..a640847 100644
--- a/src/packets.h
+++ b/src/packets.h
@@ -14,6 +14,12 @@ struct wireguard_device;
struct wireguard_peer;
struct sk_buff;
+struct packet_cb {
+ u64 nonce;
+ u8 ds;
+};
+#define PACKET_CB(skb) ((struct packet_cb *)skb->cb)
+
/* receive.c */
void packet_receive(struct wireguard_device *wg, struct sk_buff *skb);
void packet_process_queued_handshake_packets(struct work_struct *work);
@@ -24,13 +30,13 @@ void packet_send_keepalive(struct wireguard_peer *peer);
void packet_queue_handshake_initiation(struct wireguard_peer *peer);
void packet_send_queued_handshakes(struct work_struct *work);
void packet_send_handshake_response(struct wireguard_peer *peer);
-void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, void *data, size_t data_len, __le32 sender_index);
+void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index);
/* data.c */
typedef void (*packet_create_data_callback_t)(struct sk_buff_head *, struct wireguard_peer *);
typedef void (*packet_consume_data_callback_t)(struct sk_buff *skb, struct wireguard_peer *, struct endpoint *, bool used_new_key, int err);
int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, packet_create_data_callback_t callback);
-void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_device *wg, packet_consume_data_callback_t callback);
+void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packet_consume_data_callback_t callback);
#ifdef CONFIG_WIREGUARD_PARALLEL
int packet_init_data_caches(void);
diff --git a/src/ratelimiter.c b/src/ratelimiter.c
index 12282fd..ab8f93d 100644
--- a/src/ratelimiter.c
+++ b/src/ratelimiter.c
@@ -25,7 +25,7 @@ static inline void cfg_init(struct hashlimit_cfg1 *cfg, int family)
cfg->srcmask = 32;
else if (family == NFPROTO_IPV6)
cfg->srcmask = 96;
- cfg->mode = XT_HASHLIMIT_HASH_SIP; /* source IP only -- we could also do source port by ORing this with XT_HASHLIMIT_HASH_SPT */
+ cfg->mode = XT_HASHLIMIT_HASH_SIP; /* source IP only -- we could also do source port by ORing this with XT_HASHLIMIT_HASH_SPT, but we don't really want to do that. It would also cause problems since we skb_pull early on, and hashlimit's nexthdr stuff isn't so nice. */
cfg->avg = XT_HASHLIMIT_SCALE / RATELIMITER_PACKETS_PER_SECOND; /* 30 per second per IP */
cfg->burst = RATELIMITER_PACKETS_BURSTABLE; /* Allow bursts of 5 at a time */
cfg->gc_interval = 1000; /* same as expiration date */
diff --git a/src/receive.c b/src/receive.c
index 5707ab2..f791a2e 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -30,9 +30,11 @@ static inline void update_latest_addr(struct wireguard_peer *peer, struct sk_buf
socket_set_peer_endpoint(peer, &endpoint);
}
-static inline int skb_data_offset(struct sk_buff *skb, size_t *data_offset, size_t *data_len)
+static inline int skb_prepare_header(struct sk_buff *skb)
{
struct udphdr *udp;
+ size_t data_offset, data_len;
+ enum message_type message_type;
if (unlikely(skb->len < sizeof(struct iphdr)))
return -EINVAL;
@@ -42,35 +44,50 @@ static inline int skb_data_offset(struct sk_buff *skb, size_t *data_offset, size
return -EINVAL;
udp = udp_hdr(skb);
- *data_offset = (u8 *)udp - skb->data;
- if (unlikely(*data_offset > U16_MAX)) {
+ data_offset = (u8 *)udp - skb->data;
+ if (unlikely(data_offset > U16_MAX)) {
net_dbg_skb_ratelimited("Packet has offset at impossible location from %pISpfsc\n", skb);
return -EINVAL;
}
- if (unlikely(*data_offset + sizeof(struct udphdr) > skb->len)) {
+ if (unlikely(data_offset + sizeof(struct udphdr) > skb->len)) {
net_dbg_skb_ratelimited("Packet isn't big enough to have UDP fields from %pISpfsc\n", skb);
return -EINVAL;
}
- *data_len = ntohs(udp->len);
- if (unlikely(*data_len < sizeof(struct udphdr))) {
+ data_len = ntohs(udp->len);
+ if (unlikely(data_len < sizeof(struct udphdr))) {
net_dbg_skb_ratelimited("UDP packet is reporting too small of a size from %pISpfsc\n", skb);
return -EINVAL;
}
- if (unlikely(*data_len > skb->len - *data_offset)) {
+ if (unlikely(data_len > skb->len - data_offset)) {
net_dbg_skb_ratelimited("UDP packet is lying about its size from %pISpfsc\n", skb);
return -EINVAL;
}
- *data_len -= sizeof(struct udphdr);
- *data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data;
- if (!pskb_may_pull(skb, *data_offset + sizeof(struct message_header))) {
+ data_len -= sizeof(struct udphdr);
+ data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data;
+ if (unlikely(!pskb_may_pull(skb, data_offset + sizeof(struct message_header)))) {
net_dbg_skb_ratelimited("Could not pull header into data section from %pISpfsc\n", skb);
return -EINVAL;
}
-
- return 0;
+ if (pskb_trim(skb, data_len + data_offset) < 0) {
+ net_dbg_skb_ratelimited("Could not trim packet from %pISpfsc\n", skb);
+ return -EINVAL;
+ }
+ skb_pull(skb, data_offset);
+ if (unlikely(skb->len != data_len)) {
+ net_dbg_skb_ratelimited("Final len does not agree with calculated len from %pISpfsc\n", skb);
+ return -EINVAL;
+ }
+ message_type = message_determine_type(skb);
+ __skb_push(skb, data_offset);
+ if (unlikely(!pskb_may_pull(skb, data_offset + message_header_sizes[message_type]))) {
+ net_dbg_skb_ratelimited("Could not pull full header into data section from %pISpfsc\n", skb);
+ return -EINVAL;
+ }
+ __skb_pull(skb, data_offset);
+ return message_type;
}
-static void receive_handshake_packet(struct wireguard_device *wg, void *data, size_t len, struct sk_buff *skb)
+static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff *skb)
{
struct wireguard_peer *peer = NULL;
enum message_type message_type;
@@ -78,16 +95,16 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si
enum cookie_mac_state mac_state;
bool packet_needs_cookie;
- message_type = message_determine_type(data, len);
+ message_type = message_determine_type(skb);
if (message_type == MESSAGE_HANDSHAKE_COOKIE) {
net_dbg_skb_ratelimited("Receiving cookie response from %pISpfsc\n", skb);
- cookie_message_consume(data, wg);
+ cookie_message_consume((struct message_handshake_cookie *)skb->data, wg);
return;
}
under_load = skb_queue_len(&wg->incoming_handshakes) >= MAX_QUEUED_INCOMING_HANDSHAKES / 2;
- mac_state = cookie_validate_packet(&wg->cookie_checker, skb, data, len, under_load);
+ mac_state = cookie_validate_packet(&wg->cookie_checker, skb, under_load);
if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE))
packet_needs_cookie = false;
else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
@@ -98,13 +115,13 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si
}
switch (message_type) {
- case MESSAGE_HANDSHAKE_INITIATION:
+ case MESSAGE_HANDSHAKE_INITIATION: {
+ struct message_handshake_initiation *message = (struct message_handshake_initiation *)skb->data;
if (packet_needs_cookie) {
- struct message_handshake_initiation *message = data;
- packet_send_handshake_cookie(wg, skb, message, sizeof(*message), message->sender_index);
+ packet_send_handshake_cookie(wg, skb, message->sender_index);
return;
}
- peer = noise_handshake_consume_initiation(data, wg);
+ peer = noise_handshake_consume_initiation(message, wg);
if (unlikely(!peer)) {
net_dbg_skb_ratelimited("Invalid handshake initiation from %pISpfsc\n", skb);
return;
@@ -113,13 +130,14 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si
net_dbg_ratelimited("Receiving handshake initiation from peer %Lu (%pISpfsc)\n", peer->internal_id, &peer->endpoint.addr);
packet_send_handshake_response(peer);
break;
- case MESSAGE_HANDSHAKE_RESPONSE:
+ }
+ case MESSAGE_HANDSHAKE_RESPONSE: {
+ struct message_handshake_response *message = (struct message_handshake_response *)skb->data;
if (packet_needs_cookie) {
- struct message_handshake_response *message = data;
- packet_send_handshake_cookie(wg, skb, message, sizeof(*message), message->sender_index);
+ packet_send_handshake_cookie(wg, skb, message->sender_index);
return;
}
- peer = noise_handshake_consume_response(data, wg);
+ peer = noise_handshake_consume_response(message, wg);
if (unlikely(!peer)) {
net_dbg_skb_ratelimited("Invalid handshake response from %pISpfsc\n", skb);
return;
@@ -137,6 +155,7 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si
packet_send_keepalive(peer);
}
break;
+ }
default:
WARN(1, "Somehow a wrong type of packet wound up in the handshake queue!\n");
return;
@@ -144,7 +163,7 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si
BUG_ON(!peer);
- rx_stats(peer, len);
+ rx_stats(peer, skb->len);
timers_any_authenticated_packet_received(peer);
timers_any_authenticated_packet_traversal(peer);
peer_put(peer);
@@ -154,12 +173,10 @@ void packet_process_queued_handshake_packets(struct work_struct *work)
{
struct wireguard_device *wg = container_of(work, struct wireguard_device, incoming_handshakes_work);
struct sk_buff *skb;
- size_t len, offset;
size_t num_processed = 0;
while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) {
- if (!skb_data_offset(skb, &offset, &len))
- receive_handshake_packet(wg, skb->data + offset, len, skb);
+ receive_handshake_packet(wg, skb);
dev_kfree_skb(skb);
if (++num_processed == MAX_BURST_INCOMING_HANDSHAKES) {
queue_work(wg->workqueue, &wg->incoming_handshakes_work);
@@ -188,11 +205,6 @@ static void keep_key_fresh(struct wireguard_peer *peer)
}
}
-struct packet_cb {
- u8 ds;
-};
-#define PACKET_CB(skb) ((struct packet_cb *)skb->cb)
-
static void receive_data_packet(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key, int err)
{
struct net_device *dev;
@@ -276,11 +288,10 @@ continue_processing:
void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
{
- size_t len, offset;
-
- if (unlikely(skb_data_offset(skb, &offset, &len) < 0))
+ int message_type = skb_prepare_header(skb);
+ if (unlikely(message_type < 0))
goto err;
- switch (message_determine_type(skb->data + offset, len)) {
+ switch (message_type) {
case MESSAGE_HANDSHAKE_INITIATION:
case MESSAGE_HANDSHAKE_RESPONSE:
case MESSAGE_HANDSHAKE_COOKIE:
@@ -288,17 +299,13 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
net_dbg_skb_ratelimited("Too many handshakes queued, dropping packet from %pISpfsc\n", skb);
goto err;
}
- if (skb_linearize(skb) < 0) {
- net_dbg_skb_ratelimited("Unable to linearize handshake skb from %pISpfsc\n", skb);
- goto err;
- }
skb_queue_tail(&wg->incoming_handshakes, skb);
/* Queues up a call to packet_process_queued_handshake_packets(skb): */
queue_work(wg->workqueue, &wg->incoming_handshakes_work);
break;
case MESSAGE_DATA:
PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb);
- packet_consume_data(skb, offset, wg, receive_data_packet);
+ packet_consume_data(skb, wg, receive_data_packet);
break;
default:
net_dbg_skb_ratelimited("Invalid packet from %pISpfsc\n", skb);
diff --git a/src/send.c b/src/send.c
index e04a245..f5414e1 100644
--- a/src/send.c
+++ b/src/send.c
@@ -77,12 +77,12 @@ void packet_send_handshake_response(struct wireguard_peer *peer)
}
}
-void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, void *data, size_t data_len, __le32 sender_index)
+void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index)
{
struct message_handshake_cookie packet;
net_dbg_skb_ratelimited("Sending cookie response for denied handshake message for %pISpfsc\n", initiating_skb);
- cookie_message_create(&packet, initiating_skb, data, data_len, sender_index, &wg->cookie_checker);
+ cookie_message_create(&packet, initiating_skb, sender_index, &wg->cookie_checker);
socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet, sizeof(packet));
}
@@ -123,10 +123,13 @@ static void message_create_data_done(struct sk_buff_head *queue, struct wireguar
struct sk_buff *skb, *tmp;
bool is_keepalive, data_sent = false;
+ if (unlikely(!skb_queue_len(queue)))
+ return;
+
timers_any_authenticated_packet_traversal(peer);
skb_queue_walk_safe(queue, skb, tmp) {
is_keepalive = skb->len == message_data_len(0);
- if (likely(!socket_send_skb_to_peer(peer, skb, *(u8 *)skb->cb) && !is_keepalive))
+ if (likely(!socket_send_skb_to_peer(peer, skb, PACKET_CB(skb)->ds) && !is_keepalive))
data_sent = true;
}
if (likely(data_sent))