diff options
Diffstat (limited to 'src/wg_noise.c')
-rw-r--r-- | src/wg_noise.c | 1415 |
1 files changed, 889 insertions, 526 deletions
diff --git a/src/wg_noise.c b/src/wg_noise.c index b5bd5c5..5d92750 100644 --- a/src/wg_noise.c +++ b/src/wg_noise.c @@ -9,147 +9,474 @@ #include <sys/lock.h> #include <sys/rwlock.h> #include <sys/systm.h> +#include <sys/malloc.h> +#include <sys/refcount.h> +#include <sys/epoch.h> +#include <sys/ck.h> -#include "support.h" +#include "crypto.h" #include "wg_noise.h" +#include "support.h" -/* Private functions */ -static struct noise_keypair * - noise_remote_keypair_allocate(struct noise_remote *); -static void - noise_remote_keypair_free(struct noise_remote *, - struct noise_keypair *); -static uint32_t noise_remote_handshake_index_get(struct noise_remote *); -static void noise_remote_handshake_index_drop(struct noise_remote *); - -static uint64_t noise_counter_send(struct noise_counter *); -static int noise_counter_recv(struct noise_counter *, uint64_t); +/* Protocol string constants */ +#define NOISE_HANDSHAKE_NAME "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" +#define NOISE_IDENTIFIER_NAME "WireGuard v1 zx2c4 Jason@zx2c4.com" + +/* Constants for the counter */ +#define COUNTER_BITS_TOTAL 8192 +#define COUNTER_BITS (sizeof(unsigned long) * 8) +#define COUNTER_NUM (COUNTER_BITS_TOTAL / COUNTER_BITS) +#define COUNTER_WINDOW_SIZE (COUNTER_BITS_TOTAL - COUNTER_BITS) + +/* Constants for the keypair */ +#define REKEY_AFTER_MESSAGES (1ull << 60) +#define REJECT_AFTER_MESSAGES (UINT64_MAX - COUNTER_WINDOW_SIZE - 1) +#define REKEY_AFTER_TIME 120 +#define REKEY_AFTER_TIME_RECV 165 +#define REJECT_INTERVAL (1000000000 / 50) /* fifty times per sec */ +/* 24 = floor(log2(REJECT_INTERVAL)) */ +#define REJECT_INTERVAL_MASK (~((1ull<<24)-1)) +#define TIMER_RESET (struct timespec){ -(REKEY_TIMEOUT+1), 0 } + +#define HT_INDEX_SIZE (1 << 13) +#define HT_INDEX_MASK (HT_INDEX_SIZE - 1) +#define HT_REMOTE_SIZE (1 << 11) +#define HT_REMOTE_MASK (HT_REMOTE_SIZE - 1) +#define MAX_REMOTE_PER_LOCAL (1 << 20) + +struct noise_index { + CK_LIST_ENTRY(noise_index) i_entry; + uint32_t i_local_index; + uint32_t i_remote_index; + int i_is_keypair; +}; + +struct noise_keypair { + struct noise_index kp_index; + u_int kp_refcnt; + int kp_can_send; + int kp_is_initiator; + struct timespec kp_birthdate; /* nanouptime */ + struct noise_remote *kp_remote; + + uint8_t kp_send[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t kp_recv[NOISE_SYMMETRIC_KEY_LEN]; + + /* Counter elements */ + struct rwlock kp_nonce_lock; + uint64_t kp_nonce_send; + uint64_t kp_nonce_recv; + unsigned long kp_backtrack[COUNTER_NUM]; + + struct epoch_context kp_smr; +}; + +struct noise_handshake { + uint8_t hs_e[NOISE_PUBLIC_KEY_LEN]; + uint8_t hs_hash[NOISE_HASH_LEN]; + uint8_t hs_ck[NOISE_HASH_LEN]; +}; + +struct noise_remote { + struct noise_index r_index; + + CK_LIST_ENTRY(noise_remote) r_entry; + uint8_t r_public[NOISE_PUBLIC_KEY_LEN]; + + struct rwlock r_handshake_lock; + struct noise_handshake r_handshake; + int r_handshake_alive; + int r_handshake_initiator; + struct timespec r_last_sent; /* nanouptime */ + struct timespec r_last_init_recv; /* nanouptime */ + uint8_t r_timestamp[NOISE_TIMESTAMP_LEN]; + uint8_t r_psk[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t r_ss[NOISE_PUBLIC_KEY_LEN]; + + u_int r_refcnt; + struct noise_local *r_local; + void *r_arg; + + struct rwlock r_keypair_lock; + struct noise_keypair *r_next, *r_current, *r_previous; + + struct epoch_context r_smr; + void (*r_cleanup)(struct noise_remote *); +}; + +struct noise_local { + struct rwlock l_identity_lock; + int l_has_identity; + uint8_t l_public[NOISE_PUBLIC_KEY_LEN]; + uint8_t l_private[NOISE_PUBLIC_KEY_LEN]; + + u_int l_refcnt; + SIPHASH_KEY l_hash_key; + void *l_arg; + void (*l_cleanup)(struct noise_local *); + + struct rwlock l_remote_lock; + size_t l_remote_num; + CK_LIST_HEAD(,noise_remote) l_remote_hash[HT_REMOTE_SIZE]; + + struct rwlock l_index_lock; + CK_LIST_HEAD(,noise_index) l_index_hash[HT_INDEX_SIZE]; +}; + +static void noise_precompute_ss(struct noise_local *, struct noise_remote *); + +static void noise_remote_index_insert(struct noise_local *, struct noise_remote *); +static int noise_remote_index_remove(struct noise_local *, struct noise_remote *); +static void noise_remote_expire_current(struct noise_remote *); + + +static void noise_add_new_keypair(struct noise_local *, struct noise_remote *, struct noise_keypair *); +static int noise_received_with(struct noise_keypair *); +static int noise_begin_session(struct noise_remote *); +static void noise_keypair_drop(struct noise_keypair *); static void noise_kdf(uint8_t *, uint8_t *, uint8_t *, const uint8_t *, - size_t, size_t, size_t, size_t, - const uint8_t [NOISE_HASH_LEN]); -static int noise_mix_dh( - uint8_t [NOISE_HASH_LEN], - uint8_t [NOISE_SYMMETRIC_KEY_LEN], - const uint8_t [NOISE_PUBLIC_KEY_LEN], - const uint8_t [NOISE_PUBLIC_KEY_LEN]); -static int noise_mix_ss( - uint8_t ck[NOISE_HASH_LEN], - uint8_t key[NOISE_SYMMETRIC_KEY_LEN], - const uint8_t ss[NOISE_PUBLIC_KEY_LEN]); -static void noise_mix_hash( - uint8_t [NOISE_HASH_LEN], - const uint8_t *, - size_t); -static void noise_mix_psk( - uint8_t [NOISE_HASH_LEN], - uint8_t [NOISE_HASH_LEN], - uint8_t [NOISE_SYMMETRIC_KEY_LEN], - const uint8_t [NOISE_SYMMETRIC_KEY_LEN]); -static void noise_param_init( - uint8_t [NOISE_HASH_LEN], - uint8_t [NOISE_HASH_LEN], - const uint8_t [NOISE_PUBLIC_KEY_LEN]); - + size_t, size_t, size_t, size_t, + const uint8_t [NOISE_HASH_LEN]); +static int noise_mix_dh(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_SYMMETRIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); +static int noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t [NOISE_SYMMETRIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); +static void noise_mix_hash(uint8_t [NOISE_HASH_LEN], const uint8_t *, size_t); +static void noise_mix_psk(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN], + uint8_t [NOISE_SYMMETRIC_KEY_LEN], const uint8_t [NOISE_SYMMETRIC_KEY_LEN]); +static void noise_param_init(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); static void noise_msg_encrypt(uint8_t *, const uint8_t *, size_t, - uint8_t [NOISE_SYMMETRIC_KEY_LEN], - uint8_t [NOISE_HASH_LEN]); + uint8_t [NOISE_SYMMETRIC_KEY_LEN], uint8_t [NOISE_HASH_LEN]); static int noise_msg_decrypt(uint8_t *, const uint8_t *, size_t, - uint8_t [NOISE_SYMMETRIC_KEY_LEN], - uint8_t [NOISE_HASH_LEN]); -static void noise_msg_ephemeral( - uint8_t [NOISE_HASH_LEN], - uint8_t [NOISE_HASH_LEN], - const uint8_t src[NOISE_PUBLIC_KEY_LEN]); - + uint8_t [NOISE_SYMMETRIC_KEY_LEN], uint8_t [NOISE_HASH_LEN]); +static void noise_msg_ephemeral(uint8_t [NOISE_HASH_LEN], uint8_t [NOISE_HASH_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); static void noise_tai64n_now(uint8_t [NOISE_TIMESTAMP_LEN]); static int noise_timer_expired(struct timespec *, time_t, long); -/* Set/Get noise parameters */ -void -noise_local_init(struct noise_local *l, struct noise_upcall *upcall) +/* I can't find where FreeBSD defines such behaviours, so that is temporarily here. */ +#define epoch_ptr_read(p) ck_pr_load_ptr(p) +#define epoch_ptr_write(p, v) ck_pr_store_ptr(p, v) +/* Back to regular programming... */ + +MALLOC_DEFINE(M_NOISE, "NOISE", "wgnoise"); + +/* Local configuration */ +struct noise_local * +noise_local_alloc(void *arg) { - bzero(l, sizeof(*l)); - rw_init(&l->l_identity_lock, "noise_local_identity"); - l->l_upcall = *upcall; + struct noise_local *l; + size_t i; + + if ((l = malloc(sizeof(*l), M_NOISE, M_NOWAIT)) == NULL) + return (NULL); + + rw_init(&l->l_identity_lock, "noise_identity"); + l->l_has_identity = 0; + bzero(l->l_public, NOISE_PUBLIC_KEY_LEN); + bzero(l->l_private, NOISE_PUBLIC_KEY_LEN); + + refcount_init(&l->l_refcnt, 1); + arc4random_buf(&l->l_hash_key, sizeof(l->l_hash_key)); + l->l_arg = arg; + l->l_cleanup = NULL; + + rw_init(&l->l_remote_lock, "noise_remote"); + l->l_remote_num = 0; + for (i = 0; i < HT_REMOTE_SIZE; i++) + CK_LIST_INIT(&l->l_remote_hash[i]); + + rw_init(&l->l_index_lock, "noise_index"); + for (i = 0; i < HT_INDEX_SIZE; i++) + CK_LIST_INIT(&l->l_index_hash[i]); + + return (l); +} + +struct noise_local * +noise_local_ref(struct noise_local *l) +{ + refcount_acquire(&l->l_refcnt); + return (l); } void -noise_local_lock_identity(struct noise_local *l) +noise_local_put(struct noise_local *l) { - rw_enter_write(&l->l_identity_lock); + if (refcount_release(&l->l_refcnt)) { + if (l->l_cleanup != NULL) + l->l_cleanup(l); + explicit_bzero(l, sizeof(*l)); + free(l, M_NOISE); + } } void -noise_local_unlock_identity(struct noise_local *l) +noise_local_free(struct noise_local *l, void (*cleanup)(struct noise_local *)) { - rw_exit_write(&l->l_identity_lock); + l->l_cleanup = cleanup; + noise_local_put(l); } -int -noise_local_set_private(struct noise_local *l, - const uint8_t private[NOISE_PUBLIC_KEY_LEN]) +void * +noise_local_arg(struct noise_local *l) { - rw_assert_wrlock(&l->l_identity_lock); + return (l->l_arg); +} +void +noise_local_private(struct noise_local *l, const uint8_t private[NOISE_PUBLIC_KEY_LEN]) +{ + struct epoch_tracker et; + struct noise_remote *r; + size_t i; + + rw_wlock(&l->l_identity_lock); memcpy(l->l_private, private, NOISE_PUBLIC_KEY_LEN); curve25519_clamp_secret(l->l_private); - l->l_has_identity = curve25519_generate_public(l->l_public, private); + l->l_has_identity = curve25519_generate_public(l->l_public, l->l_private); - return l->l_has_identity ? 0 : ENXIO; + NET_EPOCH_ENTER(et); + for (i = 0; i < HT_REMOTE_SIZE; i++) { + CK_LIST_FOREACH(r, &l->l_remote_hash[i], r_entry) { + noise_precompute_ss(l, r); + noise_remote_expire_current(r); + } + } + NET_EPOCH_EXIT(et); + rw_wunlock(&l->l_identity_lock); } int noise_local_keys(struct noise_local *l, uint8_t public[NOISE_PUBLIC_KEY_LEN], uint8_t private[NOISE_PUBLIC_KEY_LEN]) { - int ret = 0; - rw_enter_read(&l->l_identity_lock); - if (l->l_has_identity) { + int has_identity; + rw_rlock(&l->l_identity_lock); + if ((has_identity = l->l_has_identity)) { if (public != NULL) memcpy(public, l->l_public, NOISE_PUBLIC_KEY_LEN); if (private != NULL) memcpy(private, l->l_private, NOISE_PUBLIC_KEY_LEN); - } else { - ret = ENXIO; } - rw_exit_read(&l->l_identity_lock); - return ret; + rw_runlock(&l->l_identity_lock); + return (has_identity ? 0 : ENXIO); } -void -noise_remote_init(struct noise_remote *r, - const uint8_t public[NOISE_PUBLIC_KEY_LEN], struct noise_local *l) +static void +noise_precompute_ss(struct noise_local *l, struct noise_remote *r) +{ + rw_wlock(&r->r_handshake_lock); + if (!l->l_has_identity || + !curve25519(r->r_ss, l->l_private, r->r_public)) + bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN); + rw_wunlock(&r->r_handshake_lock); +} + +/* Remote configuration */ +struct noise_remote * +noise_remote_alloc(struct noise_local *l, void *arg, + const uint8_t public[NOISE_PUBLIC_KEY_LEN], + const uint8_t psk[NOISE_PUBLIC_KEY_LEN]) { - bzero(r, sizeof(*r)); + struct noise_remote *r, *ri; + uint64_t idx; + + if ((r = malloc(sizeof(*r), M_NOISE, M_NOWAIT)) == NULL) + return (NULL); + + r->r_index.i_is_keypair = 0; + memcpy(r->r_public, public, NOISE_PUBLIC_KEY_LEN); + rw_init(&r->r_handshake_lock, "noise_handshake"); + bzero(&r->r_handshake, sizeof(r->r_handshake)); + r->r_handshake_alive = 0; + r->r_handshake_initiator = 0; + r->r_last_sent = TIMER_RESET; + r->r_last_init_recv = TIMER_RESET; + bzero(r->r_timestamp, NOISE_TIMESTAMP_LEN); + noise_remote_set_psk(r, psk); + noise_precompute_ss(l, r); + + refcount_init(&r->r_refcnt, 1); + r->r_local = noise_local_ref(l); + r->r_arg = arg; + rw_init(&r->r_keypair_lock, "noise_keypair"); + r->r_next = r->r_current = r->r_previous = NULL; - SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[0], kp_entry); - SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[1], kp_entry); - SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[2], kp_entry); + bzero(&r->r_smr, sizeof(r->r_smr)); - KASSERT(l != NULL, ("must provide local")); - r->r_local = l; + /* Insert to hashtable */ + idx = siphash24(&l->l_hash_key, public, NOISE_PUBLIC_KEY_LEN) & HT_REMOTE_MASK; - rw_enter_write(&l->l_identity_lock); - noise_remote_precompute(r); - rw_exit_write(&l->l_identity_lock); + rw_wlock(&l->l_remote_lock); + CK_LIST_FOREACH(ri, &l->l_remote_hash[idx], r_entry) + if (timingsafe_bcmp(ri->r_public, public, NOISE_PUBLIC_KEY_LEN) == 0) + goto free; + if (l->l_remote_num < MAX_REMOTE_PER_LOCAL) { + l->l_remote_num++; + CK_LIST_INSERT_HEAD(&l->l_remote_hash[idx], r, r_entry); + } else { +free: + free(r, M_NOISE); + noise_local_put(l); + r = NULL; + } + rw_wunlock(&l->l_remote_lock); + + return (r); } -int +struct noise_remote * +noise_remote_lookup(struct noise_local *l, const uint8_t public[NOISE_PUBLIC_KEY_LEN]) +{ + struct epoch_tracker et; + struct noise_remote *r, *ret = NULL; + uint64_t idx; + + idx = siphash24(&l->l_hash_key, public, NOISE_PUBLIC_KEY_LEN) & HT_REMOTE_MASK; + + NET_EPOCH_ENTER(et); + CK_LIST_FOREACH(r, &l->l_remote_hash[idx], r_entry) { + if (timingsafe_bcmp(r->r_public, public, NOISE_PUBLIC_KEY_LEN) == 0) { + if (refcount_acquire_if_not_zero(&r->r_refcnt)) + ret = r; + break; + } + } + NET_EPOCH_EXIT(et); + return (ret); +} + +static void +noise_remote_index_insert(struct noise_local *l, struct noise_remote *r) +{ + struct noise_index *i, *r_i = &r->r_index; + uint32_t idx; + + noise_remote_index_remove(l, r); + + rw_wlock(&l->l_index_lock); +assign_id: + r_i->i_local_index = arc4random(); + idx = r_i->i_local_index & HT_INDEX_MASK; + CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) + if (i->i_local_index == r_i->i_local_index) + goto assign_id; + + CK_LIST_INSERT_HEAD(&l->l_index_hash[idx], r_i, i_entry); + rw_wunlock(&l->l_index_lock); + + r->r_handshake_alive = 1; +} + +struct noise_remote * +noise_remote_index_lookup(struct noise_local *l, uint32_t idx0) +{ + struct epoch_tracker et; + struct noise_index *i; + struct noise_remote *r, *ret = NULL; + uint32_t idx = idx0 & HT_INDEX_MASK; + + NET_EPOCH_ENTER(et); + CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) { + if (i->i_local_index == idx0 && !i->i_is_keypair) { + r = (struct noise_remote *) i; + if (refcount_acquire_if_not_zero(&r->r_refcnt)) + ret = r; + break; + } + } + NET_EPOCH_EXIT(et); + return (ret); +} + +static int +noise_remote_index_remove(struct noise_local *l, struct noise_remote *r) +{ + rw_assert_wrlock(&r->r_handshake_lock); + if (r->r_handshake_alive) { + rw_wlock(&l->l_index_lock); + CK_LIST_REMOVE(&r->r_index, i_entry); + rw_wunlock(&l->l_index_lock); + r->r_handshake_alive = 0; + return (1); + } + return (0); +} + +struct noise_remote * +noise_remote_ref(struct noise_remote *r) +{ + refcount_acquire(&r->r_refcnt); + return (r); +} + +static void +noise_remote_smr_free(struct epoch_context *smr) +{ + struct noise_remote *r; + r = __containerof(smr, struct noise_remote, r_smr); + if (r->r_cleanup != NULL) + r->r_cleanup(r); + noise_local_put(r->r_local); + explicit_bzero(r, sizeof(*r)); + free(r, M_NOISE); +} + +void +noise_remote_put(struct noise_remote *r) +{ + if (refcount_release(&r->r_refcnt)) + NET_EPOCH_CALL(noise_remote_smr_free, &r->r_smr); +} + +void +noise_remote_free(struct noise_remote *r, void (*cleanup)(struct noise_remote *)) +{ + struct noise_local *l = r->r_local; + + r->r_cleanup = cleanup; + + /* remove from hashtable */ + rw_wlock(&l->l_remote_lock); + CK_LIST_REMOVE(r, r_entry); + l->l_remote_num--; + rw_wunlock(&l->l_remote_lock); + + /* now clear all keypairs and handshakes, then put this reference */ + noise_remote_handshake_clear(r); + noise_remote_keypairs_clear(r); + noise_remote_put(r); +} + +struct noise_local * +noise_remote_local(struct noise_remote *r) +{ + return (noise_local_ref(r->r_local)); +} + +void * +noise_remote_arg(struct noise_remote *r) +{ + return (r->r_arg); +} + +void noise_remote_set_psk(struct noise_remote *r, const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) { - int same; - rw_enter_write(&r->r_handshake_lock); - same = !timingsafe_bcmp(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN); - if (!same) { + rw_wlock(&r->r_handshake_lock); + if (psk == NULL) + bzero(r->r_psk, NOISE_SYMMETRIC_KEY_LEN); + else memcpy(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN); - } - rw_exit_write(&r->r_handshake_lock); - return same ? EEXIST : 0; + rw_wunlock(&r->r_handshake_lock); } int @@ -162,35 +489,406 @@ noise_remote_keys(struct noise_remote *r, uint8_t public[NOISE_PUBLIC_KEY_LEN], if (public != NULL) memcpy(public, r->r_public, NOISE_PUBLIC_KEY_LEN); - rw_enter_read(&r->r_handshake_lock); + rw_rlock(&r->r_handshake_lock); if (psk != NULL) memcpy(psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); ret = timingsafe_bcmp(r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN); - rw_exit_read(&r->r_handshake_lock); + rw_runlock(&r->r_handshake_lock); - /* If r_psk != null_psk return 0, else ENOENT (no psk) */ - return ret ? 0 : ENOENT; + return (ret ? 0 : ENOENT); +} + +int +noise_remote_initiation_expired(struct noise_remote *r) +{ + int expired; + rw_rlock(&r->r_handshake_lock); + expired = noise_timer_expired(&r->r_last_sent, REKEY_TIMEOUT, 0); + rw_runlock(&r->r_handshake_lock); + return (expired); } void -noise_remote_precompute(struct noise_remote *r) +noise_remote_handshake_clear(struct noise_remote *r) { - struct noise_local *l = r->r_local; - rw_assert_wrlock(&l->l_identity_lock); - if (!l->l_has_identity) - bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN); - else if (!curve25519(r->r_ss, l->l_private, r->r_public)) - bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN); + rw_wlock(&r->r_handshake_lock); + if (noise_remote_index_remove(r->r_local, r)) + bzero(&r->r_handshake, sizeof(r->r_handshake)); + r->r_last_sent = TIMER_RESET; + rw_wunlock(&r->r_handshake_lock); +} + +void +noise_remote_keypairs_clear(struct noise_remote *r) +{ + struct noise_keypair *kp; + + rw_wlock(&r->r_keypair_lock); + kp = epoch_ptr_read(&r->r_next); + epoch_ptr_write(&r->r_next, NULL); + noise_keypair_drop(kp); + + kp = epoch_ptr_read(&r->r_current); + epoch_ptr_write(&r->r_current, NULL); + noise_keypair_drop(kp); + + kp = epoch_ptr_read(&r->r_previous); + epoch_ptr_write(&r->r_previous, NULL); + noise_keypair_drop(kp); + rw_wunlock(&r->r_keypair_lock); +} + +static void +noise_remote_expire_current(struct noise_remote *r) +{ + struct epoch_tracker et; + struct noise_keypair *kp; + + noise_remote_handshake_clear(r); + + NET_EPOCH_ENTER(et); + kp = epoch_ptr_read(&r->r_next); + if (kp != NULL) WRITE_ONCE(kp->kp_can_send, 0); + kp = epoch_ptr_read(&r->r_current); + if (kp != NULL) WRITE_ONCE(kp->kp_can_send, 0); + NET_EPOCH_EXIT(et); +} + +/* Keypair functions */ +static void +noise_add_new_keypair(struct noise_local *l, struct noise_remote *r, + struct noise_keypair *kp) +{ + struct noise_keypair *next, *current, *previous; + struct noise_index *r_i = &r->r_index; + + /* Insert into the keypair table */ + rw_wlock(&r->r_keypair_lock); + next = epoch_ptr_read(&r->r_next); + current = epoch_ptr_read(&r->r_current); + previous = epoch_ptr_read(&r->r_previous); + + if (kp->kp_is_initiator) { + if (next != NULL) { + epoch_ptr_write(&r->r_next, NULL); + epoch_ptr_write(&r->r_previous, next); + noise_keypair_drop(current); + } else { + epoch_ptr_write(&r->r_previous, current); + } + noise_keypair_drop(previous); + epoch_ptr_write(&r->r_current, kp); + } else { + epoch_ptr_write(&r->r_next, kp); + noise_keypair_drop(next); + epoch_ptr_write(&r->r_previous, NULL); + noise_keypair_drop(previous); + + } + rw_wunlock(&r->r_keypair_lock); + + /* Insert into index table */ + rw_assert_wrlock(&r->r_handshake_lock); + + kp->kp_index.i_is_keypair = 1; + kp->kp_index.i_local_index = r_i->i_local_index; + kp->kp_index.i_remote_index = r_i->i_remote_index; + + rw_wlock(&l->l_index_lock); + CK_LIST_INSERT_BEFORE(r_i, &kp->kp_index, i_entry); + CK_LIST_REMOVE(r_i, i_entry); + rw_wunlock(&l->l_index_lock); - rw_enter_write(&r->r_handshake_lock); - noise_remote_handshake_index_drop(r); explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); - rw_exit_write(&r->r_handshake_lock); } +static int +noise_received_with(struct noise_keypair *kp) +{ + struct epoch_tracker et; + struct noise_keypair *old; + struct noise_remote *r = kp->kp_remote; + + NET_EPOCH_ENTER(et); + if (kp != epoch_ptr_read(&r->r_next)) { + NET_EPOCH_EXIT(et); + return (0); + } + NET_EPOCH_EXIT(et); + + rw_wlock(&r->r_keypair_lock); + if (kp != epoch_ptr_read(&r->r_next)) { + rw_wunlock(&r->r_keypair_lock); + return (0); + } + + old = epoch_ptr_read(&r->r_previous); + epoch_ptr_write(&r->r_previous, epoch_ptr_read(&r->r_current)); + noise_keypair_drop(old); + epoch_ptr_write(&r->r_current, kp); + epoch_ptr_write(&r->r_next, NULL); + rw_wunlock(&r->r_keypair_lock); + + return (ECONNRESET); +} + +static int +noise_begin_session(struct noise_remote *r) +{ + struct noise_keypair *kp; + + rw_assert_wrlock(&r->r_handshake_lock); + + if ((kp = malloc(sizeof(*kp), M_NOISE, M_NOWAIT)) == NULL) + return (ENOSPC); + + refcount_init(&kp->kp_refcnt, 1); + kp->kp_can_send = 1; + kp->kp_is_initiator = r->r_handshake_initiator; + getnanouptime(&kp->kp_birthdate); + kp->kp_remote = noise_remote_ref(r); + + if (kp->kp_is_initiator) + noise_kdf(kp->kp_send, kp->kp_recv, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + r->r_handshake.hs_ck); + else + noise_kdf(kp->kp_recv, kp->kp_send, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + r->r_handshake.hs_ck); + + rw_init(&kp->kp_nonce_lock, "noise_nonce"); + kp->kp_nonce_send = 0; + kp->kp_nonce_recv = 0; + bzero(kp->kp_backtrack, sizeof(kp->kp_backtrack)); + bzero(&kp->kp_smr, sizeof(kp->kp_smr)); + + noise_add_new_keypair(r->r_local, r, kp); + return (0); +} + +struct noise_keypair * +noise_keypair_lookup(struct noise_local *l, uint32_t idx0) +{ + struct epoch_tracker et; + struct noise_index *i; + struct noise_keypair *kp, *ret = NULL; + uint32_t idx = idx0 & HT_INDEX_MASK; + + NET_EPOCH_ENTER(et); + CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) { + if (i->i_local_index == idx0 && i->i_is_keypair) { + kp = (struct noise_keypair *) i; + if (refcount_acquire_if_not_zero(&kp->kp_refcnt)) + ret = kp; + break; + } + } + NET_EPOCH_EXIT(et); + return (ret); +} + +struct noise_keypair * +noise_keypair_current(struct noise_remote *r) +{ + struct epoch_tracker et; + struct noise_keypair *kp, *ret = NULL; + + NET_EPOCH_ENTER(et); + kp = epoch_ptr_read(&r->r_current); + if (kp != NULL && READ_ONCE(kp->kp_can_send)) { + if (noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0)) + WRITE_ONCE(kp->kp_can_send, 0); + else if (refcount_acquire_if_not_zero(&kp->kp_refcnt)) + ret = kp; + } + NET_EPOCH_EXIT(et); + return (ret); +} + +struct noise_keypair * +noise_keypair_ref(struct noise_keypair *kp) +{ + refcount_acquire(&kp->kp_refcnt); + return (kp); +} + +static void +noise_keypair_smr_free(struct epoch_context *smr) +{ + struct noise_keypair *kp; + kp = __containerof(smr, struct noise_keypair, kp_smr); + noise_remote_put(kp->kp_remote); + explicit_bzero(kp, sizeof(*kp)); + free(kp, M_NOISE); +} + + +void +noise_keypair_put(struct noise_keypair *kp) +{ + if (refcount_release(&kp->kp_refcnt)) + NET_EPOCH_CALL(noise_keypair_smr_free, &kp->kp_smr); +} + +static void +noise_keypair_drop(struct noise_keypair *kp) +{ + struct noise_remote *r; + struct noise_local *l; + + if (kp == NULL) + return; + + r = kp->kp_remote; + l = r->r_local; + + rw_wlock(&l->l_index_lock); + CK_LIST_REMOVE(&kp->kp_index, i_entry); + rw_wunlock(&l->l_index_lock); + + noise_keypair_put(kp); +} + +struct noise_remote * +noise_keypair_remote(struct noise_keypair *kp) +{ + return (noise_remote_ref(kp->kp_remote)); +} + +void * +noise_keypair_remote_arg(struct noise_keypair *kp) +{ + return kp->kp_remote->r_arg; +} + + + +int +noise_keypair_nonce_next(struct noise_keypair *kp, uint64_t *send) +{ +#ifdef __LP64__ + *send = atomic_fetchadd_64(&kp->kp_nonce_send, 1); +#else + rw_wlock(&kp->kp_nonce_lock); + *send = ctr->c_send++; + rw_wunlock(&kp->kp_nonce_lock); +#endif + if (*send < REJECT_AFTER_MESSAGES) + return (0); + WRITE_ONCE(kp->kp_can_send, 0); + return (EINVAL); +} + +int +noise_keypair_nonce_check(struct noise_keypair *kp, uint64_t recv) +{ + uint64_t i, top, index_recv, index_ctr; + unsigned long bit; + int ret = EEXIST; + + rw_wlock(&kp->kp_nonce_lock); + + /* Check that the recv counter is valid */ + if (kp->kp_nonce_recv >= REJECT_AFTER_MESSAGES || + recv >= REJECT_AFTER_MESSAGES) + goto error; + + /* If the packet is out of the window, invalid */ + if (recv + COUNTER_WINDOW_SIZE < kp->kp_nonce_recv) + goto error; + + /* If the new counter is ahead of the current counter, we'll need to + * zero out the bitmap that has previously been used */ + index_recv = recv / COUNTER_BITS; + index_ctr = kp->kp_nonce_recv / COUNTER_BITS; + + if (recv > kp->kp_nonce_recv) { + top = MIN(index_recv - index_ctr, COUNTER_NUM); + for (i = 1; i <= top; i++) + kp->kp_backtrack[ + (i + index_ctr) & (COUNTER_NUM - 1)] = 0; + WRITE_ONCE(kp->kp_nonce_recv, recv); + } + + index_recv %= COUNTER_NUM; + bit = 1ul << (recv % COUNTER_BITS); + + if (kp->kp_backtrack[index_recv] & bit) + goto error; + + kp->kp_backtrack[index_recv] |= bit; + + ret = 0; +error: + rw_wunlock(&kp->kp_nonce_lock); + return (ret); +} + +int +noise_keep_key_fresh_send(struct noise_remote *r) +{ + struct epoch_tracker et; + struct noise_keypair *current; + int keep_key_fresh; + + NET_EPOCH_ENTER(et); + current = epoch_ptr_read(&r->r_current); + keep_key_fresh = current != NULL && READ_ONCE(current->kp_can_send) && ( + READ_ONCE(current->kp_nonce_send) > REKEY_AFTER_MESSAGES || + (current->kp_is_initiator && noise_timer_expired(¤t->kp_birthdate, REKEY_AFTER_TIME, 0))); + NET_EPOCH_EXIT(et); + + return (keep_key_fresh ? ESTALE : 0); +} + +int +noise_keep_key_fresh_recv(struct noise_remote *r) +{ + struct epoch_tracker et; + struct noise_keypair *current; + int keep_key_fresh; + + NET_EPOCH_ENTER(et); + current = epoch_ptr_read(&r->r_current); + keep_key_fresh = current != NULL && READ_ONCE(current->kp_can_send) && + current->kp_is_initiator && noise_timer_expired(¤t->kp_birthdate, + REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT, 0); + NET_EPOCH_EXIT(et); + + return (keep_key_fresh ? ESTALE : 0); +} + +void +noise_keypair_encrypt(struct noise_keypair *kp, uint32_t *r_idx, uint64_t nonce, + uint8_t *buf, size_t buflen) +{ + chacha20poly1305_encrypt(buf, buf, buflen, NULL, 0, nonce, kp->kp_send); + *r_idx = kp->kp_index.i_remote_index; +} + +int +noise_keypair_decrypt(struct noise_keypair *kp, uint64_t nonce, uint8_t *buf, + size_t buflen) +{ + if (READ_ONCE(kp->kp_nonce_recv) >= REJECT_AFTER_MESSAGES || + noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0)) + return (EINVAL); + + if (chacha20poly1305_decrypt(buf, buf, buflen, NULL, 0, nonce, kp->kp_recv) == 0) + return (EINVAL); + + if (noise_received_with(kp) != 0) + return (ECONNRESET); + + return (0); +} + + /* Handshake functions */ int -noise_create_initiation(struct noise_remote *r, uint32_t *s_idx, +noise_create_initiation(struct noise_remote *r, + uint32_t *s_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]) @@ -200,10 +898,12 @@ noise_create_initiation(struct noise_remote *r, uint32_t *s_idx, uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; int ret = EINVAL; - rw_enter_read(&l->l_identity_lock); - rw_enter_write(&r->r_handshake_lock); + rw_rlock(&l->l_identity_lock); + rw_wlock(&r->r_handshake_lock); if (!l->l_has_identity) goto error; + if (!noise_timer_expired(&r->r_last_sent, REKEY_TIMEOUT, 0)) + goto error; noise_param_init(hs->hs_ck, hs->hs_hash, r->r_public); /* e */ @@ -229,21 +929,22 @@ noise_create_initiation(struct noise_remote *r, uint32_t *s_idx, noise_msg_encrypt(ets, ets, NOISE_TIMESTAMP_LEN, key, hs->hs_hash); - noise_remote_handshake_index_drop(r); - hs->hs_state = CREATED_INITIATION; - hs->hs_local_index = noise_remote_handshake_index_get(r); - *s_idx = hs->hs_local_index; + noise_remote_index_insert(l, r); + getnanouptime(&r->r_last_sent); + *s_idx = r->r_index.i_local_index; + r->r_handshake_initiator = 1; ret = 0; error: - rw_exit_write(&r->r_handshake_lock); - rw_exit_read(&l->l_identity_lock); + rw_wunlock(&r->r_handshake_lock); + rw_runlock(&l->l_identity_lock); explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); - return ret; + return (ret); } int noise_consume_initiation(struct noise_local *l, struct noise_remote **rp, - uint32_t s_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint32_t s_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]) { @@ -254,7 +955,7 @@ noise_consume_initiation(struct noise_local *l, struct noise_remote **rp, uint8_t timestamp[NOISE_TIMESTAMP_LEN]; int ret = EINVAL; - rw_enter_read(&l->l_identity_lock); + rw_rlock(&l->l_identity_lock); if (!l->l_has_identity) goto error; noise_param_init(hs.hs_ck, hs.hs_hash, l->l_public); @@ -272,23 +973,23 @@ noise_consume_initiation(struct noise_local *l, struct noise_remote **rp, goto error; /* Lookup the remote we received from */ - if ((r = l->l_upcall.u_remote_get(l->l_upcall.u_arg, r_public)) == NULL) + if ((r = noise_remote_lookup(l, r_public)) == NULL) goto error; /* ss */ if (noise_mix_ss(hs.hs_ck, key, r->r_ss) != 0) - goto error; + goto error_put; /* {t} */ if (noise_msg_decrypt(timestamp, ets, NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0) - goto error; + goto error_put; memcpy(hs.hs_e, ue, NOISE_PUBLIC_KEY_LEN); /* We have successfully computed the same results, now we ensure that * this is not an initiation replay, or a flood attack */ - rw_enter_write(&r->r_handshake_lock); + rw_wlock(&r->r_handshake_lock); /* Replay */ if (memcmp(timestamp, r->r_timestamp, NOISE_TIMESTAMP_LEN) > 0) @@ -296,41 +997,45 @@ noise_consume_initiation(struct noise_local *l, struct noise_remote **rp, else goto error_set; /* Flood attack */ - if (noise_timer_expired(&r->r_last_init, 0, REJECT_INTERVAL)) - getnanouptime(&r->r_last_init); + if (noise_timer_expired(&r->r_last_init_recv, 0, REJECT_INTERVAL)) + getnanouptime(&r->r_last_init_recv); else goto error_set; /* Ok, we're happy to accept this initiation now */ - noise_remote_handshake_index_drop(r); - hs.hs_state = CONSUMED_INITIATION; - hs.hs_local_index = noise_remote_handshake_index_get(r); - hs.hs_remote_index = s_idx; + noise_remote_index_insert(l, r); + r->r_index.i_remote_index = s_idx; + r->r_handshake_initiator = 0; r->r_handshake = hs; - *rp = r; + *rp = noise_remote_ref(r); ret = 0; error_set: - rw_exit_write(&r->r_handshake_lock); + rw_wunlock(&r->r_handshake_lock); +error_put: + noise_remote_put(r); error: - rw_exit_read(&l->l_identity_lock); + rw_runlock(&l->l_identity_lock); explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); explicit_bzero(&hs, sizeof(hs)); - return ret; + return (ret); } int -noise_create_response(struct noise_remote *r, uint32_t *s_idx, uint32_t *r_idx, - uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN]) +noise_create_response(struct noise_remote *r, + uint32_t *s_idx, uint32_t *r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t en[0 + NOISE_AUTHTAG_LEN]) { struct noise_handshake *hs = &r->r_handshake; + struct noise_local *l = r->r_local; uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; uint8_t e[NOISE_PUBLIC_KEY_LEN]; int ret = EINVAL; - rw_enter_read(&r->r_local->l_identity_lock); - rw_enter_write(&r->r_handshake_lock); + rw_rlock(&l->l_identity_lock); + rw_wlock(&r->r_handshake_lock); - if (hs->hs_state != CONSUMED_INITIATION) + if (!r->r_handshake_alive || r->r_handshake_initiator) goto error; /* e */ @@ -353,51 +1058,57 @@ noise_create_response(struct noise_remote *r, uint32_t *s_idx, uint32_t *r_idx, /* {} */ noise_msg_encrypt(en, NULL, 0, key, hs->hs_hash); - hs->hs_state = CREATED_RESPONSE; - *r_idx = hs->hs_remote_index; - *s_idx = hs->hs_local_index; - ret = 0; + if ((ret = noise_begin_session(r)) == 0) { + getnanouptime(&r->r_last_sent); + *s_idx = r->r_index.i_local_index; + *r_idx = r->r_index.i_remote_index; + } error: - rw_exit_write(&r->r_handshake_lock); - rw_exit_read(&r->r_local->l_identity_lock); + rw_wunlock(&r->r_handshake_lock); + rw_runlock(&l->l_identity_lock); explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); explicit_bzero(e, NOISE_PUBLIC_KEY_LEN); - return ret; + return (ret); } int -noise_consume_response(struct noise_remote *r, uint32_t s_idx, uint32_t r_idx, - uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN]) +noise_consume_response(struct noise_local *l, struct noise_remote **rp, + uint32_t s_idx, uint32_t r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t en[0 + NOISE_AUTHTAG_LEN]) { - struct noise_local *l = r->r_local; - struct noise_handshake hs; + uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN]; uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; - uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN]; + struct noise_handshake hs; + struct noise_remote *r = NULL; int ret = EINVAL; - rw_enter_read(&l->l_identity_lock); + if ((r = noise_remote_index_lookup(l, r_idx)) == NULL) + return (ret); + + rw_rlock(&l->l_identity_lock); if (!l->l_has_identity) goto error; - rw_enter_read(&r->r_handshake_lock); - hs = r->r_handshake; - memcpy(preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); - rw_exit_read(&r->r_handshake_lock); - - if (hs.hs_state != CREATED_INITIATION || - hs.hs_local_index != r_idx) + rw_rlock(&r->r_handshake_lock); + if (!r->r_handshake_alive || !r->r_handshake_initiator) { + rw_runlock(&r->r_handshake_lock); goto error; + } + memcpy(preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); + hs = r->r_handshake; + rw_runlock(&r->r_handshake_lock); /* e */ noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue); /* ee */ if (noise_mix_dh(hs.hs_ck, NULL, hs.hs_e, ue) != 0) - goto error; + goto error_zero; /* se */ if (noise_mix_dh(hs.hs_ck, NULL, l->l_private, ue) != 0) - goto error; + goto error_zero; /* psk */ noise_mix_psk(hs.hs_ck, hs.hs_hash, key, preshared_key); @@ -405,365 +1116,28 @@ noise_consume_response(struct noise_remote *r, uint32_t s_idx, uint32_t r_idx, /* {} */ if (noise_msg_decrypt(NULL, en, 0 + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0) - goto error; + goto error_zero; - hs.hs_remote_index = s_idx; - - rw_enter_write(&r->r_handshake_lock); - if (r->r_handshake.hs_state == hs.hs_state && - r->r_handshake.hs_local_index == hs.hs_local_index) { + rw_wlock(&r->r_handshake_lock); + if (r->r_handshake_alive && r->r_handshake_initiator && + r->r_index.i_local_index == r_idx) { r->r_handshake = hs; - r->r_handshake.hs_state = CONSUMED_RESPONSE; - ret = 0; + r->r_index.i_remote_index = s_idx; + ret = noise_begin_session(r); + *rp = noise_remote_ref(r); } - rw_exit_write(&r->r_handshake_lock); -error: - rw_exit_read(&l->l_identity_lock); - explicit_bzero(&hs, sizeof(hs)); + rw_wunlock(&r->r_handshake_lock); +error_zero: + explicit_bzero(preshared_key, NOISE_SYMMETRIC_KEY_LEN); explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); - return ret; -} - -int -noise_remote_begin_session(struct noise_remote *r) -{ - struct noise_handshake *hs = &r->r_handshake; - struct noise_keypair kp, *next, *current, *previous; - - rw_enter_write(&r->r_handshake_lock); - - /* We now derive the keypair from the handshake */ - if (hs->hs_state == CONSUMED_RESPONSE) { - kp.kp_is_initiator = 1; - noise_kdf(kp.kp_send, kp.kp_recv, NULL, NULL, - NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, - hs->hs_ck); - } else if (hs->hs_state == CREATED_RESPONSE) { - kp.kp_is_initiator = 0; - noise_kdf(kp.kp_recv, kp.kp_send, NULL, NULL, - NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, - hs->hs_ck); - } else { - rw_exit_write(&r->r_handshake_lock); - return EINVAL; - } - - kp.kp_valid = 1; - kp.kp_local_index = hs->hs_local_index; - kp.kp_remote_index = hs->hs_remote_index; - getnanouptime(&kp.kp_birthdate); - bzero(&kp.kp_ctr, sizeof(kp.kp_ctr)); - rw_init(&kp.kp_ctr.c_lock, "noise_counter"); - - /* Now we need to add_new_keypair */ - rw_enter_write(&r->r_keypair_lock); - next = r->r_next; - current = r->r_current; - previous = r->r_previous; - - if (kp.kp_is_initiator) { - if (next != NULL) { - r->r_next = NULL; - r->r_previous = next; - noise_remote_keypair_free(r, current); - } else { - r->r_previous = current; - } - - noise_remote_keypair_free(r, previous); - - r->r_current = noise_remote_keypair_allocate(r); - *r->r_current = kp; - } else { - noise_remote_keypair_free(r, next); - r->r_previous = NULL; - noise_remote_keypair_free(r, previous); - - r->r_next = noise_remote_keypair_allocate(r); - *r->r_next = kp; - } - rw_exit_write(&r->r_keypair_lock); - - explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); - rw_exit_write(&r->r_handshake_lock); - - explicit_bzero(&kp, sizeof(kp)); - return 0; -} - -void -noise_remote_clear(struct noise_remote *r) -{ - rw_enter_write(&r->r_handshake_lock); - noise_remote_handshake_index_drop(r); - explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); - rw_exit_write(&r->r_handshake_lock); - - rw_enter_write(&r->r_keypair_lock); - noise_remote_keypair_free(r, r->r_next); - noise_remote_keypair_free(r, r->r_current); - noise_remote_keypair_free(r, r->r_previous); - r->r_next = NULL; - r->r_current = NULL; - r->r_previous = NULL; - rw_exit_write(&r->r_keypair_lock); -} - -void -noise_remote_expire_current(struct noise_remote *r) -{ - rw_enter_write(&r->r_keypair_lock); - if (r->r_next != NULL) - r->r_next->kp_valid = 0; - if (r->r_current != NULL) - r->r_current->kp_valid = 0; - rw_exit_write(&r->r_keypair_lock); -} - -int -noise_remote_ready(struct noise_remote *r) -{ - struct noise_keypair *kp; - int ret; - - rw_enter_read(&r->r_keypair_lock); - /* kp_ctr isn't locked here, we're happy to accept a racy read. */ - if ((kp = r->r_current) == NULL || - !kp->kp_valid || - noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) || - kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES || - kp->kp_ctr.c_send >= REJECT_AFTER_MESSAGES) - ret = EINVAL; - else - ret = 0; - rw_exit_read(&r->r_keypair_lock); - return ret; -} - -int -noise_remote_encrypt(struct noise_remote *r, uint32_t *r_idx, uint64_t *nonce, - uint8_t *buf, size_t buflen) -{ - struct noise_keypair *kp; - int ret = EINVAL; - - rw_enter_read(&r->r_keypair_lock); - if ((kp = r->r_current) == NULL) - goto error; - - /* We confirm that our values are within our tolerances. We want: - * - a valid keypair - * - our keypair to be less than REJECT_AFTER_TIME seconds old - * - our receive counter to be less than REJECT_AFTER_MESSAGES - * - our send counter to be less than REJECT_AFTER_MESSAGES - * - * kp_ctr isn't locked here, we're happy to accept a racy read. */ - if (!kp->kp_valid || - noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) || - kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES || - ((*nonce = noise_counter_send(&kp->kp_ctr)) > REJECT_AFTER_MESSAGES)) - goto error; - - /* We encrypt into the same buffer, so the caller must ensure that buf - * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index - * are passed back out to the caller through the provided data pointer. */ - *r_idx = kp->kp_remote_index; - chacha20poly1305_encrypt(buf, buf, buflen, - NULL, 0, *nonce, kp->kp_send); - - /* If our values are still within tolerances, but we are approaching - * the tolerances, we notify the caller with ESTALE that they should - * establish a new keypair. The current keypair can continue to be used - * until the tolerances are hit. We notify if: - * - our send counter is valid and not less than REKEY_AFTER_MESSAGES - * - we're the initiator and our keypair is older than - * REKEY_AFTER_TIME seconds */ - ret = ESTALE; - if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) || - (kp->kp_is_initiator && - noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME, 0))) - goto error; - - ret = 0; -error: - rw_exit_read(&r->r_keypair_lock); - return ret; -} - -int -noise_remote_decrypt(struct noise_remote *r, uint32_t r_idx, uint64_t nonce, - uint8_t *buf, size_t buflen) -{ - struct noise_keypair *kp; - int ret = EINVAL; - - /* We retrieve the keypair corresponding to the provided index. We - * attempt the current keypair first as that is most likely. We also - * want to make sure that the keypair is valid as it would be - * catastrophic to decrypt against a zero'ed keypair. */ - rw_enter_read(&r->r_keypair_lock); - - if (r->r_current != NULL && r->r_current->kp_local_index == r_idx) { - kp = r->r_current; - } else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx) { - kp = r->r_previous; - } else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx) { - kp = r->r_next; - } else { - goto error; - } - - /* We confirm that our values are within our tolerances. These values - * are the same as the encrypt routine. - * - * kp_ctr isn't locked here, we're happy to accept a racy read. */ - if (noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) || - kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES) - goto error; - - /* Decrypt, then validate the counter. We don't want to validate the - * counter before decrypting as we do not know the message is authentic - * prior to decryption. */ - if (chacha20poly1305_decrypt(buf, buf, buflen, - NULL, 0, nonce, kp->kp_recv) == 0) - goto error; - - if (noise_counter_recv(&kp->kp_ctr, nonce) != 0) - goto error; - - /* If we've received the handshake confirming data packet then move the - * next keypair into current. If we do slide the next keypair in, then - * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a - * data packet can't confirm a session that we are an INITIATOR of. */ - if (kp == r->r_next) { - rw_exit_read(&r->r_keypair_lock); - rw_enter_write(&r->r_keypair_lock); - if (kp == r->r_next && kp->kp_local_index == r_idx) { - noise_remote_keypair_free(r, r->r_previous); - r->r_previous = r->r_current; - r->r_current = r->r_next; - r->r_next = NULL; - - ret = ECONNRESET; - goto error; - } - rw_enter(&r->r_keypair_lock, RW_DOWNGRADE); - } - - /* Similar to when we encrypt, we want to notify the caller when we - * are approaching our tolerances. We notify if: - * - we're the initiator and the current keypair is older than - * REKEY_AFTER_TIME_RECV seconds. */ - ret = ESTALE; - kp = r->r_current; - if (kp != NULL && - kp->kp_valid && - kp->kp_is_initiator && - noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME_RECV, 0)) - goto error; - - ret = 0; - -error: - rw_exit(&r->r_keypair_lock); - return ret; -} - -/* Private functions - these should not be called outside this file under any - * circumstances. */ -static struct noise_keypair * -noise_remote_keypair_allocate(struct noise_remote *r) -{ - struct noise_keypair *kp; - kp = SLIST_FIRST(&r->r_unused_keypairs); - SLIST_REMOVE_HEAD(&r->r_unused_keypairs, kp_entry); - return kp; -} - -static void -noise_remote_keypair_free(struct noise_remote *r, struct noise_keypair *kp) -{ - struct noise_upcall *u = &r->r_local->l_upcall; - if (kp != NULL) { - SLIST_INSERT_HEAD(&r->r_unused_keypairs, kp, kp_entry); - u->u_index_drop(u->u_arg, kp->kp_local_index); - bzero(kp->kp_send, sizeof(kp->kp_send)); - bzero(kp->kp_recv, sizeof(kp->kp_recv)); - } -} - -static uint32_t -noise_remote_handshake_index_get(struct noise_remote *r) -{ - struct noise_upcall *u = &r->r_local->l_upcall; - return u->u_index_set(u->u_arg, r); -} - -static void -noise_remote_handshake_index_drop(struct noise_remote *r) -{ - struct noise_handshake *hs = &r->r_handshake; - struct noise_upcall *u = &r->r_local->l_upcall; - rw_assert_wrlock(&r->r_handshake_lock); - if (hs->hs_state != HS_ZEROED) - u->u_index_drop(u->u_arg, hs->hs_local_index); -} - -static uint64_t -noise_counter_send(struct noise_counter *ctr) -{ - uint64_t ret; - rw_enter_write(&ctr->c_lock); - ret = ctr->c_send++; - rw_exit_write(&ctr->c_lock); - return ret; -} - -static int -noise_counter_recv(struct noise_counter *ctr, uint64_t recv) -{ - uint64_t i, top, index_recv, index_ctr; - unsigned long bit; - int ret = EEXIST; - - rw_enter_write(&ctr->c_lock); - - /* Check that the recv counter is valid */ - if (ctr->c_recv >= REJECT_AFTER_MESSAGES || - recv >= REJECT_AFTER_MESSAGES) - goto error; - - /* If the packet is out of the window, invalid */ - if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv) - goto error; - - /* If the new counter is ahead of the current counter, we'll need to - * zero out the bitmap that has previously been used */ - index_recv = recv / COUNTER_BITS; - index_ctr = ctr->c_recv / COUNTER_BITS; - - if (recv > ctr->c_recv) { - top = MIN(index_recv - index_ctr, COUNTER_NUM); - for (i = 1; i <= top; i++) - ctr->c_backtrack[ - (i + index_ctr) & (COUNTER_NUM - 1)] = 0; - ctr->c_recv = recv; - } - - index_recv %= COUNTER_NUM; - bit = 1ul << (recv % COUNTER_BITS); - - if (ctr->c_backtrack[index_recv] & bit) - goto error; - - ctr->c_backtrack[index_recv] |= bit; - - ret = 0; + explicit_bzero(&hs, sizeof(hs)); error: - rw_exit_write(&ctr->c_lock); - return ret; + rw_runlock(&l->l_identity_lock); + noise_remote_put(r); + return (ret); } +/* Handshake helper functions */ static void noise_kdf(uint8_t *a, uint8_t *b, uint8_t *c, const uint8_t *x, size_t a_len, size_t b_len, size_t c_len, size_t x_len, @@ -772,13 +1146,6 @@ noise_kdf(uint8_t *a, uint8_t *b, uint8_t *c, const uint8_t *x, uint8_t out[BLAKE2S_HASH_SIZE + 1]; uint8_t sec[BLAKE2S_HASH_SIZE]; -#ifdef DIAGNOSTIC - MPASS(a_len <= BLAKE2S_HASH_SIZE && b_len <= BLAKE2S_HASH_SIZE && - c_len <= BLAKE2S_HASH_SIZE); - MPASS(!(b || b_len || c || c_len) || (a && a_len)); - MPASS(!(c || c_len) || (b && b_len)); -#endif - /* Extract entropy from "x" into sec */ blake2s_hmac(sec, x, ck, BLAKE2S_HASH_SIZE, x_len, NOISE_HASH_LEN); @@ -822,11 +1189,11 @@ noise_mix_dh(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t dh[NOISE_PUBLIC_KEY_LEN]; if (!curve25519(dh, private, public)) - return EINVAL; + return (EINVAL); noise_kdf(ck, key, NULL, dh, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck); explicit_bzero(dh, NOISE_PUBLIC_KEY_LEN); - return 0; + return (0); } static int @@ -835,10 +1202,10 @@ noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN], { static uint8_t null_point[NOISE_PUBLIC_KEY_LEN]; if (timingsafe_bcmp(ss, null_point, NOISE_PUBLIC_KEY_LEN) == 0) - return ENOENT; + return (ENOENT); noise_kdf(ck, key, NULL, ss, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck); - return 0; + return (0); } static void @@ -901,9 +1268,9 @@ noise_msg_decrypt(uint8_t *dst, const uint8_t *src, size_t src_len, /* Nonce always zero for Noise_IK */ if (!chacha20poly1305_decrypt(dst, src, src_len, hash, NOISE_HASH_LEN, 0, key)) - return EINVAL; + return (EINVAL); noise_mix_hash(hash, src, src_len); - return 0; + return (0); } static void @@ -942,11 +1309,7 @@ noise_timer_expired(struct timespec *birthdate, time_t sec, long nsec) struct timespec uptime; struct timespec expire = { .tv_sec = sec, .tv_nsec = nsec }; - /* We don't really worry about a zeroed birthdate, to avoid the extra - * check on every encrypt/decrypt. This does mean that r_last_init - * check may fail if getnanouptime is < REJECT_INTERVAL from 0. */ - getnanouptime(&uptime); timespecadd(birthdate, &expire, &expire); - return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0; + return (timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0); } |