aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2017-04-03 21:40:45 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2017-04-04 03:44:35 +0200
commitfd6d19bb46a666868abd6aeff4fc43dd067489b5 (patch)
treef4dc3a91443258e81a0cc4d402bce8451481f26a /src
parentchacha20poly1305: check return values of sgops (diff)
downloadwireguard-monolithic-historical-fd6d19bb46a666868abd6aeff4fc43dd067489b5.tar.xz
wireguard-monolithic-historical-fd6d19bb46a666868abd6aeff4fc43dd067489b5.zip
data: simplify flow
Diffstat (limited to 'src')
-rw-r--r--src/data.c117
-rw-r--r--src/packets.h9
-rw-r--r--src/receive.c9
-rw-r--r--src/send.c4
4 files changed, 45 insertions, 94 deletions
diff --git a/src/data.c b/src/data.c
index ddb99b0..d993007 100644
--- a/src/data.c
+++ b/src/data.c
@@ -18,7 +18,6 @@
struct encryption_ctx {
struct padata_priv padata;
struct sk_buff_head queue;
- packet_create_data_callback_t callback;
struct wireguard_peer *peer;
struct noise_keypair *keypair;
};
@@ -27,9 +26,7 @@ struct decryption_ctx {
struct padata_priv padata;
struct endpoint endpoint;
struct sk_buff *skb;
- packet_consume_data_callback_t callback;
struct noise_keypair *keypair;
- int ret;
};
#ifdef CONFIG_WIREGUARD_PARALLEL
@@ -225,7 +222,7 @@ static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_
bool have_simd = chacha20poly1305_init_simd();
skb_queue_walk_safe(queue, skb, tmp) {
if (unlikely(!skb_encrypt(skb, keypair, have_simd))) {
- skb_unlink(skb, queue);
+ __skb_unlink(skb, queue);
kfree_skb(skb);
continue;
}
@@ -236,32 +233,22 @@ static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_
}
#ifdef CONFIG_WIREGUARD_PARALLEL
-static void do_encryption(struct padata_priv *padata)
+static void begin_parallel_encryption(struct padata_priv *padata)
{
struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata);
-
queue_encrypt_reset(&ctx->queue, ctx->keypair);
padata_do_serial(padata);
}
-static void finish_encryption(struct padata_priv *padata)
+static void finish_parallel_encryption(struct padata_priv *padata)
{
struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata);
-
- ctx->callback(&ctx->queue, ctx->peer);
+ packet_create_data_done(&ctx->queue, ctx->peer);
atomic_dec(&ctx->peer->parallel_encryption_inflight);
peer_put(ctx->peer);
kmem_cache_free(encryption_ctx_cache, ctx);
}
-static inline int start_encryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu)
-{
- memset(priv, 0, sizeof(struct padata_priv));
- priv->parallel = do_encryption;
- priv->serial = finish_encryption;
- return padata_do_parallel(padata, priv, cb_cpu);
-}
-
static inline unsigned int choose_cpu(__le32 key)
{
unsigned int cpu_index, cpu, cb_cpu;
@@ -276,7 +263,7 @@ static inline unsigned int choose_cpu(__le32 key)
}
#endif
-int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, packet_create_data_callback_t callback)
+int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer)
{
int ret = -ENOKEY;
struct noise_keypair *keypair;
@@ -303,21 +290,21 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer,
#ifdef CONFIG_WIREGUARD_PARALLEL
if ((skb_queue_len(queue) > 1 || queue->next->len > 256 || atomic_read(&peer->parallel_encryption_inflight) > 0) && cpumask_weight(cpu_online_mask) > 1) {
- unsigned int cpu = choose_cpu(keypair->remote_index);
struct encryption_ctx *ctx = kmem_cache_alloc(encryption_ctx_cache, GFP_ATOMIC);
if (!ctx)
goto serial_encrypt;
skb_queue_head_init(&ctx->queue);
skb_queue_splice_init(queue, &ctx->queue);
- ctx->callback = callback;
+ memset(&ctx->padata, 0, sizeof(ctx->padata));
+ ctx->padata.parallel = begin_parallel_encryption;
+ ctx->padata.serial = finish_parallel_encryption;
ctx->keypair = keypair;
ctx->peer = peer_rcu_get(peer);
ret = -EBUSY;
if (unlikely(!ctx->peer))
goto err_parallel;
atomic_inc(&peer->parallel_encryption_inflight);
- ret = start_encryption(peer->device->parallel_send, &ctx->padata, cpu);
- if (unlikely(ret < 0)) {
+ if (unlikely(padata_do_parallel(peer->device->parallel_send, &ctx->padata, choose_cpu(keypair->remote_index)))) {
atomic_dec(&peer->parallel_encryption_inflight);
peer_put(ctx->peer);
err_parallel:
@@ -330,7 +317,7 @@ serial_encrypt:
#endif
{
queue_encrypt_reset(queue, keypair);
- callback(queue, peer);
+ packet_create_data_done(queue, peer);
}
return 0;
@@ -344,83 +331,56 @@ err_rcu:
static void begin_decrypt_packet(struct decryption_ctx *ctx)
{
- ctx->ret = socket_endpoint_from_skb(&ctx->endpoint, ctx->skb);
- if (unlikely(ctx->ret < 0))
- goto err;
-
- ctx->ret = -ENOKEY;
- if (unlikely(!skb_decrypt(ctx->skb, &ctx->keypair->receiving)))
- goto err;
-
- ctx->ret = 0;
- return;
-
-err:
- peer_put(ctx->keypair->entry.peer);
+ if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) {
+ peer_put(ctx->keypair->entry.peer);
+ noise_keypair_put(ctx->keypair);
+ dev_kfree_skb(ctx->skb);
+ ctx->skb = NULL;
+ }
}
static void finish_decrypt_packet(struct decryption_ctx *ctx)
{
- struct noise_keypairs *keypairs;
- bool used_new_key = false;
- u64 nonce = PACKET_CB(ctx->skb)->nonce;
- int ret = ctx->ret;
- if (ret)
- goto err;
+ bool used_new_key;
- keypairs = &ctx->keypair->entry.peer->keypairs;
- ret = counter_validate(&ctx->keypair->receiving.counter, nonce) ? 0 : -ERANGE;
+ if (!ctx->skb)
+ return;
- if (likely(!ret))
- used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair);
- else {
- net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", nonce, ctx->keypair->receiving.counter.receive.counter);
+ if (unlikely(!counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(ctx->skb)->nonce))) {
+ net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter);
peer_put(ctx->keypair->entry.peer);
- goto err;
+ noise_keypair_put(ctx->keypair);
+ dev_kfree_skb(ctx->skb);
+ return;
}
- noise_keypair_put(ctx->keypair);
-
+ used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair);
skb_reset(ctx->skb);
- ctx->callback(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key, 0);
- return;
-
-err:
+ packet_consume_data_done(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key);
noise_keypair_put(ctx->keypair);
- ctx->callback(ctx->skb, NULL, NULL, false, ret);
}
#ifdef CONFIG_WIREGUARD_PARALLEL
-static void do_decryption(struct padata_priv *padata)
+static void begin_parallel_decryption(struct padata_priv *padata)
{
struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata);
begin_decrypt_packet(ctx);
padata_do_serial(padata);
}
-static void finish_decryption(struct padata_priv *padata)
+static void finish_parallel_decryption(struct padata_priv *padata)
{
struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata);
finish_decrypt_packet(ctx);
kmem_cache_free(decryption_ctx_cache, ctx);
}
-
-static inline int start_decryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu)
-{
- memset(priv, 0, sizeof(struct padata_priv));
- priv->parallel = do_decryption;
- priv->serial = finish_decryption;
- return padata_do_parallel(padata, priv, cb_cpu);
-}
#endif
-void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packet_consume_data_callback_t callback)
+void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg)
{
- int ret;
struct noise_keypair *keypair;
__le32 idx = ((struct message_data *)skb->data)->key_idx;
- ret = -EINVAL;
rcu_read_lock_bh();
keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx));
rcu_read_unlock_bh();
@@ -429,19 +389,15 @@ void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packe
#ifdef CONFIG_WIREGUARD_PARALLEL
if (cpumask_weight(cpu_online_mask) > 1) {
- unsigned int cpu = choose_cpu(idx);
- struct decryption_ctx *ctx;
-
- ret = -ENOMEM;
- ctx = kmem_cache_alloc(decryption_ctx_cache, GFP_ATOMIC);
+ struct decryption_ctx *ctx = kmem_cache_alloc(decryption_ctx_cache, GFP_ATOMIC);
if (unlikely(!ctx))
goto err_peer;
-
ctx->skb = skb;
ctx->keypair = keypair;
- ctx->callback = callback;
- ret = start_decryption(wg->parallel_receive, &ctx->padata, cpu);
- if (unlikely(ret)) {
+ memset(&ctx->padata, 0, sizeof(ctx->padata));
+ ctx->padata.parallel = begin_parallel_decryption;
+ ctx->padata.serial = finish_parallel_decryption;
+ if (unlikely(padata_do_parallel(wg->parallel_receive, &ctx->padata, choose_cpu(idx)))) {
kmem_cache_free(decryption_ctx_cache, ctx);
goto err_peer;
}
@@ -450,8 +406,7 @@ void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packe
{
struct decryption_ctx ctx = {
.skb = skb,
- .keypair = keypair,
- .callback = callback
+ .keypair = keypair
};
begin_decrypt_packet(&ctx);
finish_decrypt_packet(&ctx);
@@ -464,5 +419,5 @@ err_peer:
noise_keypair_put(keypair);
#endif
err:
- callback(skb, NULL, NULL, false, ret);
+ dev_kfree_skb(skb);
}
diff --git a/src/packets.h b/src/packets.h
index a640847..be9cfd7 100644
--- a/src/packets.h
+++ b/src/packets.h
@@ -23,6 +23,7 @@ struct packet_cb {
/* receive.c */
void packet_receive(struct wireguard_device *wg, struct sk_buff *skb);
void packet_process_queued_handshake_packets(struct work_struct *work);
+void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key);
/* send.c */
void packet_send_queue(struct wireguard_peer *peer);
@@ -31,12 +32,12 @@ void packet_queue_handshake_initiation(struct wireguard_peer *peer);
void packet_send_queued_handshakes(struct work_struct *work);
void packet_send_handshake_response(struct wireguard_peer *peer);
void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index);
+void packet_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *peer);
+
/* data.c */
-typedef void (*packet_create_data_callback_t)(struct sk_buff_head *, struct wireguard_peer *);
-typedef void (*packet_consume_data_callback_t)(struct sk_buff *skb, struct wireguard_peer *, struct endpoint *, bool used_new_key, int err);
-int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, packet_create_data_callback_t callback);
-void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packet_consume_data_callback_t callback);
+int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer);
+void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg);
#ifdef CONFIG_WIREGUARD_PARALLEL
int packet_init_data_caches(void);
diff --git a/src/receive.c b/src/receive.c
index 3b375ae..929d723 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -205,17 +205,12 @@ static void keep_key_fresh(struct wireguard_peer *peer)
}
}
-static void receive_data_packet(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key, int err)
+void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key)
{
struct net_device *dev;
struct wireguard_peer *routed_peer;
struct wireguard_device *wg;
- if (unlikely(err < 0 || !peer || !endpoint)) {
- dev_kfree_skb(skb);
- return;
- }
-
socket_set_peer_endpoint(peer, endpoint);
wg = peer->device;
@@ -305,7 +300,7 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
break;
case MESSAGE_DATA:
PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb);
- packet_consume_data(skb, wg, receive_data_packet);
+ packet_consume_data(skb, wg);
break;
default:
net_dbg_skb_ratelimited("Invalid packet from %pISpfsc\n", skb);
diff --git a/src/send.c b/src/send.c
index 046b62e..6ed660b 100644
--- a/src/send.c
+++ b/src/send.c
@@ -118,7 +118,7 @@ void packet_send_keepalive(struct wireguard_peer *peer)
packet_send_queue(peer);
}
-static void message_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *peer)
+void packet_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *peer)
{
struct sk_buff *skb, *tmp;
bool is_keepalive, data_sent = false;
@@ -157,7 +157,7 @@ void packet_send_queue(struct wireguard_peer *peer)
return;
/* We submit it for encryption and sending. */
- switch (packet_create_data(&queue, peer, message_create_data_done)) {
+ switch (packet_create_data(&queue, peer)) {
case 0:
break;
case -EBUSY: