aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2017-05-24 19:18:04 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2017-05-30 18:07:28 +0200
commit6e6986e17d0dc99903c5134c7c0c8e78fbec7831 (patch)
treef1a571f1c04e021113abb3572dc31b44401a7634
parentnoise: precompute static-static ECDH operation (diff)
downloadwireguard-monolithic-historical-6e6986e17d0dc99903c5134c7c0c8e78fbec7831.tar.xz
wireguard-monolithic-historical-6e6986e17d0dc99903c5134c7c0c8e78fbec7831.zip
peer: use iterator macro instead of callback
-rw-r--r--src/config.c32
-rw-r--r--src/device.c50
-rw-r--r--src/noise.c8
-rw-r--r--src/noise.h2
-rw-r--r--src/peer.c27
-rw-r--r--src/peer.h31
6 files changed, 73 insertions, 77 deletions
diff --git a/src/config.c b/src/config.c
index c3fe154..48afae1 100644
--- a/src/config.c
+++ b/src/config.c
@@ -8,20 +8,15 @@
#include "hashtables.h"
#include "peer.h"
#include "uapi.h"
-
-static int clear_peer_endpoint_src(struct wireguard_peer *peer, void *data)
-{
- socket_clear_peer_endpoint_src(peer);
- return 0;
-}
-
static int set_device_port(struct wireguard_device *wg, u16 port)
{
+ struct wireguard_peer *peer, *temp;
socket_uninit(wg);
wg->incoming_port = port;
if (!(netdev_pub(wg)->flags & IFF_UP))
return 0;
- peer_for_each_unlocked(wg, clear_peer_endpoint_src, NULL);
+ peer_for_each (wg, peer, temp, false)
+ socket_clear_peer_endpoint_src(peer);
return socket_init(wg);
}
@@ -133,6 +128,7 @@ int config_set_device(struct wireguard_device *wg, void __user *user_device)
{
int ret;
size_t i, offset;
+ struct wireguard_peer *peer, *temp;
struct wgdevice in_device;
void __user *user_peer;
bool modified_static_identity = false;
@@ -152,7 +148,8 @@ int config_set_device(struct wireguard_device *wg, void __user *user_device)
if (in_device.fwmark || (!in_device.fwmark && (in_device.flags & WGDEVICE_REMOVE_FWMARK))) {
wg->fwmark = in_device.fwmark;
- peer_for_each_unlocked(wg, clear_peer_endpoint_src, NULL);
+ peer_for_each (wg, peer, temp, false)
+ socket_clear_peer_endpoint_src(peer);
}
if (in_device.port) {
@@ -183,8 +180,10 @@ int config_set_device(struct wireguard_device *wg, void __user *user_device)
}
if (modified_static_identity) {
- if (peer_for_each_unlocked(wg, noise_precompute_static_static, NULL) < 0)
- noise_set_static_identity_private_key(&wg->static_identity, NULL);
+ peer_for_each (wg, peer, temp, false) {
+ if (!noise_precompute_static_static(peer))
+ peer_remove(peer);
+ }
cookie_checker_precompute_device_keys(&wg->cookie_checker);
}
@@ -242,10 +241,9 @@ static int populate_ipmask(void *ctx, union nf_inet_addr ip, u8 cidr, int family
return ret;
}
-static int populate_peer(struct wireguard_peer *peer, void *ctx)
+static int populate_peer(struct wireguard_peer *peer, struct data_remaining *data)
{
int ret = 0;
- struct data_remaining *data = ctx;
void __user *upeer = data->data;
struct wgpeer out_peer;
struct data_remaining ipmasks_data = { NULL };
@@ -289,6 +287,7 @@ static int populate_peer(struct wireguard_peer *peer, void *ctx)
int config_get_device(struct wireguard_device *wg, void __user *user_device)
{
int ret;
+ struct wireguard_peer *peer, *temp;
struct net_device *dev = netdev_pub(wg);
struct data_remaining peer_data = { NULL };
struct wgdevice out_device;
@@ -330,7 +329,12 @@ int config_get_device(struct wireguard_device *wg, void __user *user_device)
peer_data.out_len = in_device.peers_size;
peer_data.data = user_device + sizeof(struct wgdevice);
- ret = peer_for_each_unlocked(wg, populate_peer, &peer_data);
+
+ peer_for_each (wg, peer, temp, false) {
+ ret = populate_peer(peer, &peer_data);
+ if (ret)
+ break;
+ }
if (ret)
goto out;
out_device.num_peers = peer_data.count;
diff --git a/src/device.c b/src/device.c
index e10aeed..a06750a 100644
--- a/src/device.c
+++ b/src/device.c
@@ -26,18 +26,10 @@
#include <net/netfilter/nf_nat_core.h>
#endif
-static int open_peer(struct wireguard_peer *peer, void *data)
-{
- timers_init_peer(peer);
- packet_send_queue(peer);
- if (peer->persistent_keepalive_interval)
- packet_send_keepalive(peer);
- return 0;
-}
-
static int open(struct net_device *dev)
{
int ret;
+ struct wireguard_peer *peer, *temp;
struct wireguard_device *wg = netdev_priv(dev);
#if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0)
struct inet6_dev *dev_v6 = __in6_dev_get(dev);
@@ -64,16 +56,12 @@ static int open(struct net_device *dev)
ret = socket_init(wg);
if (ret < 0)
return ret;
- peer_for_each(wg, open_peer, NULL);
- return 0;
-}
-
-static int clear_noise_peer(struct wireguard_peer *peer, void *data)
-{
- noise_handshake_clear(&peer->handshake);
- noise_keypairs_clear(&peer->keypairs);
- if (peer->timers_enabled)
- del_timer(&peer->timer_kill_ephemerals);
+ peer_for_each (wg, peer, temp, true) {
+ timers_init_peer(peer);
+ packet_send_queue(peer);
+ if (peer->persistent_keepalive_interval)
+ packet_send_keepalive(peer);
+ }
return 0;
}
@@ -81,25 +69,31 @@ static int clear_noise_peer(struct wireguard_peer *peer, void *data)
static int suspending_clear_noise_peers(struct notifier_block *nb, unsigned long action, void *data)
{
struct wireguard_device *wg = container_of(nb, struct wireguard_device, clear_peers_on_suspend);
+ struct wireguard_peer *peer, *temp;
if (action == PM_HIBERNATION_PREPARE || action == PM_SUSPEND_PREPARE) {
- peer_for_each(wg, clear_noise_peer, NULL);
+ peer_for_each (wg, peer, temp, true) {
+ noise_handshake_clear(&peer->handshake);
+ noise_keypairs_clear(&peer->keypairs);
+ if (peer->timers_enabled)
+ del_timer(&peer->timer_kill_ephemerals);
+ }
rcu_barrier_bh();
}
return 0;
}
#endif
-static int stop_peer(struct wireguard_peer *peer, void *data)
-{
- timers_uninit_peer(peer);
- clear_noise_peer(peer, data);
- return 0;
-}
-
static int stop(struct net_device *dev)
{
struct wireguard_device *wg = netdev_priv(dev);
- peer_for_each(wg, stop_peer, NULL);
+ struct wireguard_peer *peer, *temp;
+ peer_for_each (wg, peer, temp, true) {
+ timers_uninit_peer(peer);
+ noise_handshake_clear(&peer->handshake);
+ noise_keypairs_clear(&peer->keypairs);
+ if (peer->timers_enabled)
+ del_timer(&peer->timer_kill_ephemerals);
+ }
skb_queue_purge(&wg->incoming_handshakes);
socket_uninit(wg);
return 0;
diff --git a/src/noise.c b/src/noise.c
index 9e7fab0..c9d8148 100644
--- a/src/noise.c
+++ b/src/noise.c
@@ -38,12 +38,12 @@ void noise_init(void)
blake2s_final(&blake, handshake_init_hash, NOISE_HASH_LEN);
}
-int noise_precompute_static_static(struct wireguard_peer *peer, void *ctx)
+bool noise_precompute_static_static(struct wireguard_peer *peer)
{
if (peer->handshake.static_identity->has_identity)
- return curve25519(peer->handshake.precomputed_static_static, peer->handshake.static_identity->static_private, peer->handshake.remote_static) ? 0 : -EINVAL;
+ return curve25519(peer->handshake.precomputed_static_static, peer->handshake.static_identity->static_private, peer->handshake.remote_static);
memset(peer->handshake.precomputed_static_static, 0, NOISE_PUBLIC_KEY_LEN);
- return 0;
+ return true;
}
bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer)
@@ -56,7 +56,7 @@ bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static
memcpy(handshake->preshared_key, peer_preshared_key, NOISE_SYMMETRIC_KEY_LEN);
handshake->static_identity = static_identity;
handshake->state = HANDSHAKE_ZEROED;
- return !noise_precompute_static_static(peer, static_identity);
+ return noise_precompute_static_static(peer);
}
void noise_handshake_clear(struct noise_handshake *handshake)
diff --git a/src/noise.h b/src/noise.h
index 5e4d9af..c2d7e63 100644
--- a/src/noise.h
+++ b/src/noise.h
@@ -109,7 +109,7 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs);
bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair);
void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN]);
-int noise_precompute_static_static(struct wireguard_peer *peer, void *ctx);
+bool noise_precompute_static_static(struct wireguard_peer *peer);
bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake);
struct wireguard_peer *noise_handshake_consume_initiation(struct message_handshake_initiation *src, struct wireguard_device *wg);
diff --git a/src/peer.c b/src/peer.c
index 411a82e..fc7ee0a 100644
--- a/src/peer.c
+++ b/src/peer.c
@@ -108,33 +108,6 @@ void peer_put(struct wireguard_peer *peer)
kref_put(&peer->refcount, kref_release);
}
-int peer_for_each_unlocked(struct wireguard_device *wg, int (*fn)(struct wireguard_peer *peer, void *ctx), void *data)
-{
- struct wireguard_peer *peer, *temp;
- int ret = 0;
-
- lockdep_assert_held(&wg->device_update_lock);
- list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) {
- peer = peer_rcu_get(peer);
- if (unlikely(!peer))
- continue;
- ret = fn(peer, data);
- peer_put(peer);
- if (ret < 0)
- break;
- }
- return ret;
-}
-
-int peer_for_each(struct wireguard_device *wg, int (*fn)(struct wireguard_peer *peer, void *ctx), void *data)
-{
- int ret;
- mutex_lock(&wg->device_update_lock);
- ret = peer_for_each_unlocked(wg, fn, data);
- mutex_unlock(&wg->device_update_lock);
- return ret;
-}
-
void peer_remove_all(struct wireguard_device *wg)
{
struct wireguard_peer *peer, *temp;
diff --git a/src/peer.h b/src/peer.h
index 8cb759a..d12c3c8 100644
--- a/src/peer.h
+++ b/src/peer.h
@@ -67,9 +67,34 @@ void peer_remove_all(struct wireguard_device *wg);
struct wireguard_peer *peer_lookup_by_index(struct wireguard_device *wg, u32 index);
-int peer_for_each_unlocked(struct wireguard_device *wg, int (*fn)(struct wireguard_peer *peer, void *ctx), void *data);
-int peer_for_each(struct wireguard_device *wg, int (*fn)(struct wireguard_peer *peer, void *ctx), void *data);
-
unsigned int peer_total_count(struct wireguard_device *wg);
+/* This is a macro iterator of essentially this:
+ *
+ * if (__should_lock)
+ * mutex_lock(&(__wg)->device_update_lock);
+ * else
+ * lockdep_assert_held(&(__wg)->device_update_lock)
+ * list_for_each_entry_safe (__peer, __temp, &(__wg)->peer_list, peer_list) {
+ * __peer = peer_rcu_get(__peer);
+ * if (!__peer)
+ * continue;
+ * ITERATOR_BODY
+ * peer_put(__peer);
+ * }
+ * if (__should_lock)
+ * mutex_unlock(&(__wg)->device_update_lock);
+ *
+ * While it's really ugly to look at, the code gcc produces from it is actually perfect.
+ */
+#define pfe_label(n) __PASTE(__PASTE(pfe_label_, n ## _), __LINE__)
+#define peer_for_each(__wg, __peer, __temp, __should_lock) \
+ if (1) { if (__should_lock) mutex_lock(&(__wg)->device_update_lock); else lockdep_assert_held(&(__wg)->device_update_lock); goto pfe_label(1); } else pfe_label(1): \
+ if (1) goto pfe_label(2); else while (1) if (1) { if (__should_lock) mutex_unlock(&(__wg)->device_update_lock); break; } else pfe_label(2): \
+ list_for_each_entry_safe (__peer, __temp, &(__wg)->peer_list, peer_list) \
+ if (0) pfe_label(3): break; else \
+ if (0); else for (__peer = peer_rcu_get(peer); __peer;) if (1) { goto pfe_label(4); pfe_label(5): break; } else while (1) if (1) goto pfe_label(5); else pfe_label(4): \
+ if (1) { goto pfe_label(6); pfe_label(7):; } else while (1) if (1) goto pfe_label(3); else while (1) if (1) goto pfe_label(7); else pfe_label(6): \
+ if (1) { goto pfe_label(8); pfe_label(9): peer_put(__peer); break; pfe_label(10): peer_put(__peer); } else while (1) if (1) goto pfe_label(9); else while (1) if (1) goto pfe_label(10); else pfe_label(8):
+
#endif