aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatt Dunwoodie <ncon@mail.noconroy.net>2019-10-06 19:07:22 +0100
committerMatt Dunwoodie <ncon@mail.noconroy.net>2019-10-07 10:28:41 +0100
commit8e255652cab939a2590a3d093717cf39bfc04011 (patch)
treefcc766861b94f23bdd8edd57df7e119656773157
parentCouple of small fixes (diff)
downloadwireguard-openbsd-8e255652cab939a2590a3d093717cf39bfc04011.tar.xz
wireguard-openbsd-8e255652cab939a2590a3d093717cf39bfc04011.zip
Use mutexes in wireguard rather than rwlocks
There are still some nasty refcnt leaks when returning from wireguard.h, especially when an error occurs. TODO set a standard as for what to do when returning an error: * Leave session and put based on session == NULL * Always set session and always drop * Set session = NULL and drop ref I'm leaning towards the last
-rw-r--r--src/wireguard.c374
-rw-r--r--src/wireguard.h51
2 files changed, 225 insertions, 200 deletions
diff --git a/src/wireguard.c b/src/wireguard.c
index ebba73f..340a918 100644
--- a/src/wireguard.c
+++ b/src/wireguard.c
@@ -56,12 +56,10 @@ CTASSERT(WG_KEY_SIZE == WG_HASH_SIZE);
CTASSERT(WG_MSG_PADDING_SIZE == CHACHA20POLY1305_AUTHTAG_SIZE);
CTASSERT(WG_XNONCE_SIZE == XCHACHA20POLY1305_NONCE_SIZE);
+#define ret_error_peer(err) do { ret = err; goto leave_peer; } while (0)
#define ret_error(err) do { ret = err; goto leave; } while (0)
#define offsetof(type, member) ((size_t)(&((type *)0)->member))
-int wireguard_debug = 0;
-#define WGDEBUG(...) do { if (wireguard_debug) printf(__VA_ARGS__); } while (0)
-
/* wireguard.c */
void wg_kdf(uint8_t [WG_HASH_SIZE], uint8_t [WG_HASH_SIZE],
uint8_t [WG_HASH_SIZE], uint8_t [WG_KEY_SIZE], uint8_t *, size_t);
@@ -82,7 +80,8 @@ struct wg_session *wg_peer_hs_session(struct wg_peer *);
struct wg_session *wg_peer_ks_session(struct wg_peer *);
void wg_session_drop(struct wg_session *);
-void wg_peer_attach_session(struct wg_peer *, struct wg_session *);
+void wg_peer_attach_session(struct wg_peer *, struct wg_session *,
+ struct wg_handshake *, enum wg_state);
struct wg_session *wg_device_new_session(struct wg_device *);
/* Some crappy API */
@@ -99,16 +98,25 @@ wg_device_init(struct wg_device *dev, int ipl,
dev->d_cleanup = cleanup_fn;
fm_init(&dev->d_peers, 4, ipl);
fm_init(&dev->d_sessions, 12, ipl);
- rw_init(&dev->d_lock, "wg_device");
+ mtx_init(&dev->d_mtx, ipl);
/* d_cookie_maker, d_keypair initialised to 0 */
}
void
wg_device_setkey(struct wg_device *dev, struct wg_privkey *key)
{
- rw_enter_write(&dev->d_lock);
+ mtx_enter(&dev->d_mtx);
wg_keypair_from_key(&dev->d_keypair, key);
- rw_exit_write(&dev->d_lock);
+ mtx_leave(&dev->d_mtx);
+}
+
+void
+wg_device_getkey(struct wg_device *dev, struct wg_keypair *kp, int priv)
+{
+ /* TODO mask key based on priv */
+ mtx_enter(&dev->d_mtx);
+ *kp = dev->d_keypair;
+ mtx_leave(&dev->d_mtx);
}
void
@@ -116,6 +124,7 @@ wg_device_destroy(struct wg_device *dev)
{
struct map_item *item;
+ /* TODO lock fixed map */
FM_FOREACH_FILLED(item, &dev->d_peers)
wg_peer_drop(item->value);
@@ -132,12 +141,11 @@ wg_device_new_peer(struct wg_device *dev, struct wg_pubkey *key, void *arg)
peer->p_arg = arg;
peer->p_device = dev;
peer->p_remote = *key;
- refcnt_init(&peer->p_refcnt);
- rw_init(&peer->p_lock, "wg_peer");
+ mtx_init(&peer->p_mtx, dev->d_mtx.mtx_wantipl);
- rw_enter_write(&peer->p_lock);
+ mtx_enter(&peer->p_mtx);
peer->p_id = fm_insert(&dev->d_peers, peer);
- rw_exit_write(&peer->p_lock);
+ mtx_leave(&peer->p_mtx);
/* All other elements of wg_peer are nulled by M_ZERO */
return peer;
@@ -150,13 +158,11 @@ wg_device_new_session(struct wg_device *dev)
session = malloc(sizeof(*session), M_DEVBUF, M_WAITOK | M_ZERO);
getnanotime(&session->s_created);
- refcnt_init(&session->s_refcnt);
+ mtx_init(&session->s_mtx, dev->d_mtx.mtx_wantipl);
- rw_enter_write(&session->s_lock);
+ mtx_enter(&session->s_mtx);
session->s_local_id = fm_insert(&dev->d_sessions, session);
- rw_exit_write(&session->s_lock);
-
- WGDEBUG("wg_device_new_session: %x\n", session->s_local_id);
+ mtx_leave(&session->s_mtx);
return session;
}
@@ -168,7 +174,8 @@ wg_device_ref_peerkey(struct wg_device *dev, struct wg_pubkey *key)
* matching peer */
struct wg_peer *peer = NULL;
struct map_item *item;
- rw_enter_read(&dev->d_lock);
+
+ /* TODO lock, or use better data structure */
FM_FOREACH_FILLED(item, &dev->d_peers) {
peer = item->value;
if (memcmp(key->k, peer->p_remote.k, sizeof(key->k)) == 0)
@@ -176,7 +183,6 @@ wg_device_ref_peerkey(struct wg_device *dev, struct wg_pubkey *key)
else
peer = NULL;
}
- rw_exit_read(&dev->d_lock);
if (peer)
peer = fm_lookup(&dev->d_peers, peer->p_id);
@@ -212,22 +218,25 @@ wg_session_drop(struct wg_session *session)
}
void
-wg_peer_attach_session(struct wg_peer *peer, struct wg_session *session)
+wg_peer_attach_session(struct wg_peer *peer, struct wg_session *session,
+ struct wg_handshake *hs, enum wg_state state)
{
struct wg_session *old_session;
+
/* Assert the session is not attached to another peer.
- * Assert the session is for the same device that the peer is for.
* Assert the session is newly created, by checking s_state. */
+ mtx_enter(&session->s_mtx);
KASSERT(session->s_peer == NULL);
- //KASSERT(session->s_device == peer->p_device);
- KASSERT(session->s_state == WG_STATE_RECV_INITIATION ||
- session->s_state == WG_STATE_MADE_INITIATION);
-
- rw_enter_write(&peer->p_lock);
+ KASSERT(session->s_state == WG_STATE_NEW);
session->s_peer = peer;
+ session->s_state = state;
+ session->s_handshake = *hs;
+ mtx_leave(&session->s_mtx);
+
+ mtx_enter(&peer->p_mtx);
old_session = peer->p_hs_session;
peer->p_hs_session = session;
- rw_exit_write(&peer->p_lock);
+ mtx_leave(&peer->p_mtx);
if (old_session)
wg_session_drop(old_session);
@@ -236,17 +245,17 @@ wg_peer_attach_session(struct wg_peer *peer, struct wg_session *session)
void
wg_peer_reset_attempts(struct wg_peer *peer)
{
- rw_enter_write(&peer->p_lock);
+ mtx_enter(&peer->p_mtx);
peer->p_attempts = 0;
- rw_exit_write(&peer->p_lock);
+ mtx_leave(&peer->p_mtx);
}
void
wg_peer_clean(struct wg_peer *peer)
{
struct wg_session *hs, *ks, *ks_old;
- rw_enter_write(&peer->p_lock);
+ mtx_enter(&peer->p_mtx);
hs = peer->p_hs_session;
ks = peer->p_ks_session;
ks_old = peer->p_ks_session_old;
@@ -254,8 +263,7 @@ wg_peer_clean(struct wg_peer *peer)
peer->p_hs_session = NULL;
peer->p_ks_session = NULL;
peer->p_ks_session_old = NULL;
-
- rw_exit_write(&peer->p_lock);
+ mtx_leave(&peer->p_mtx);
if (hs != NULL)
wg_session_drop(hs);
@@ -274,8 +282,9 @@ wg_session_promote(struct wg_session *session)
struct wg_keyset *ks = &session->s_keyset;
struct wg_handshake *hs = &session->s_handshake;
+ /* TODO make better */
/* Setup session: derive keys, initialise the antireplay structure */
- rw_enter_write(&session->s_lock);
+ mtx_enter(&session->s_mtx);
if (session->s_state == WG_STATE_RECV_RESPONSE) {
session->s_state = WG_STATE_INITIATOR;
wg_kdf(ks->k_txkey.k, ks->k_rxkey.k, NULL, hs->h_ck, NULL, 0);
@@ -283,20 +292,21 @@ wg_session_promote(struct wg_session *session)
session->s_state = WG_STATE_RESPONDER;
wg_kdf(ks->k_rxkey.k, ks->k_txkey.k, NULL, hs->h_ck, NULL, 0);
} else {
- rw_exit_write(&session->s_lock);
+ mtx_leave(&session->s_mtx);
return;
}
antireplay_init(&ks->k_ar);
- rw_exit_write(&session->s_lock);
+ mtx_leave(&session->s_mtx);
+
- rw_enter_write(&session->s_lock);
+ mtx_enter(&peer->p_mtx);
old_session = peer->p_ks_session_old;
if (peer->p_ks_session != NULL)
peer->p_ks_session_old = peer->p_ks_session;
peer->p_ks_session = peer->p_hs_session;
peer->p_hs_session = NULL;
- rw_exit_write(&session->s_lock);
+ mtx_leave(&peer->p_mtx);
if (old_session != NULL)
wg_session_drop(old_session);
@@ -307,27 +317,28 @@ wg_session_promote(struct wg_session *session)
void
wg_peer_setshared(struct wg_peer *peer, struct wg_privkey *key)
{
- rw_enter_write(&peer->p_lock);
+ mtx_enter(&peer->p_mtx);
peer->p_shared = *key;
- rw_exit_write(&peer->p_lock);
+ mtx_leave(&peer->p_mtx);
}
void
-wg_peer_getshared(struct wg_peer *peer, struct wg_privkey *key)
+wg_peer_getshared(struct wg_peer *peer, struct wg_privkey *key, int priv)
{
- rw_enter_read(&peer->p_lock);
+ /* TODO check priv */
+ mtx_enter(&peer->p_mtx);
*key = peer->p_shared;
- rw_exit_read(&peer->p_lock);
+ mtx_leave(&peer->p_mtx);
}
struct timespec
wg_peer_last_handshake(struct wg_peer *peer)
{
struct timespec ret = { 0, 0 };
- rw_enter_read(&peer->p_lock);
- if (peer->p_ks_session)
+ mtx_enter(&peer->p_mtx);
+ if (peer->p_ks_session != NULL)
ret = peer->p_ks_session->s_created;
- rw_exit_read(&peer->p_lock);
+ mtx_leave(&peer->p_mtx);
return ret;
}
@@ -337,10 +348,10 @@ wg_peer_last_session(struct wg_peer *peer)
uint32_t id = 0;
struct wg_session *session;
- rw_enter_read(&peer->p_lock);
- if (peer->p_ks_session)
+ mtx_enter(&peer->p_mtx);
+ if (peer->p_ks_session != NULL)
id = peer->p_ks_session->s_local_id;
- rw_exit_read(&peer->p_lock);
+ mtx_leave(&peer->p_mtx);
if ((session = fm_lookup(&peer->p_device->d_sessions, id)) == NULL)
return NULL;
@@ -359,15 +370,20 @@ wg_device_rx_initiation(struct wg_device *dev, struct wg_msg_initiation *init,
struct wg_session **s)
{
struct wg_peer *peer;
+ struct wg_keypair kp;
struct wg_handshake hs;
struct wg_timestamp ts;
struct wg_pubkey remote;
- struct wg_session *session;
+ struct wg_session *session = NULL;
enum wg_error ret = WG_OK;
- rw_enter_read(&dev->d_lock);
+ /* We want to ensure that the keypair is not modified during the
+ * handshake, so we take a local copy here and bzero it before
+ * returning */
+ wg_device_getkey(dev, &kp, 1);
+ /* Noise handshake */
memcpy(hs.h_remote.k, init->ephemeral, WG_KEY_SIZE);
wg_keypair_generate(&hs.h_local);
@@ -375,56 +391,59 @@ wg_device_rx_initiation(struct wg_device *dev, struct wg_msg_initiation *init,
memcpy(hs.h_hash, hs.h_ck, WG_HASH_SIZE);
wg_mix_hash(&hs, WG_IDENTIFIER, strlen(WG_IDENTIFIER));
- wg_mix_hash(&hs, dev->d_keypair.pub.k, WG_KEY_SIZE);
+ wg_mix_hash(&hs, kp.pub.k, WG_KEY_SIZE);
wg_kdf(hs.h_ck, NULL, NULL, hs.h_ck, hs.h_remote.k, WG_KEY_SIZE);
wg_mix_hash(&hs, hs.h_remote.k, WG_KEY_SIZE);
- wg_mix_dh(&hs, dev->d_keypair.priv.k, hs.h_remote.k);
+ wg_mix_dh(&hs, kp.priv.k, hs.h_remote.k);
if (!wg_handshake_decrypt(&hs, remote.k, init->static_pub,
WG_ENCRYPTED_SIZE(sizeof(remote.k))))
ret_error(WG_DECRYPT);
wg_mix_hash(&hs, init->static_pub, sizeof(init->static_pub));
- wg_mix_dh(&hs, dev->d_keypair.priv.k, remote.k);
+ wg_mix_dh(&hs, kp.priv.k, remote.k);
if (!wg_handshake_decrypt(&hs, ts.t, init->timestamp,
WG_ENCRYPTED_SIZE(sizeof(ts.t))))
ret_error(WG_DECRYPT);
wg_mix_hash(&hs, init->timestamp, sizeof(init->timestamp));
- wg_hash2(hs.h_k, WG_MAC1, strlen(WG_MAC1), dev->d_keypair.pub.k,
- WG_KEY_SIZE);
+ wg_hash2(hs.h_k, WG_MAC1, strlen(WG_MAC1), kp.pub.k, WG_KEY_SIZE);
blake2s(hs.h_mac, (void *)init, hs.h_k, sizeof(hs.h_mac),
offsetof(struct wg_msg_initiation, mac1), sizeof(hs.h_k));
+ /* Check MAC matches */
if (timingsafe_bcmp(hs.h_mac, init->mac1, WG_MAC_SIZE))
ret_error(WG_MAC);
+ /* Lookup peer key that was specified in the packet, as we need to
+ * know what peer this is for. */
if ((peer = wg_device_ref_peerkey(dev, &remote)) == NULL)
ret_error(WG_UNKNOWN_PEER);
- if (memcmp(&ts, &peer->p_timestamp, WG_TIMESTAMP_SIZE) < 0)
- ret_error(WG_TIMESTAMP);
+ /* We want to ensure this packet is not replayed, so we validate that
+ * the timestamp (not necessarily representative of the real time) is
+ * greater than the last one we have received */
+ mtx_enter(&peer->p_mtx);
+ if (memcmp(ts.t, peer->p_timestamp.t, sizeof(ts.t)) >= 0)
+ peer->p_timestamp = ts;
+ mtx_leave(&peer->p_mtx);
- rw_exit_read(&dev->d_lock);
+ if (memcmp(ts.t, peer->p_timestamp.t, sizeof(ts.t)) != 0) {
+ wg_peer_put(peer);
+ ret_error(WG_TIMESTAMP);
+ }
- /* Create new session and add to peer */
session = wg_device_new_session(dev);
- session->s_handshake = hs;
session->s_remote_id = init->sender;
- session->s_state = WG_STATE_RECV_INITIATION;
- explicit_bzero(&hs, sizeof(hs));
- *s = session;
- wg_peer_attach_session(peer, session);
+ wg_peer_attach_session(peer, session, &hs, WG_STATE_RECV_INITIATION);
dev->d_outq(peer, WG_PKT_RESPONSE, session->s_local_id);
-
wg_peer_put(peer);
-
- return WG_OK;
+ *s = session;
leave:
- rw_exit_read(&dev->d_lock);
+ explicit_bzero(&kp, sizeof(kp));
explicit_bzero(&hs, sizeof(hs));
return ret;
}
@@ -433,61 +452,69 @@ enum wg_error
wg_device_rx_response(struct wg_device *dev, struct wg_msg_response *resp,
struct wg_session **s)
{
+ struct wg_keypair kp;
struct wg_handshake hs;
+ struct wg_privkey shared;
struct wg_session *session;
enum wg_error ret = WG_OK;
if ((session = fm_lookup(&dev->d_sessions, resp->receiver)) == NULL)
- return WG_ID;
-
- rw_enter_read(&dev->d_lock);
- rw_enter_write(&session->s_lock);
+ ret_error(WG_ID);
- if (session->s_state != WG_STATE_MADE_INITIATION)
- ret_error(WG_STATE);
+ /* Load requried values */
+ wg_device_getkey(dev, &kp, 1);
- /* Make a local copy so we don't clobber the real HS */
+ mtx_enter(&session->s_mtx);
hs = session->s_handshake;
+ mtx_leave(&session->s_mtx);
+
+ mtx_enter(&session->s_peer->p_mtx);
+ shared = session->s_peer->p_shared;
+ mtx_leave(&session->s_peer->p_mtx);
+ /* Noise recv handshake */
memcpy(hs.h_remote.k, resp->ephemeral, WG_KEY_SIZE);
wg_kdf(hs.h_ck, NULL, NULL, hs.h_ck, hs.h_remote.k, WG_KEY_SIZE);
wg_mix_hash(&hs, hs.h_remote.k, WG_KEY_SIZE);
wg_mix_dh(&hs, hs.h_local.priv.k, hs.h_remote.k);
- wg_mix_dh(&hs, dev->d_keypair.priv.k, hs.h_remote.k);
+ wg_mix_dh(&hs, kp.priv.k, hs.h_remote.k);
- wg_mix_psk(&hs, session->s_peer->p_shared.k);
+ wg_mix_psk(&hs, shared.k);
if (!wg_handshake_decrypt(&hs, NULL, resp->empty, WG_ENCRYPTED_SIZE(0)))
ret_error(WG_DECRYPT);
wg_mix_hash(&hs, resp->empty, WG_ENCRYPTED_SIZE(0));
- wg_hash2(hs.h_k, WG_MAC1, strlen(WG_MAC1), dev->d_keypair.pub.k,
+ wg_hash2(hs.h_k, WG_MAC1, strlen(WG_MAC1), kp.pub.k,
WG_KEY_SIZE);
blake2s(hs.h_mac, (void *)resp, hs.h_k, sizeof(hs.h_mac),
offsetof(struct wg_msg_response, mac1), sizeof(hs.h_k));
+ /* Compare macs */
if (timingsafe_bcmp(hs.h_mac, resp->mac1, WG_MAC_SIZE))
ret_error(WG_MAC);
- session->s_handshake = hs;
- session->s_remote_id = resp->sender;
- session->s_state = WG_STATE_RECV_RESPONSE;
-
- rw_exit_write(&session->s_lock);
- rw_exit_read(&dev->d_lock);
- explicit_bzero(&hs, sizeof(hs));
+ /* Update session only if we are in correct state */
+ mtx_enter(&session->s_mtx);
+ if (session->s_state == WG_STATE_MADE_INITIATION) {
+ session->s_handshake = hs;
+ session->s_remote_id = resp->sender;
+ session->s_state = WG_STATE_RECV_RESPONSE;
+ } else {
+ ret = WG_STATE;
+ }
+ mtx_leave(&session->s_mtx);
wg_session_promote(session);
- *s = session;
- return ret;
+ *s = session;
leave:
- rw_exit_write(&session->s_lock);
- rw_exit_read(&dev->d_lock);
+ explicit_bzero(&shared, sizeof(shared));
+ explicit_bzero(&kp, sizeof(kp));
explicit_bzero(&hs, sizeof(hs));
return ret;
@@ -504,30 +531,26 @@ wg_device_rx_cookie(struct wg_device *dev, struct wg_msg_cookie *cookie,
enum wg_error ret = WG_OK;
if ((session = fm_lookup(&dev->d_sessions, cookie->receiver)) == NULL)
- return WG_ID;
+ ret_error(WG_ID);
- rw_enter_write(&session->s_peer->p_lock);
- rw_enter_read(&session->s_lock);
+ wg_hash2(key, WG_COOKIE, strlen(WG_COOKIE),
+ session->s_peer->p_remote.k, WG_KEY_SIZE);
- if (session->s_state != WG_STATE_MADE_INITIATION &&
- session->s_state != WG_STATE_MADE_RESPONSE)
- ret_error(WG_STATE);
-
- wg_hash2(key, WG_COOKIE, strlen(WG_COOKIE), session->s_peer->p_remote.k,
- WG_KEY_SIZE);
-
- if(!xchacha20poly1305_decrypt(value, cookie->value, sizeof(cookie->value),
- session->s_handshake.h_mac, WG_MAC_SIZE, cookie->nonce, key))
+ /* TODO lock for h_mac? */
+ if(!xchacha20poly1305_decrypt(value, cookie->value,
+ sizeof(cookie->value), session->s_handshake.h_mac, WG_MAC_SIZE,
+ cookie->nonce, key))
ret_error(WG_DECRYPT);
+ /* Update peer with new cookie data */
+ mtx_enter(&session->s_peer->p_mtx);
memcpy(session->s_peer->p_cookie.cookie, value,
sizeof(session->s_peer->p_cookie.cookie));
getnanotime(&session->s_peer->p_cookie.time);
+ mtx_leave(&session->s_peer->p_mtx);
*s = session;
leave:
- rw_exit_read(&session->s_lock);
- rw_exit_write(&session->s_peer->p_lock);
return ret;
}
@@ -535,18 +558,17 @@ enum wg_error
wg_device_rx_transport(struct wg_device *dev, struct wg_msg_transport *msg,
size_t len, struct wg_session **s)
{
+ struct wg_session *session;
enum wg_error ret = WG_OK;
size_t data_len = len - offsetof(struct wg_msg_transport, data);
uint64_t counter = letoh64(msg->counter);
- struct wg_session *session;
if ((session = fm_lookup(&dev->d_sessions, msg->receiver)) == NULL)
- return WG_ID;
+ ret_error(WG_ID);
wg_session_promote(session);
- rw_enter_read(&session->s_lock);
-
+ /* TODO fix locks, at the moment we just kinda don't care */
if (session->s_state != WG_STATE_INITIATOR &&
session->s_state != WG_STATE_RESPONDER)
ret_error(WG_STATE);
@@ -559,7 +581,6 @@ wg_device_rx_transport(struct wg_device *dev, struct wg_msg_transport *msg,
msg->counter, session->s_keyset.k_rxkey.k))
ret_error(WG_DECRYPT);
- /* TODO fix read lock -> write lock */
if (antireplay_update(&session->s_keyset.k_ar, counter))
ret_error(WG_REPLAY);
@@ -572,7 +593,6 @@ wg_device_rx_transport(struct wg_device *dev, struct wg_msg_transport *msg,
*s = session;
leave:
- rw_exit_read(&session->s_lock);
return ret;
}
@@ -582,58 +602,61 @@ wg_device_tx_initiation(struct wg_device *dev, struct wg_msg_initiation *init,
uint32_t id, struct wg_session **s)
{
struct wg_peer *peer;
- struct wg_handshake *hs;
- struct wg_session *session;
+ struct wg_keypair kp;
+ struct wg_handshake hs;
+ struct wg_session *session = NULL;
enum wg_error ret = WG_OK;
if ((peer = fm_lookup(&dev->d_peers, id)) == NULL)
- return WG_ID;
+ ret_error(WG_ID);
- /* TODO better locking */
+ /* TODO do we care about locking these? */
if (!wg_timespec_timedout(&peer->p_last_initiation, WG_REKEY_TIMEOUT))
- ret_error(WG_HS_RATE);
+ ret_error_peer(WG_HS_RATE);
if (peer->p_attempts >= WG_REKEY_ATTEMPT_COUNT)
- ret_error(WG_HS_ATTEMPTS);
+ ret_error_peer(WG_HS_ATTEMPTS);
+ /* We need to generate the session here first, so we can use s_local_id
+ * below. We also want to operate on a local handshake, so we don't
+ * have to lock the session. */
session = wg_device_new_session(dev);
- hs = &session->s_handshake;
- rw_enter_read(&dev->d_lock);
- rw_enter_read(&peer->p_lock);
-
- wg_keypair_generate(&hs->h_local);
+ wg_device_getkey(dev, &kp, 1);
+ wg_keypair_generate(&hs.h_local);
+ /* Noise handshake */
init->type = WG_MSG_INITIATION;
init->sender = session->s_local_id;
- memcpy(init->ephemeral, hs->h_local.pub.k, WG_KEY_SIZE);
+ memcpy(init->ephemeral, hs.h_local.pub.k, WG_KEY_SIZE);
- wg_hash2(hs->h_ck, WG_CONSTRUCTION, strlen(WG_CONSTRUCTION), NULL, 0);
- memcpy(hs->h_hash, hs->h_ck, WG_HASH_SIZE);
- wg_mix_hash(hs, WG_IDENTIFIER, strlen(WG_IDENTIFIER));
+ wg_hash2(hs.h_ck, WG_CONSTRUCTION, strlen(WG_CONSTRUCTION), NULL, 0);
+ memcpy(hs.h_hash, hs.h_ck, WG_HASH_SIZE);
+ wg_mix_hash(&hs, WG_IDENTIFIER, strlen(WG_IDENTIFIER));
- wg_mix_hash(hs, peer->p_remote.k, WG_KEY_SIZE);
- wg_kdf(hs->h_ck, NULL, NULL, hs->h_ck, hs->h_local.pub.k, WG_KEY_SIZE);
- wg_mix_hash(hs, hs->h_local.pub.k, WG_KEY_SIZE);
- wg_mix_dh(hs, hs->h_local.priv.k, peer->p_remote.k);
+ wg_mix_hash(&hs, peer->p_remote.k, WG_KEY_SIZE);
+ wg_kdf(hs.h_ck, NULL, NULL, hs.h_ck, hs.h_local.pub.k, WG_KEY_SIZE);
+ wg_mix_hash(&hs, hs.h_local.pub.k, WG_KEY_SIZE);
+ wg_mix_dh(&hs, hs.h_local.priv.k, peer->p_remote.k);
- wg_handshake_encrypt(hs, init->static_pub, dev->d_keypair.pub.k,
+ wg_handshake_encrypt(&hs, init->static_pub, kp.pub.k,
WG_KEY_SIZE);
- wg_mix_hash(hs, init->static_pub, WG_ENCRYPTED_SIZE(WG_KEY_SIZE));
- wg_mix_dh(hs, dev->d_keypair.priv.k, peer->p_remote.k);
+ wg_mix_hash(&hs, init->static_pub, WG_ENCRYPTED_SIZE(WG_KEY_SIZE));
+ wg_mix_dh(&hs, kp.priv.k, peer->p_remote.k);
wg_timestamp_get(init->timestamp);
- wg_handshake_encrypt(hs, init->timestamp, init->timestamp, WG_TIMESTAMP_SIZE);
- wg_mix_hash(hs, init->timestamp, WG_ENCRYPTED_SIZE(WG_TIMESTAMP_SIZE));
- wg_hash2(hs->h_k, WG_MAC1, strlen(WG_MAC1), peer->p_remote.k, WG_KEY_SIZE);
+ wg_handshake_encrypt(&hs, init->timestamp, init->timestamp, WG_TIMESTAMP_SIZE);
+ wg_mix_hash(&hs, init->timestamp, WG_ENCRYPTED_SIZE(WG_TIMESTAMP_SIZE));
+ wg_hash2(hs.h_k, WG_MAC1, strlen(WG_MAC1), peer->p_remote.k, WG_KEY_SIZE);
- blake2s(init->mac1, (void *) init, hs->h_k, sizeof(init->mac1),
- offsetof(struct wg_msg_initiation, mac1), sizeof(hs->h_k));
- memcpy(hs->h_mac, init->mac1, sizeof(hs->h_mac));
+ blake2s(init->mac1, (void *) init, hs.h_k, sizeof(init->mac1),
+ offsetof(struct wg_msg_initiation, mac1), sizeof(hs.h_k));
+ memcpy(hs.h_mac, init->mac1, sizeof(hs.h_mac));
+ /* TODO lock for cookie time? */
if (wg_timespec_timedout(&peer->p_cookie.time, WG_COOKIE_VALID_TIME))
bzero(init->mac2, WG_MAC_SIZE);
else
@@ -641,22 +664,21 @@ wg_device_tx_initiation(struct wg_device *dev, struct wg_msg_initiation *init,
sizeof(init->mac2), offsetof(struct wg_msg_initiation, mac2),
sizeof(peer->p_cookie.cookie));
- rw_exit_read(&peer->p_lock);
- rw_exit_read(&dev->d_lock);
-
- /* TODO lock? */
+ /* Update peer */
+ mtx_enter(&peer->p_mtx);
getnanotime(&peer->p_last_initiation);
peer->p_attempts++;
+ mtx_leave(&peer->p_mtx);
/* Attach session to peer */
- session->s_state = WG_STATE_MADE_INITIATION;
- wg_peer_attach_session(peer, session);
- *s = session;
+ wg_peer_attach_session(peer, session, &hs, WG_STATE_MADE_INITIATION);
+ *s = session;
+leave_peer:
wg_peer_put(peer);
- return WG_OK;
leave:
- wg_peer_put(peer);
+ explicit_bzero(&kp, sizeof(kp));
+ explicit_bzero(&hs, sizeof(hs));
return ret;
}
@@ -666,43 +688,41 @@ wg_device_tx_response(struct wg_device *dev, struct wg_msg_response *resp,
{
enum wg_error ret = WG_OK;
- struct wg_handshake *hs;
+ struct wg_handshake hs;
struct wg_session *session;
if ((session = fm_lookup(&dev->d_sessions, id)) == NULL)
- return WG_ID;
-
- rw_enter_read(&session->s_peer->p_lock);
- rw_enter_write(&session->s_lock);
-
- if (session->s_state != WG_STATE_RECV_INITIATION)
- ret_error(WG_STATE);
+ ret_error(WG_ID);
resp->type = WG_MSG_RESPONSE;
resp->sender = session->s_local_id;
resp->receiver = session->s_remote_id;
- hs = &session->s_handshake;
+ mtx_enter(&session->s_mtx);
+ hs = session->s_handshake;
+ mtx_leave(&session->s_mtx);
- wg_kdf(hs->h_ck, NULL, NULL, hs->h_ck, hs->h_local.pub.k, WG_KEY_SIZE);
- wg_mix_hash(hs, hs->h_local.pub.k, WG_KEY_SIZE);
+ /* Noise handshake */
+ wg_kdf(hs.h_ck, NULL, NULL, hs.h_ck, hs.h_local.pub.k, WG_KEY_SIZE);
+ wg_mix_hash(&hs, hs.h_local.pub.k, WG_KEY_SIZE);
- memcpy(resp->ephemeral, hs->h_local.pub.k, WG_KEY_SIZE);
+ memcpy(resp->ephemeral, hs.h_local.pub.k, WG_KEY_SIZE);
- wg_mix_dh(hs, hs->h_local.priv.k, hs->h_remote.k);
- wg_mix_dh(hs, hs->h_local.priv.k, session->s_peer->p_remote.k);
+ wg_mix_dh(&hs, hs.h_local.priv.k, hs.h_remote.k);
+ wg_mix_dh(&hs, hs.h_local.priv.k, session->s_peer->p_remote.k);
- wg_mix_psk(hs, session->s_peer->p_shared.k);
+ wg_mix_psk(&hs, session->s_peer->p_shared.k);
- wg_handshake_encrypt(hs, resp->empty, NULL, 0);
+ wg_handshake_encrypt(&hs, resp->empty, NULL, 0);
- wg_mix_hash(hs, resp->empty, WG_ENCRYPTED_SIZE(0));
+ wg_mix_hash(&hs, resp->empty, WG_ENCRYPTED_SIZE(0));
- wg_hash2(hs->h_k, WG_MAC1, strlen(WG_MAC1), session->s_peer->p_remote.k, WG_KEY_SIZE);
- blake2s(resp->mac1, (void *)resp, hs->h_k, sizeof(resp->mac1),
- offsetof(struct wg_msg_response, mac1), sizeof(hs->h_k));
- memcpy(hs->h_mac, resp->mac1, sizeof(hs->h_mac));
+ wg_hash2(hs.h_k, WG_MAC1, strlen(WG_MAC1), session->s_peer->p_remote.k, WG_KEY_SIZE);
+ blake2s(resp->mac1, (void *)resp, hs.h_k, sizeof(resp->mac1),
+ offsetof(struct wg_msg_response, mac1), sizeof(hs.h_k));
+ memcpy(hs.h_mac, resp->mac1, sizeof(hs.h_mac));
+ /* TODO lock for cookie time? */
if (wg_timespec_timedout(&session->s_peer->p_cookie.time, WG_COOKIE_VALID_TIME))
bzero(resp->mac2, WG_MAC_SIZE);
else
@@ -710,15 +730,16 @@ wg_device_tx_response(struct wg_device *dev, struct wg_msg_response *resp,
sizeof(resp->mac2), offsetof(struct wg_msg_response, mac2),
sizeof(session->s_peer->p_cookie.cookie));
- session->s_state = WG_STATE_MADE_RESPONSE;
+ /* Update session */
+ mtx_enter(&session->s_mtx);
+ if (session->s_state == WG_STATE_RECV_INITIATION)
+ session->s_state = WG_STATE_MADE_RESPONSE;
+ else
+ ret = WG_STATE;
+ mtx_leave(&session->s_mtx);
+
*s = session;
- rw_exit_write(&session->s_lock);
- rw_exit_read(&session->s_peer->p_lock);
- return ret;
leave:
- rw_exit_write(&session->s_lock);
- rw_exit_read(&session->s_peer->p_lock);
- wg_session_put(session);
return ret;
}
@@ -737,10 +758,9 @@ wg_device_tx_transport(struct wg_device *dev, struct wg_msg_transport *msg,
struct wg_session *session;
if ((session = fm_lookup(&dev->d_sessions, id)) == NULL)
- return WG_ID;
-
- rw_enter_read(&session->s_lock);
+ ret_error(WG_ID);
+ /* TODO we should do some locking */
if (session->s_state != WG_STATE_INITIATOR &&
session->s_state != WG_STATE_RESPONDER)
ret_error(WG_STATE);
@@ -763,10 +783,8 @@ wg_device_tx_transport(struct wg_device *dev, struct wg_msg_transport *msg,
dev->d_outq(session->s_peer, WG_PKT_INITIATION, session->s_peer->p_id);
session->s_peer->p_tx_bytes += len;
-
*s = session;
leave:
- rw_exit_read(&session->s_lock);
return ret;
}
diff --git a/src/wireguard.h b/src/wireguard.h
index ede3403..80b35ac 100644
--- a/src/wireguard.h
+++ b/src/wireguard.h
@@ -20,7 +20,7 @@
#include <sys/types.h>
#include <sys/time.h>
#include <sys/timeout.h>
-#include <sys/rwlock.h>
+#include <sys/mutex.h>
#include <sys/fixedmap.h>
#include <sys/antireplay.h>
@@ -144,14 +144,14 @@ struct wg_timers {
};
struct wg_session {
- uint32_t s_local_id; /* Static */
- uint32_t s_remote_id; /* Static */
- struct wg_peer *s_peer; /* Static */
- struct timespec s_created; /* Static */
- struct refcnt s_refcnt; /* Atomic */
-
- /* All protected by s_lock */
- struct rwlock s_lock;
+ /* Static */
+ uint32_t s_local_id;
+ uint32_t s_remote_id;
+ struct wg_peer *s_peer;
+ struct timespec s_created;
+
+ /* All protected by s_mtx */
+ struct mutex s_mtx;
enum wg_state s_state;
struct wg_handshake {
@@ -179,14 +179,16 @@ struct wg_peer {
struct wg_pubkey p_remote;
struct refcnt p_refcnt;
- /* All protected by p_lock */
- struct rwlock p_lock;
- struct wg_cookie p_cookie;
+ /* Atomic */
struct wg_timers p_timers;
+
+ /* All protected by p_mtx */
+ struct mutex p_mtx;
+ struct wg_privkey p_shared;
+ struct wg_cookie p_cookie;
struct timespec p_last_initiation;
uint64_t p_tx_bytes;
uint64_t p_rx_bytes;
- struct wg_privkey p_shared;
uint8_t p_attempts;
struct wg_session *p_hs_session;
struct wg_session *p_ks_session;
@@ -195,15 +197,20 @@ struct wg_peer {
};
struct wg_device {
+ /* Static */
void *d_arg;
void (*d_cleanup)(struct wg_peer *);
void (*d_notify)(struct wg_peer *);
void (*d_outq)(struct wg_peer *, enum wg_pkt_type, uint32_t);
- struct fixed_map d_peers;
- struct fixed_map d_sessions;
- struct rwlock d_lock;
+
+ /* Mutex */
+ struct mutex d_mtx;
struct wg_cookie_maker d_cookie_maker;
struct wg_keypair d_keypair;
+
+ /* Atomic */
+ struct fixed_map d_peers;
+ struct fixed_map d_sessions;
};
enum wg_error {
@@ -242,12 +249,12 @@ static char *wg_error_str[] = {
/* WireGuard functions */
-void wg_device_init(struct wg_device *, int,
- void (*)(struct wg_peer *),
- void (*)(struct wg_peer *, enum wg_pkt_type, uint32_t),
- void (*)(struct wg_peer *), void *);
-void wg_device_setkey(struct wg_device *, struct wg_privkey *);
-void wg_device_destroy(struct wg_device *);
+void wg_device_init(struct wg_device *, int,
+ void (*)(struct wg_peer *),
+ void (*)(struct wg_peer *, enum wg_pkt_type, uint32_t),
+ void (*)(struct wg_peer *), void *);
+void wg_device_setkey(struct wg_device *, struct wg_privkey *);
+void wg_device_destroy(struct wg_device *);
struct wg_peer *wg_device_new_peer(struct wg_device *, struct wg_pubkey *, void *);
struct wg_peer *wg_device_ref_peerkey(struct wg_device *, struct wg_pubkey *);