aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2018-08-01 15:59:37 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2018-08-03 00:14:18 +0200
commit81eb0e30f9b39e99d1bb7b56828fd32e50ea055a (patch)
tree7b9e212d2a73644bae8b09da164147a4491d98fa /src
parentnoise: free peer references on failure (diff)
downloadwireguard-monolithic-historical-81eb0e30f9b39e99d1bb7b56828fd32e50ea055a.tar.xz
wireguard-monolithic-historical-81eb0e30f9b39e99d1bb7b56828fd32e50ea055a.zip
peer: ensure destruction doesn't race
Completely rework peer removal to ensure peers don't jump between contexts and create races.
Diffstat (limited to 'src')
-rw-r--r--src/compat/compat.h3
-rw-r--r--src/cookie.c8
-rw-r--r--src/hashtables.c6
-rw-r--r--src/hashtables.h2
-rw-r--r--src/noise.c58
-rw-r--r--src/noise.h2
-rw-r--r--src/peer.c43
-rw-r--r--src/peer.h1
-rw-r--r--src/receive.c41
-rw-r--r--src/send.c34
-rw-r--r--src/timers.c60
11 files changed, 148 insertions, 110 deletions
diff --git a/src/compat/compat.h b/src/compat/compat.h
index 5b3075b..86df5f3 100644
--- a/src/compat/compat.h
+++ b/src/compat/compat.h
@@ -51,6 +51,9 @@
#ifndef READ_ONCE
#define READ_ONCE ACCESS_ONCE
#endif
+#ifndef WRITE_ONCE
+#define WRITE_ONCE(p, v) (ACCESS_ONCE(p) = (v))
+#endif
#if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0)
#include "udp_tunnel/udp_tunnel_partial_compat.h"
diff --git a/src/cookie.c b/src/cookie.c
index bc6d8be..9268630 100644
--- a/src/cookie.c
+++ b/src/cookie.c
@@ -165,15 +165,9 @@ void cookie_message_consume(struct message_handshake_cookie *src, struct wiregua
{
u8 cookie[COOKIE_LEN];
struct wireguard_peer *peer = NULL;
- struct index_hashtable_entry *entry;
bool ret;
- rcu_read_lock_bh();
- entry = index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE | INDEX_HASHTABLE_KEYPAIR, src->receiver_index);
- if (likely(entry))
- peer = entry->peer;
- rcu_read_unlock_bh();
- if (unlikely(!peer))
+ if (unlikely(!index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE | INDEX_HASHTABLE_KEYPAIR, src->receiver_index, &peer)))
return;
down_read(&peer->latest_cookie.lock);
diff --git a/src/hashtables.c b/src/hashtables.c
index 03b9e21..ac6df59 100644
--- a/src/hashtables.c
+++ b/src/hashtables.c
@@ -152,7 +152,7 @@ void index_hashtable_remove(struct index_hashtable *table, struct index_hashtabl
}
/* Returns a strong reference to a entry->peer */
-struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index)
+struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer)
{
struct index_hashtable_entry *iter_entry, *entry = NULL;
@@ -166,7 +166,9 @@ struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *tab
}
if (likely(entry)) {
entry->peer = peer_get_maybe_zero(entry->peer);
- if (unlikely(!entry->peer))
+ if (likely(entry->peer))
+ *peer = entry->peer;
+ else
entry = NULL;
}
rcu_read_unlock_bh();
diff --git a/src/hashtables.h b/src/hashtables.h
index a2ef6f0..f64cd24 100644
--- a/src/hashtables.h
+++ b/src/hashtables.h
@@ -47,6 +47,6 @@ void index_hashtable_init(struct index_hashtable *table);
__le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry);
bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new);
void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry);
-struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index);
+struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer);
#endif /* _WG_HASHTABLES_H */
diff --git a/src/noise.c b/src/noise.c
index a1e094b..0f6e51b 100644
--- a/src/noise.c
+++ b/src/noise.c
@@ -103,24 +103,23 @@ static struct noise_keypair *keypair_create(struct wireguard_peer *peer)
static void keypair_free_rcu(struct rcu_head *rcu)
{
- struct noise_keypair *keypair = container_of(rcu, struct noise_keypair, rcu);
-
- net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id);
- kzfree(keypair);
+ kzfree(container_of(rcu, struct noise_keypair, rcu));
}
static void keypair_free_kref(struct kref *kref)
{
struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount);
-
+ net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id);
index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry);
call_rcu_bh(&keypair->rcu, keypair_free_rcu);
}
-void noise_keypair_put(struct noise_keypair *keypair)
+void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
{
if (unlikely(!keypair))
return;
+ if (unlikely(unreference_now))
+ index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry);
kref_put(&keypair->refcount, keypair_free_kref);
}
@@ -139,13 +138,13 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs)
spin_lock_bh(&keypairs->keypair_update_lock);
old = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
- noise_keypair_put(old);
+ noise_keypair_put(old, true);
old = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
RCU_INIT_POINTER(keypairs->next_keypair, NULL);
- noise_keypair_put(old);
+ noise_keypair_put(old, true);
old = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
RCU_INIT_POINTER(keypairs->current_keypair, NULL);
- noise_keypair_put(old);
+ noise_keypair_put(old, true);
spin_unlock_bh(&keypairs->keypair_update_lock);
}
@@ -171,7 +170,7 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai
*/
RCU_INIT_POINTER(keypairs->next_keypair, NULL);
rcu_assign_pointer(keypairs->previous_keypair, next_keypair);
- noise_keypair_put(current_keypair);
+ noise_keypair_put(current_keypair, true);
} else /* If there wasn't an existing next keypair, we replace the
* previous with the current one.
*/
@@ -179,7 +178,7 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai
/* At this point we can get rid of the old previous keypair, and set up
* the new keypair.
*/
- noise_keypair_put(previous_keypair);
+ noise_keypair_put(previous_keypair, true);
rcu_assign_pointer(keypairs->current_keypair, new_keypair);
} else {
/* If we're the responder, it means we can't use the new keypair until
@@ -188,9 +187,9 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai
* in the new next one.
*/
rcu_assign_pointer(keypairs->next_keypair, new_keypair);
- noise_keypair_put(next_keypair);
+ noise_keypair_put(next_keypair, true);
RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
- noise_keypair_put(previous_keypair);
+ noise_keypair_put(previous_keypair, true);
}
spin_unlock_bh(&keypairs->keypair_update_lock);
}
@@ -218,7 +217,7 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k
*/
old_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)));
- noise_keypair_put(old_keypair);
+ noise_keypair_put(old_keypair, true);
rcu_assign_pointer(keypairs->current_keypair, received_keypair);
RCU_INIT_POINTER(keypairs->next_keypair, NULL);
@@ -542,7 +541,7 @@ out:
struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg)
{
struct noise_handshake *handshake;
- struct wireguard_peer *ret_peer = NULL;
+ struct wireguard_peer *peer = NULL, *ret_peer = NULL;
u8 key[NOISE_SYMMETRIC_KEY_LEN];
u8 hash[NOISE_HASH_LEN];
u8 chaining_key[NOISE_HASH_LEN];
@@ -556,7 +555,7 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
if (unlikely(!wg->static_identity.has_identity))
goto out;
- handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index);
+ handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index, &peer);
if (unlikely(!handshake))
goto out;
@@ -601,11 +600,11 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
handshake->remote_index = src->sender_index;
handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
up_write(&handshake->lock);
- ret_peer = handshake->entry.peer;
+ ret_peer = peer;
goto out;
fail:
- peer_put(handshake->entry.peer);
+ peer_put(peer);
out:
memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
memzero_explicit(hash, NOISE_HASH_LEN);
@@ -619,14 +618,15 @@ out:
bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs)
{
struct noise_keypair *new_keypair;
+ bool ret = false;
down_write(&handshake->lock);
if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
- goto fail;
+ goto out;
new_keypair = keypair_create(handshake->entry.peer);
if (!new_keypair)
- goto fail;
+ goto out;
new_keypair->i_am_the_initiator = handshake->state == HANDSHAKE_CONSUMED_RESPONSE;
new_keypair->remote_index = handshake->remote_index;
@@ -636,14 +636,16 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi
derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key);
handshake_zero(handshake);
- add_new_keypair(keypairs, new_keypair);
- net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id);
- WARN_ON(!index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry));
- up_write(&handshake->lock);
-
- return true;
+ rcu_read_lock_bh();
+ if (likely(!container_of(handshake, struct wireguard_peer, handshake)->is_dead)) {
+ add_new_keypair(keypairs, new_keypair);
+ net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id);
+ ret = index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry);
+ } else
+ kzfree(new_keypair);
+ rcu_read_unlock_bh();
-fail:
+out:
up_write(&handshake->lock);
- return false;
+ return ret;
}
diff --git a/src/noise.h b/src/noise.h
index 5804acf..be59587 100644
--- a/src/noise.h
+++ b/src/noise.h
@@ -95,7 +95,7 @@ struct wireguard_device;
void noise_init(void);
bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer);
void noise_handshake_clear(struct noise_handshake *handshake);
-void noise_keypair_put(struct noise_keypair *keypair);
+void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now);
struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair);
void noise_keypairs_clear(struct noise_keypairs *keypairs);
bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair);
diff --git a/src/peer.c b/src/peer.c
index 093f318..97a472a 100644
--- a/src/peer.c
+++ b/src/peer.c
@@ -86,18 +86,40 @@ void peer_remove(struct wireguard_peer *peer)
if (unlikely(!peer))
return;
lockdep_assert_held(&peer->device->device_update_lock);
+
+ /* Remove from configuration-time lookup structures so new packets can't enter. */
+ list_del_init(&peer->peer_list);
allowedips_remove_by_peer(&peer->device->peer_allowedips, peer, &peer->device->device_update_lock);
pubkey_hashtable_remove(&peer->device->peer_hashtable, peer);
- skb_queue_purge(&peer->staged_packet_queue);
- noise_handshake_clear(&peer->handshake);
+
+ /* Mark as dead, so that we don't allow jumping contexts after. */
+ WRITE_ONCE(peer->is_dead, true);
+ synchronize_rcu_bh();
+
+ /* Now that no more keypairs can be created for this peer, we destroy existing ones. */
noise_keypairs_clear(&peer->keypairs);
- list_del_init(&peer->peer_list);
+
+ /* Destroy all ongoing timers that were in-flight at the beginning of this function. */
timers_stop(peer);
- flush_workqueue(peer->device->packet_crypt_wq); /* The first flush is for encrypt/decrypt. */
- flush_workqueue(peer->device->packet_crypt_wq); /* The second.1 flush is for send (but not receive, since that's napi). */
- napi_disable(&peer->napi); /* The second.2 flush is for receive (but not send, since that's wq). */
- flush_workqueue(peer->device->handshake_send_wq);
+
+ /* The transition between packet encryption/decryption queues isn't guarded
+ * by is_dead, but each reference's life is strictly bounded by two
+ * generations: once for parallel crypto and once for serial ingestion,
+ * so we can simply flush twice, and be sure that we no longer have references
+ * inside these queues.
+ *
+ * a) For encrypt/decrypt. */
+ flush_workqueue(peer->device->packet_crypt_wq);
+ /* b.1) For send (but not receive, since that's napi). */
+ flush_workqueue(peer->device->packet_crypt_wq);
+ /* b.2.1) For receive (but not send, since that's wq). */
+ napi_disable(&peer->napi);
+ /* b.2.1) It's now safe to remove the napi struct, which must be done here from process context. */
netif_napi_del(&peer->napi);
+
+ /* Ensure any workstructs we own (like transmit_handshake_work or clear_peer_work) no longer are in use. */
+ flush_workqueue(peer->device->handshake_send_wq);
+
--peer->device->num_peers;
peer_put(peer);
}
@@ -105,8 +127,6 @@ void peer_remove(struct wireguard_peer *peer)
static void rcu_release(struct rcu_head *rcu)
{
struct wireguard_peer *peer = container_of(rcu, struct wireguard_peer, rcu);
-
- pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr);
dst_cache_destroy(&peer->endpoint_cache);
packet_queue_free(&peer->rx_queue, false);
packet_queue_free(&peer->tx_queue, false);
@@ -116,9 +136,12 @@ static void rcu_release(struct rcu_head *rcu)
static void kref_release(struct kref *refcount)
{
struct wireguard_peer *peer = container_of(refcount, struct wireguard_peer, refcount);
-
+ pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr);
+ /* Remove ourself from dynamic runtime lookup structures, now that the last reference is gone. */
index_hashtable_remove(&peer->device->index_hashtable, &peer->handshake.entry);
+ /* Remove any lingering packets that didn't have a chance to be transmitted. */
skb_queue_purge(&peer->staged_packet_queue);
+ /* Free the memory used. */
call_rcu_bh(&peer->rcu, rcu_release);
}
diff --git a/src/peer.h b/src/peer.h
index 059fa64..8daa053 100644
--- a/src/peer.h
+++ b/src/peer.h
@@ -58,6 +58,7 @@ struct wireguard_peer {
struct list_head peer_list;
u64 internal_id;
struct napi_struct napi;
+ bool is_dead;
};
struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]);
diff --git a/src/receive.c b/src/receive.c
index 12af8ed..d3a698a 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -282,9 +282,9 @@ out:
}
#include "selftest/counter.h"
-static void packet_consume_data_done(struct sk_buff *skb, struct endpoint *endpoint)
+static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff *skb, struct endpoint *endpoint)
{
- struct wireguard_peer *peer = PACKET_PEER(skb), *routed_peer;
+ struct wireguard_peer *routed_peer;
struct net_device *dev = peer->device->dev;
unsigned int len, len_before_trim;
@@ -400,11 +400,11 @@ int packet_rx_poll(struct napi_struct *napi, int budget)
goto next;
skb_reset(skb);
- packet_consume_data_done(skb, &endpoint);
+ packet_consume_data_done(peer, skb, &endpoint);
free = false;
next:
- noise_keypair_put(keypair);
+ noise_keypair_put(keypair, false);
peer_put(peer);
if (unlikely(free))
dev_kfree_skb(skb);
@@ -436,32 +436,31 @@ void packet_decrypt_worker(struct work_struct *work)
static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb)
{
- struct wireguard_peer *peer;
+ struct wireguard_peer *peer = NULL;
__le32 idx = ((struct message_data *)skb->data)->key_idx;
int ret;
rcu_read_lock_bh();
- PACKET_CB(skb)->keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx));
- rcu_read_unlock_bh();
- if (unlikely(!PACKET_CB(skb)->keypair)) {
- dev_kfree_skb(skb);
- return;
- }
+ PACKET_CB(skb)->keypair = (struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx, &peer);
+ if (unlikely(!noise_keypair_get(PACKET_CB(skb)->keypair)))
+ goto err_keypair;
- /* The call to index_hashtable_lookup gives us a reference to its underlying peer, so we don't need to call peer_get(). */
- peer = PACKET_PEER(skb);
+ if (unlikely(peer->is_dead))
+ goto err;
ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu);
- if (likely(!ret))
- return; /* Successful. No need to drop references below. */
-
- if (ret == -EPIPE)
+ if (unlikely(ret == -EPIPE))
queue_enqueue_per_peer(&peer->rx_queue, skb, PACKET_STATE_DEAD);
- else {
- peer_put(peer);
- noise_keypair_put(PACKET_CB(skb)->keypair);
- dev_kfree_skb(skb);
+ if (likely(!ret || ret == -EPIPE)) {
+ rcu_read_unlock_bh();
+ return;
}
+err:
+ noise_keypair_put(PACKET_CB(skb)->keypair, false);
+err_keypair:
+ rcu_read_unlock_bh();
+ peer_put(peer);
+ dev_kfree_skb(skb);
}
void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
diff --git a/src/send.c b/src/send.c
index 788ff60..481d153 100644
--- a/src/send.c
+++ b/src/send.c
@@ -58,13 +58,16 @@ void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool i
/* First checking the timestamp here is just an optimization; it will
* be caught while properly locked inside the actual work queue.
*/
- if (!has_expired(peer->last_sent_handshake, REKEY_TIMEOUT))
- return;
+ rcu_read_lock_bh();
+ if (!has_expired(peer->last_sent_handshake, REKEY_TIMEOUT) || unlikely(peer->is_dead))
+ goto out;
peer_get(peer);
/* Queues up calling packet_send_queued_handshakes(peer), where we do a peer_put(peer) after: */
if (!queue_work(peer->device->handshake_send_wq, &peer->transmit_handshake_work))
peer_put(peer); /* If the work was already queued, we want to drop the extra reference */
+out:
+ rcu_read_unlock_bh();
}
void packet_send_handshake_response(struct wireguard_peer *peer)
@@ -233,7 +236,7 @@ void packet_tx_worker(struct work_struct *work)
else
skb_free_null_queue(first);
- noise_keypair_put(keypair);
+ noise_keypair_put(keypair, false);
peer_put(peer);
}
}
@@ -266,19 +269,22 @@ static void packet_create_data(struct sk_buff *first)
{
struct wireguard_peer *peer = PACKET_PEER(first);
struct wireguard_device *wg = peer->device;
- int ret;
+ int ret = -EINVAL;
- ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu);
- if (likely(!ret))
- return; /* Successful. No need to fall through to drop references below. */
+ rcu_read_lock_bh();
+ if (unlikely(peer->is_dead))
+ goto err;
- if (ret == -EPIPE)
+ ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu);
+ if (unlikely(ret == -EPIPE))
queue_enqueue_per_peer(&peer->tx_queue, first, PACKET_STATE_DEAD);
- else {
- peer_put(peer);
- noise_keypair_put(PACKET_CB(first)->keypair);
- skb_free_null_queue(first);
- }
+err:
+ rcu_read_unlock_bh();
+ if (likely(!ret || ret == -EPIPE))
+ return;
+ noise_keypair_put(PACKET_CB(first)->keypair, false);
+ peer_put(peer);
+ skb_free_null_queue(first);
}
void packet_send_staged_packets(struct wireguard_peer *peer)
@@ -328,7 +334,7 @@ void packet_send_staged_packets(struct wireguard_peer *peer)
out_invalid:
key->is_valid = false;
out_nokey:
- noise_keypair_put(keypair);
+ noise_keypair_put(keypair, false);
/* We orphan the packets if we're waiting on a handshake, so that they
* don't block a socket's pool.
diff --git a/src/timers.c b/src/timers.c
index e8bb101..762152a 100644
--- a/src/timers.c
+++ b/src/timers.c
@@ -27,9 +27,20 @@
if (unlikely(!peer)) \
return;
-static inline bool timers_active(struct wireguard_peer *peer)
+static inline void mod_peer_timer(struct wireguard_peer *peer, struct timer_list *timer, unsigned long expires)
{
- return netif_running(peer->device->dev) && !list_empty(&peer->peer_list);
+ rcu_read_lock_bh();
+ if (likely(netif_running(peer->device->dev) && !peer->is_dead))
+ mod_timer(timer, expires);
+ rcu_read_unlock_bh();
+}
+
+static inline void del_peer_timer(struct wireguard_peer *peer, struct timer_list *timer)
+{
+ rcu_read_lock_bh();
+ if (likely(netif_running(peer->device->dev) && !peer->is_dead))
+ del_timer(timer);
+ rcu_read_unlock_bh();
}
static void expired_retransmit_handshake(struct timer_list *timer)
@@ -39,8 +50,7 @@ static void expired_retransmit_handshake(struct timer_list *timer)
if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) {
pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2);
- if (likely(timers_active(peer)))
- del_timer(&peer->timer_send_keepalive);
+ del_peer_timer(peer, &peer->timer_send_keepalive);
/* We drop all packets without a keypair and don't try again,
* if we try unsuccessfully for too long to make a handshake.
*/
@@ -49,8 +59,8 @@ static void expired_retransmit_handshake(struct timer_list *timer)
/* We set a timer for destroying any residue that might be left
* of a partial exchange.
*/
- if (likely(timers_active(peer)) && !timer_pending(&peer->timer_zero_key_material))
- mod_timer(&peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ);
+ if (!timer_pending(&peer->timer_zero_key_material))
+ mod_peer_timer(peer, &peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ);
} else {
++peer->timer_handshake_attempts;
pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d seconds, retrying (try %d)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, REKEY_TIMEOUT, peer->timer_handshake_attempts + 1);
@@ -70,8 +80,7 @@ static void expired_send_keepalive(struct timer_list *timer)
packet_send_keepalive(peer);
if (peer->timer_need_another_keepalive) {
peer->timer_need_another_keepalive = false;
- if (likely(timers_active(peer)))
- mod_timer(&peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ);
+ mod_peer_timer(peer, &peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ);
}
peer_put(peer);
}
@@ -91,8 +100,12 @@ static void expired_zero_key_material(struct timer_list *timer)
{
peer_get_from_timer(timer_zero_key_material);
- if (!queue_work(peer->device->handshake_send_wq, &peer->clear_peer_work)) /* Takes our reference. */
- peer_put(peer); /* If the work was already on the queue, we want to drop the extra reference */
+ rcu_read_lock_bh();
+ if (!peer->is_dead) {
+ if (!queue_work(peer->device->handshake_send_wq, &peer->clear_peer_work)) /* Should take our reference. */
+ peer_put(peer); /* If the work was already on the queue, we want to drop the extra reference */
+ }
+ rcu_read_unlock_bh();
}
static void queued_expired_zero_key_material(struct work_struct *work)
{
@@ -116,16 +129,16 @@ static void expired_send_persistent_keepalive(struct timer_list *timer)
/* Should be called after an authenticated data packet is sent. */
void timers_data_sent(struct wireguard_peer *peer)
{
- if (likely(timers_active(peer)) && !timer_pending(&peer->timer_new_handshake))
- mod_timer(&peer->timer_new_handshake, jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ);
+ if (!timer_pending(&peer->timer_new_handshake))
+ mod_peer_timer(peer, &peer->timer_new_handshake, jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ);
}
/* Should be called after an authenticated data packet is received. */
void timers_data_received(struct wireguard_peer *peer)
{
- if (likely(timers_active(peer))) {
+ if (likely(netif_running(peer->device->dev))) {
if (!timer_pending(&peer->timer_send_keepalive))
- mod_timer(&peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ);
+ mod_peer_timer(peer, &peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ);
else
peer->timer_need_another_keepalive = true;
}
@@ -134,29 +147,25 @@ void timers_data_received(struct wireguard_peer *peer)
/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */
void timers_any_authenticated_packet_sent(struct wireguard_peer *peer)
{
- if (likely(timers_active(peer)))
- del_timer(&peer->timer_send_keepalive);
+ del_peer_timer(peer, &peer->timer_send_keepalive);
}
/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */
void timers_any_authenticated_packet_received(struct wireguard_peer *peer)
{
- if (likely(timers_active(peer)))
- del_timer(&peer->timer_new_handshake);
+ del_peer_timer(peer, &peer->timer_new_handshake);
}
/* Should be called after a handshake initiation message is sent. */
void timers_handshake_initiated(struct wireguard_peer *peer)
{
- if (likely(timers_active(peer)))
- mod_timer(&peer->timer_retransmit_handshake, jiffies + REKEY_TIMEOUT * HZ + prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES));
+ mod_peer_timer(peer, &peer->timer_retransmit_handshake, jiffies + REKEY_TIMEOUT * HZ + prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES));
}
/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
void timers_handshake_complete(struct wireguard_peer *peer)
{
- if (likely(timers_active(peer)))
- del_timer(&peer->timer_retransmit_handshake);
+ del_peer_timer(peer, &peer->timer_retransmit_handshake);
peer->timer_handshake_attempts = 0;
peer->sent_lastminute_handshake = false;
getnstimeofday(&peer->walltime_last_handshake);
@@ -165,15 +174,14 @@ void timers_handshake_complete(struct wireguard_peer *peer)
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
void timers_session_derived(struct wireguard_peer *peer)
{
- if (likely(timers_active(peer)))
- mod_timer(&peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ);
+ mod_peer_timer(peer, &peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ);
}
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
void timers_any_authenticated_packet_traversal(struct wireguard_peer *peer)
{
- if (peer->persistent_keepalive_interval && likely(timers_active(peer)))
- mod_timer(&peer->timer_persistent_keepalive, jiffies + peer->persistent_keepalive_interval * HZ);
+ if (peer->persistent_keepalive_interval)
+ mod_peer_timer(peer, &peer->timer_persistent_keepalive, jiffies + peer->persistent_keepalive_interval * HZ);
}
void timers_init(struct wireguard_peer *peer)