aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-03-17 09:34:21 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2021-03-17 09:35:54 -0600
commit362884e65029464d97e50c9b660b5b90621e239e (patch)
tree814c9607aa8ef7d3a4a4e4866071219f959c324d /src
downloadwireguard-freebsd-362884e65029464d97e50c9b660b5b90621e239e.tar.xz
wireguard-freebsd-362884e65029464d97e50c9b660b5b90621e239e.zip
Initial import
There's still more to do with wiring this up properly. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'src')
-rw-r--r--src/Makefile11
-rw-r--r--src/crypto.c1694
-rw-r--r--src/crypto.h103
-rw-r--r--src/if_wg.c3451
-rw-r--r--src/if_wg.h37
-rw-r--r--src/support.h56
-rw-r--r--src/wg_cookie.c427
-rw-r--r--src/wg_cookie.h114
-rw-r--r--src/wg_noise.c952
-rw-r--r--src/wg_noise.h180
10 files changed, 7025 insertions, 0 deletions
diff --git a/src/Makefile b/src/Makefile
new file mode 100644
index 0000000..d65cf80
--- /dev/null
+++ b/src/Makefile
@@ -0,0 +1,11 @@
+# $FreeBSD$
+
+KMOD= if_wg
+
+.PATH: ${SRCTOP}/sys/dev/if_wg
+
+SRCS= opt_inet.h opt_inet6.h device_if.h bus_if.h ifdi_if.h
+
+SRCS+= if_wg.c wg_noise.c wg_cookie.c crypto.c
+
+.include <bsd.kmod.mk>
diff --git a/src/crypto.c b/src/crypto.c
new file mode 100644
index 0000000..97bef62
--- /dev/null
+++ b/src/crypto.c
@@ -0,0 +1,1694 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ */
+
+#include <sys/types.h>
+#include <sys/endian.h>
+#include <sys/systm.h>
+
+#include "crypto.h"
+
+#ifndef ARRAY_SIZE
+#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
+#endif
+#ifndef noinline
+#define noinline __attribute__((noinline))
+#endif
+#ifndef __aligned
+#define __aligned(x) __attribute__((aligned(x)))
+#endif
+#ifndef DIV_ROUND_UP
+#define DIV_ROUND_UP(n,d) (((n) + (d) - 1) / (d))
+#endif
+
+#define le32_to_cpup(a) le32toh(*(a))
+#define le64_to_cpup(a) le64toh(*(a))
+#define cpu_to_le32(a) htole32(a)
+#define cpu_to_le64(a) htole64(a)
+
+static inline uint32_t get_unaligned_le32(const uint8_t *a)
+{
+ uint32_t l;
+ __builtin_memcpy(&l, a, sizeof(l));
+ return le32_to_cpup(&l);
+}
+static inline uint64_t get_unaligned_le64(const uint8_t *a)
+{
+ uint64_t l;
+ __builtin_memcpy(&l, a, sizeof(l));
+ return le64_to_cpup(&l);
+}
+static inline void put_unaligned_le32(uint32_t s, uint8_t *d)
+{
+ uint32_t l = cpu_to_le32(s);
+ __builtin_memcpy(d, &l, sizeof(l));
+}
+static inline void cpu_to_le32_array(uint32_t *buf, unsigned int words)
+{
+ while (words--) {
+ *buf = cpu_to_le32(*buf);
+ ++buf;
+ }
+}
+static inline void le32_to_cpu_array(uint32_t *buf, unsigned int words)
+{
+ while (words--) {
+ *buf = le32_to_cpup(buf);
+ ++buf;
+ }
+}
+
+static inline uint32_t rol32(uint32_t word, unsigned int shift)
+{
+ return (word << (shift & 31)) | (word >> ((-shift) & 31));
+}
+static inline uint32_t ror32(uint32_t word, unsigned int shift)
+{
+ return (word >> (shift & 31)) | (word << ((-shift) & 31));
+}
+
+static void xor_cpy(uint8_t *dst, const uint8_t *src1, const uint8_t *src2,
+ size_t len)
+{
+ size_t i;
+
+ for (i = 0; i < len; ++i)
+ dst[i] = src1[i] ^ src2[i];
+}
+
+#define QUARTER_ROUND(x, a, b, c, d) ( \
+ x[a] += x[b], \
+ x[d] = rol32((x[d] ^ x[a]), 16), \
+ x[c] += x[d], \
+ x[b] = rol32((x[b] ^ x[c]), 12), \
+ x[a] += x[b], \
+ x[d] = rol32((x[d] ^ x[a]), 8), \
+ x[c] += x[d], \
+ x[b] = rol32((x[b] ^ x[c]), 7) \
+)
+
+#define C(i, j) (i * 4 + j)
+
+#define DOUBLE_ROUND(x) ( \
+ /* Column Round */ \
+ QUARTER_ROUND(x, C(0, 0), C(1, 0), C(2, 0), C(3, 0)), \
+ QUARTER_ROUND(x, C(0, 1), C(1, 1), C(2, 1), C(3, 1)), \
+ QUARTER_ROUND(x, C(0, 2), C(1, 2), C(2, 2), C(3, 2)), \
+ QUARTER_ROUND(x, C(0, 3), C(1, 3), C(2, 3), C(3, 3)), \
+ /* Diagonal Round */ \
+ QUARTER_ROUND(x, C(0, 0), C(1, 1), C(2, 2), C(3, 3)), \
+ QUARTER_ROUND(x, C(0, 1), C(1, 2), C(2, 3), C(3, 0)), \
+ QUARTER_ROUND(x, C(0, 2), C(1, 3), C(2, 0), C(3, 1)), \
+ QUARTER_ROUND(x, C(0, 3), C(1, 0), C(2, 1), C(3, 2)) \
+)
+
+#define TWENTY_ROUNDS(x) ( \
+ DOUBLE_ROUND(x), \
+ DOUBLE_ROUND(x), \
+ DOUBLE_ROUND(x), \
+ DOUBLE_ROUND(x), \
+ DOUBLE_ROUND(x), \
+ DOUBLE_ROUND(x), \
+ DOUBLE_ROUND(x), \
+ DOUBLE_ROUND(x), \
+ DOUBLE_ROUND(x), \
+ DOUBLE_ROUND(x) \
+)
+
+enum chacha20_lengths {
+ CHACHA20_NONCE_SIZE = 16,
+ CHACHA20_KEY_SIZE = 32,
+ CHACHA20_KEY_WORDS = CHACHA20_KEY_SIZE / sizeof(uint32_t),
+ CHACHA20_BLOCK_SIZE = 64,
+ CHACHA20_BLOCK_WORDS = CHACHA20_BLOCK_SIZE / sizeof(uint32_t),
+ HCHACHA20_NONCE_SIZE = CHACHA20_NONCE_SIZE,
+ HCHACHA20_KEY_SIZE = CHACHA20_KEY_SIZE
+};
+
+enum chacha20_constants { /* expand 32-byte k */
+ CHACHA20_CONSTANT_EXPA = 0x61707865U,
+ CHACHA20_CONSTANT_ND_3 = 0x3320646eU,
+ CHACHA20_CONSTANT_2_BY = 0x79622d32U,
+ CHACHA20_CONSTANT_TE_K = 0x6b206574U
+};
+
+struct chacha20_ctx {
+ union {
+ uint32_t state[16];
+ struct {
+ uint32_t constant[4];
+ uint32_t key[8];
+ uint32_t counter[4];
+ };
+ };
+};
+
+static void chacha20_init(struct chacha20_ctx *ctx,
+ const uint8_t key[CHACHA20_KEY_SIZE],
+ const uint64_t nonce)
+{
+ ctx->constant[0] = CHACHA20_CONSTANT_EXPA;
+ ctx->constant[1] = CHACHA20_CONSTANT_ND_3;
+ ctx->constant[2] = CHACHA20_CONSTANT_2_BY;
+ ctx->constant[3] = CHACHA20_CONSTANT_TE_K;
+ ctx->key[0] = get_unaligned_le32(key + 0);
+ ctx->key[1] = get_unaligned_le32(key + 4);
+ ctx->key[2] = get_unaligned_le32(key + 8);
+ ctx->key[3] = get_unaligned_le32(key + 12);
+ ctx->key[4] = get_unaligned_le32(key + 16);
+ ctx->key[5] = get_unaligned_le32(key + 20);
+ ctx->key[6] = get_unaligned_le32(key + 24);
+ ctx->key[7] = get_unaligned_le32(key + 28);
+ ctx->counter[0] = 0;
+ ctx->counter[1] = 0;
+ ctx->counter[2] = nonce & 0xffffffffU;
+ ctx->counter[3] = nonce >> 32;
+}
+
+static void chacha20_block(struct chacha20_ctx *ctx, uint32_t *stream)
+{
+ uint32_t x[CHACHA20_BLOCK_WORDS];
+ int i;
+
+ for (i = 0; i < ARRAY_SIZE(x); ++i)
+ x[i] = ctx->state[i];
+
+ TWENTY_ROUNDS(x);
+
+ for (i = 0; i < ARRAY_SIZE(x); ++i)
+ stream[i] = cpu_to_le32(x[i] + ctx->state[i]);
+
+ ctx->counter[0] += 1;
+}
+
+static void chacha20(struct chacha20_ctx *ctx, uint8_t *out, const uint8_t *in,
+ uint32_t len)
+{
+ uint32_t buf[CHACHA20_BLOCK_WORDS];
+
+ while (len >= CHACHA20_BLOCK_SIZE) {
+ chacha20_block(ctx, buf);
+ xor_cpy(out, in, (uint8_t *)buf, CHACHA20_BLOCK_SIZE);
+ len -= CHACHA20_BLOCK_SIZE;
+ out += CHACHA20_BLOCK_SIZE;
+ in += CHACHA20_BLOCK_SIZE;
+ }
+ if (len) {
+ chacha20_block(ctx, buf);
+ xor_cpy(out, in, (uint8_t *)buf, len);
+ }
+}
+
+static void hchacha20(uint32_t derived_key[CHACHA20_KEY_WORDS],
+ const uint8_t nonce[HCHACHA20_NONCE_SIZE],
+ const uint8_t key[HCHACHA20_KEY_SIZE])
+{
+ uint32_t x[] = { CHACHA20_CONSTANT_EXPA,
+ CHACHA20_CONSTANT_ND_3,
+ CHACHA20_CONSTANT_2_BY,
+ CHACHA20_CONSTANT_TE_K,
+ get_unaligned_le32(key + 0),
+ get_unaligned_le32(key + 4),
+ get_unaligned_le32(key + 8),
+ get_unaligned_le32(key + 12),
+ get_unaligned_le32(key + 16),
+ get_unaligned_le32(key + 20),
+ get_unaligned_le32(key + 24),
+ get_unaligned_le32(key + 28),
+ get_unaligned_le32(nonce + 0),
+ get_unaligned_le32(nonce + 4),
+ get_unaligned_le32(nonce + 8),
+ get_unaligned_le32(nonce + 12)
+ };
+
+ TWENTY_ROUNDS(x);
+
+ memcpy(derived_key + 0, x + 0, sizeof(uint32_t) * 4);
+ memcpy(derived_key + 4, x + 12, sizeof(uint32_t) * 4);
+}
+
+enum poly1305_lengths {
+ POLY1305_BLOCK_SIZE = 16,
+ POLY1305_KEY_SIZE = 32,
+ POLY1305_MAC_SIZE = 16
+};
+
+struct poly1305_internal {
+ uint32_t h[5];
+ uint32_t r[5];
+ uint32_t s[4];
+};
+
+struct poly1305_ctx {
+ struct poly1305_internal state;
+ uint32_t nonce[4];
+ uint8_t data[POLY1305_BLOCK_SIZE];
+ size_t num;
+};
+
+static void poly1305_init_core(struct poly1305_internal *st,
+ const uint8_t key[16])
+{
+ /* r &= 0xffffffc0ffffffc0ffffffc0fffffff */
+ st->r[0] = (get_unaligned_le32(&key[0])) & 0x3ffffff;
+ st->r[1] = (get_unaligned_le32(&key[3]) >> 2) & 0x3ffff03;
+ st->r[2] = (get_unaligned_le32(&key[6]) >> 4) & 0x3ffc0ff;
+ st->r[3] = (get_unaligned_le32(&key[9]) >> 6) & 0x3f03fff;
+ st->r[4] = (get_unaligned_le32(&key[12]) >> 8) & 0x00fffff;
+
+ /* s = 5*r */
+ st->s[0] = st->r[1] * 5;
+ st->s[1] = st->r[2] * 5;
+ st->s[2] = st->r[3] * 5;
+ st->s[3] = st->r[4] * 5;
+
+ /* h = 0 */
+ st->h[0] = 0;
+ st->h[1] = 0;
+ st->h[2] = 0;
+ st->h[3] = 0;
+ st->h[4] = 0;
+}
+
+static void poly1305_blocks_core(struct poly1305_internal *st,
+ const uint8_t *input, size_t len,
+ const uint32_t padbit)
+{
+ const uint32_t hibit = padbit << 24;
+ uint32_t r0, r1, r2, r3, r4;
+ uint32_t s1, s2, s3, s4;
+ uint32_t h0, h1, h2, h3, h4;
+ uint64_t d0, d1, d2, d3, d4;
+ uint32_t c;
+
+ r0 = st->r[0];
+ r1 = st->r[1];
+ r2 = st->r[2];
+ r3 = st->r[3];
+ r4 = st->r[4];
+
+ s1 = st->s[0];
+ s2 = st->s[1];
+ s3 = st->s[2];
+ s4 = st->s[3];
+
+ h0 = st->h[0];
+ h1 = st->h[1];
+ h2 = st->h[2];
+ h3 = st->h[3];
+ h4 = st->h[4];
+
+ while (len >= POLY1305_BLOCK_SIZE) {
+ /* h += m[i] */
+ h0 += (get_unaligned_le32(&input[0])) & 0x3ffffff;
+ h1 += (get_unaligned_le32(&input[3]) >> 2) & 0x3ffffff;
+ h2 += (get_unaligned_le32(&input[6]) >> 4) & 0x3ffffff;
+ h3 += (get_unaligned_le32(&input[9]) >> 6) & 0x3ffffff;
+ h4 += (get_unaligned_le32(&input[12]) >> 8) | hibit;
+
+ /* h *= r */
+ d0 = ((uint64_t)h0 * r0) + ((uint64_t)h1 * s4) +
+ ((uint64_t)h2 * s3) + ((uint64_t)h3 * s2) +
+ ((uint64_t)h4 * s1);
+ d1 = ((uint64_t)h0 * r1) + ((uint64_t)h1 * r0) +
+ ((uint64_t)h2 * s4) + ((uint64_t)h3 * s3) +
+ ((uint64_t)h4 * s2);
+ d2 = ((uint64_t)h0 * r2) + ((uint64_t)h1 * r1) +
+ ((uint64_t)h2 * r0) + ((uint64_t)h3 * s4) +
+ ((uint64_t)h4 * s3);
+ d3 = ((uint64_t)h0 * r3) + ((uint64_t)h1 * r2) +
+ ((uint64_t)h2 * r1) + ((uint64_t)h3 * r0) +
+ ((uint64_t)h4 * s4);
+ d4 = ((uint64_t)h0 * r4) + ((uint64_t)h1 * r3) +
+ ((uint64_t)h2 * r2) + ((uint64_t)h3 * r1) +
+ ((uint64_t)h4 * r0);
+
+ /* (partial) h %= p */
+ c = (uint32_t)(d0 >> 26);
+ h0 = (uint32_t)d0 & 0x3ffffff;
+ d1 += c;
+ c = (uint32_t)(d1 >> 26);
+ h1 = (uint32_t)d1 & 0x3ffffff;
+ d2 += c;
+ c = (uint32_t)(d2 >> 26);
+ h2 = (uint32_t)d2 & 0x3ffffff;
+ d3 += c;
+ c = (uint32_t)(d3 >> 26);
+ h3 = (uint32_t)d3 & 0x3ffffff;
+ d4 += c;
+ c = (uint32_t)(d4 >> 26);
+ h4 = (uint32_t)d4 & 0x3ffffff;
+ h0 += c * 5;
+ c = (h0 >> 26);
+ h0 = h0 & 0x3ffffff;
+ h1 += c;
+
+ input += POLY1305_BLOCK_SIZE;
+ len -= POLY1305_BLOCK_SIZE;
+ }
+
+ st->h[0] = h0;
+ st->h[1] = h1;
+ st->h[2] = h2;
+ st->h[3] = h3;
+ st->h[4] = h4;
+}
+
+static void poly1305_emit_core(struct poly1305_internal *st, uint8_t mac[16],
+ const uint32_t nonce[4])
+{
+ uint32_t h0, h1, h2, h3, h4, c;
+ uint32_t g0, g1, g2, g3, g4;
+ uint64_t f;
+ uint32_t mask;
+
+ /* fully carry h */
+ h0 = st->h[0];
+ h1 = st->h[1];
+ h2 = st->h[2];
+ h3 = st->h[3];
+ h4 = st->h[4];
+
+ c = h1 >> 26;
+ h1 = h1 & 0x3ffffff;
+ h2 += c;
+ c = h2 >> 26;
+ h2 = h2 & 0x3ffffff;
+ h3 += c;
+ c = h3 >> 26;
+ h3 = h3 & 0x3ffffff;
+ h4 += c;
+ c = h4 >> 26;
+ h4 = h4 & 0x3ffffff;
+ h0 += c * 5;
+ c = h0 >> 26;
+ h0 = h0 & 0x3ffffff;
+ h1 += c;
+
+ /* compute h + -p */
+ g0 = h0 + 5;
+ c = g0 >> 26;
+ g0 &= 0x3ffffff;
+ g1 = h1 + c;
+ c = g1 >> 26;
+ g1 &= 0x3ffffff;
+ g2 = h2 + c;
+ c = g2 >> 26;
+ g2 &= 0x3ffffff;
+ g3 = h3 + c;
+ c = g3 >> 26;
+ g3 &= 0x3ffffff;
+ g4 = h4 + c - (1UL << 26);
+
+ /* select h if h < p, or h + -p if h >= p */
+ mask = (g4 >> ((sizeof(uint32_t) * 8) - 1)) - 1;
+ g0 &= mask;
+ g1 &= mask;
+ g2 &= mask;
+ g3 &= mask;
+ g4 &= mask;
+ mask = ~mask;
+
+ h0 = (h0 & mask) | g0;
+ h1 = (h1 & mask) | g1;
+ h2 = (h2 & mask) | g2;
+ h3 = (h3 & mask) | g3;
+ h4 = (h4 & mask) | g4;
+
+ /* h = h % (2^128) */
+ h0 = ((h0) | (h1 << 26)) & 0xffffffff;
+ h1 = ((h1 >> 6) | (h2 << 20)) & 0xffffffff;
+ h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff;
+ h3 = ((h3 >> 18) | (h4 << 8)) & 0xffffffff;
+
+ /* mac = (h + nonce) % (2^128) */
+ f = (uint64_t)h0 + nonce[0];
+ h0 = (uint32_t)f;
+ f = (uint64_t)h1 + nonce[1] + (f >> 32);
+ h1 = (uint32_t)f;
+ f = (uint64_t)h2 + nonce[2] + (f >> 32);
+ h2 = (uint32_t)f;
+ f = (uint64_t)h3 + nonce[3] + (f >> 32);
+ h3 = (uint32_t)f;
+
+ put_unaligned_le32(h0, &mac[0]);
+ put_unaligned_le32(h1, &mac[4]);
+ put_unaligned_le32(h2, &mac[8]);
+ put_unaligned_le32(h3, &mac[12]);
+}
+
+static void poly1305_init(struct poly1305_ctx *ctx,
+ const uint8_t key[POLY1305_KEY_SIZE])
+{
+ ctx->nonce[0] = get_unaligned_le32(&key[16]);
+ ctx->nonce[1] = get_unaligned_le32(&key[20]);
+ ctx->nonce[2] = get_unaligned_le32(&key[24]);
+ ctx->nonce[3] = get_unaligned_le32(&key[28]);
+
+ poly1305_init_core(&ctx->state, key);
+
+ ctx->num = 0;
+}
+
+static void poly1305_update(struct poly1305_ctx *ctx, const uint8_t *input,
+ size_t len)
+{
+ const size_t num = ctx->num;
+ size_t rem;
+
+ if (num) {
+ rem = POLY1305_BLOCK_SIZE - num;
+ if (len < rem) {
+ memcpy(ctx->data + num, input, len);
+ ctx->num = num + len;
+ return;
+ }
+ memcpy(ctx->data + num, input, rem);
+ poly1305_blocks_core(&ctx->state, ctx->data,
+ POLY1305_BLOCK_SIZE, 1);
+ input += rem;
+ len -= rem;
+ }
+
+ rem = len % POLY1305_BLOCK_SIZE;
+ len -= rem;
+
+ if (len >= POLY1305_BLOCK_SIZE) {
+ poly1305_blocks_core(&ctx->state, input, len, 1);
+ input += len;
+ }
+
+ if (rem)
+ memcpy(ctx->data, input, rem);
+
+ ctx->num = rem;
+}
+
+static void poly1305_final(struct poly1305_ctx *ctx,
+ uint8_t mac[POLY1305_MAC_SIZE])
+{
+ size_t num = ctx->num;
+
+ if (num) {
+ ctx->data[num++] = 1;
+ while (num < POLY1305_BLOCK_SIZE)
+ ctx->data[num++] = 0;
+ poly1305_blocks_core(&ctx->state, ctx->data,
+ POLY1305_BLOCK_SIZE, 0);
+ }
+
+ poly1305_emit_core(&ctx->state, mac, ctx->nonce);
+
+ explicit_bzero(ctx, sizeof(*ctx));
+}
+
+
+static const uint8_t pad0[16] = { 0 };
+
+void
+chacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, const size_t src_len,
+ const uint8_t *ad, const size_t ad_len,
+ const uint64_t nonce,
+ const uint8_t key[CHACHA20POLY1305_KEY_SIZE])
+{
+ struct poly1305_ctx poly1305_state;
+ struct chacha20_ctx chacha20_state;
+ union {
+ uint8_t block0[POLY1305_KEY_SIZE];
+ uint64_t lens[2];
+ } b = { { 0 } };
+
+ chacha20_init(&chacha20_state, key, nonce);
+ chacha20(&chacha20_state, b.block0, b.block0, sizeof(b.block0));
+ poly1305_init(&poly1305_state, b.block0);
+
+ poly1305_update(&poly1305_state, ad, ad_len);
+ poly1305_update(&poly1305_state, pad0, (0x10 - ad_len) & 0xf);
+
+ chacha20(&chacha20_state, dst, src, src_len);
+
+ poly1305_update(&poly1305_state, dst, src_len);
+ poly1305_update(&poly1305_state, pad0, (0x10 - src_len) & 0xf);
+
+ b.lens[0] = cpu_to_le64(ad_len);
+ b.lens[1] = cpu_to_le64(src_len);
+ poly1305_update(&poly1305_state, (uint8_t *)b.lens, sizeof(b.lens));
+
+ poly1305_final(&poly1305_state, dst + src_len);
+
+ explicit_bzero(&chacha20_state, sizeof(chacha20_state));
+ explicit_bzero(&b, sizeof(b));
+}
+
+bool
+chacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, const size_t src_len,
+ const uint8_t *ad, const size_t ad_len,
+ const uint64_t nonce,
+ const uint8_t key[CHACHA20POLY1305_KEY_SIZE])
+{
+ struct poly1305_ctx poly1305_state;
+ struct chacha20_ctx chacha20_state;
+ bool ret;
+ size_t dst_len;
+ union {
+ uint8_t block0[POLY1305_KEY_SIZE];
+ uint8_t mac[POLY1305_MAC_SIZE];
+ uint64_t lens[2];
+ } b = { { 0 } };
+
+ if (src_len < POLY1305_MAC_SIZE)
+ return false;
+
+ chacha20_init(&chacha20_state, key, nonce);
+ chacha20(&chacha20_state, b.block0, b.block0, sizeof(b.block0));
+ poly1305_init(&poly1305_state, b.block0);
+
+ poly1305_update(&poly1305_state, ad, ad_len);
+ poly1305_update(&poly1305_state, pad0, (0x10 - ad_len) & 0xf);
+
+ dst_len = src_len - POLY1305_MAC_SIZE;
+ poly1305_update(&poly1305_state, src, dst_len);
+ poly1305_update(&poly1305_state, pad0, (0x10 - dst_len) & 0xf);
+
+ b.lens[0] = cpu_to_le64(ad_len);
+ b.lens[1] = cpu_to_le64(dst_len);
+ poly1305_update(&poly1305_state, (uint8_t *)b.lens, sizeof(b.lens));
+
+ poly1305_final(&poly1305_state, b.mac);
+
+ ret = timingsafe_bcmp(b.mac, src + dst_len, POLY1305_MAC_SIZE) == 0;
+ if (ret)
+ chacha20(&chacha20_state, dst, src, dst_len);
+
+ explicit_bzero(&chacha20_state, sizeof(chacha20_state));
+ explicit_bzero(&b, sizeof(b));
+
+ return ret;
+}
+
+void
+xchacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src,
+ const size_t src_len, const uint8_t *ad,
+ const size_t ad_len,
+ const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE],
+ const uint8_t key[CHACHA20POLY1305_KEY_SIZE])
+{
+ uint32_t derived_key[CHACHA20_KEY_WORDS];
+
+ hchacha20(derived_key, nonce, key);
+ cpu_to_le32_array(derived_key, ARRAY_SIZE(derived_key));
+ chacha20poly1305_encrypt(dst, src, src_len, ad, ad_len,
+ get_unaligned_le64(nonce + 16),
+ (uint8_t *)derived_key);
+ explicit_bzero(derived_key, CHACHA20POLY1305_KEY_SIZE);
+}
+
+bool
+xchacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src,
+ const size_t src_len, const uint8_t *ad,
+ const size_t ad_len,
+ const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE],
+ const uint8_t key[CHACHA20POLY1305_KEY_SIZE])
+{
+ bool ret;
+ uint32_t derived_key[CHACHA20_KEY_WORDS];
+
+ hchacha20(derived_key, nonce, key);
+ cpu_to_le32_array(derived_key, ARRAY_SIZE(derived_key));
+ ret = chacha20poly1305_decrypt(dst, src, src_len, ad, ad_len,
+ get_unaligned_le64(nonce + 16),
+ (uint8_t *)derived_key);
+ explicit_bzero(derived_key, CHACHA20POLY1305_KEY_SIZE);
+ return ret;
+}
+
+
+static const uint32_t blake2s_iv[8] = {
+ 0x6A09E667UL, 0xBB67AE85UL, 0x3C6EF372UL, 0xA54FF53AUL,
+ 0x510E527FUL, 0x9B05688CUL, 0x1F83D9ABUL, 0x5BE0CD19UL
+};
+
+static const uint8_t blake2s_sigma[10][16] = {
+ { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 },
+ { 14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3 },
+ { 11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4 },
+ { 7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8 },
+ { 9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13 },
+ { 2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9 },
+ { 12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11 },
+ { 13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10 },
+ { 6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5 },
+ { 10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0 },
+};
+
+static inline void blake2s_set_lastblock(struct blake2s_state *state)
+{
+ state->f[0] = -1;
+}
+
+static inline void blake2s_increment_counter(struct blake2s_state *state,
+ const uint32_t inc)
+{
+ state->t[0] += inc;
+ state->t[1] += (state->t[0] < inc);
+}
+
+static inline void blake2s_init_param(struct blake2s_state *state,
+ const uint32_t param)
+{
+ int i;
+
+ memset(state, 0, sizeof(*state));
+ for (i = 0; i < 8; ++i)
+ state->h[i] = blake2s_iv[i];
+ state->h[0] ^= param;
+}
+
+void blake2s_init(struct blake2s_state *state, const size_t outlen)
+{
+ blake2s_init_param(state, 0x01010000 | outlen);
+ state->outlen = outlen;
+}
+
+void blake2s_init_key(struct blake2s_state *state, const size_t outlen,
+ const uint8_t *key, const size_t keylen)
+{
+ uint8_t block[BLAKE2S_BLOCK_SIZE] = { 0 };
+
+ blake2s_init_param(state, 0x01010000 | keylen << 8 | outlen);
+ state->outlen = outlen;
+ memcpy(block, key, keylen);
+ blake2s_update(state, block, BLAKE2S_BLOCK_SIZE);
+ explicit_bzero(block, BLAKE2S_BLOCK_SIZE);
+}
+
+static inline void blake2s_compress(struct blake2s_state *state,
+ const uint8_t *block, size_t nblocks,
+ const uint32_t inc)
+{
+ uint32_t m[16];
+ uint32_t v[16];
+ int i;
+
+ while (nblocks > 0) {
+ blake2s_increment_counter(state, inc);
+ memcpy(m, block, BLAKE2S_BLOCK_SIZE);
+ le32_to_cpu_array(m, ARRAY_SIZE(m));
+ memcpy(v, state->h, 32);
+ v[ 8] = blake2s_iv[0];
+ v[ 9] = blake2s_iv[1];
+ v[10] = blake2s_iv[2];
+ v[11] = blake2s_iv[3];
+ v[12] = blake2s_iv[4] ^ state->t[0];
+ v[13] = blake2s_iv[5] ^ state->t[1];
+ v[14] = blake2s_iv[6] ^ state->f[0];
+ v[15] = blake2s_iv[7] ^ state->f[1];
+
+#define G(r, i, a, b, c, d) do { \
+ a += b + m[blake2s_sigma[r][2 * i + 0]]; \
+ d = ror32(d ^ a, 16); \
+ c += d; \
+ b = ror32(b ^ c, 12); \
+ a += b + m[blake2s_sigma[r][2 * i + 1]]; \
+ d = ror32(d ^ a, 8); \
+ c += d; \
+ b = ror32(b ^ c, 7); \
+} while (0)
+
+#define ROUND(r) do { \
+ G(r, 0, v[0], v[ 4], v[ 8], v[12]); \
+ G(r, 1, v[1], v[ 5], v[ 9], v[13]); \
+ G(r, 2, v[2], v[ 6], v[10], v[14]); \
+ G(r, 3, v[3], v[ 7], v[11], v[15]); \
+ G(r, 4, v[0], v[ 5], v[10], v[15]); \
+ G(r, 5, v[1], v[ 6], v[11], v[12]); \
+ G(r, 6, v[2], v[ 7], v[ 8], v[13]); \
+ G(r, 7, v[3], v[ 4], v[ 9], v[14]); \
+} while (0)
+ ROUND(0);
+ ROUND(1);
+ ROUND(2);
+ ROUND(3);
+ ROUND(4);
+ ROUND(5);
+ ROUND(6);
+ ROUND(7);
+ ROUND(8);
+ ROUND(9);
+
+#undef G
+#undef ROUND
+
+ for (i = 0; i < 8; ++i)
+ state->h[i] ^= v[i] ^ v[i + 8];
+
+ block += BLAKE2S_BLOCK_SIZE;
+ --nblocks;
+ }
+}
+
+void blake2s_update(struct blake2s_state *state, const uint8_t *in, size_t inlen)
+{
+ const size_t fill = BLAKE2S_BLOCK_SIZE - state->buflen;
+
+ if (!inlen)
+ return;
+ if (inlen > fill) {
+ memcpy(state->buf + state->buflen, in, fill);
+ blake2s_compress(state, state->buf, 1, BLAKE2S_BLOCK_SIZE);
+ state->buflen = 0;
+ in += fill;
+ inlen -= fill;
+ }
+ if (inlen > BLAKE2S_BLOCK_SIZE) {
+ const size_t nblocks = DIV_ROUND_UP(inlen, BLAKE2S_BLOCK_SIZE);
+ /* Hash one less (full) block than strictly possible */
+ blake2s_compress(state, in, nblocks - 1, BLAKE2S_BLOCK_SIZE);
+ in += BLAKE2S_BLOCK_SIZE * (nblocks - 1);
+ inlen -= BLAKE2S_BLOCK_SIZE * (nblocks - 1);
+ }
+ memcpy(state->buf + state->buflen, in, inlen);
+ state->buflen += inlen;
+}
+
+void blake2s_final(struct blake2s_state *state, uint8_t *out)
+{
+ blake2s_set_lastblock(state);
+ memset(state->buf + state->buflen, 0,
+ BLAKE2S_BLOCK_SIZE - state->buflen); /* Padding */
+ blake2s_compress(state, state->buf, 1, state->buflen);
+ cpu_to_le32_array(state->h, ARRAY_SIZE(state->h));
+ memcpy(out, state->h, state->outlen);
+ explicit_bzero(state, sizeof(*state));
+}
+
+void blake2s(uint8_t *out, const uint8_t *in, const uint8_t *key,
+ const size_t outlen, const size_t inlen, const size_t keylen)
+{
+ struct blake2s_state state;
+
+ if (keylen)
+ blake2s_init_key(&state, outlen, key, keylen);
+ else
+ blake2s_init(&state, outlen);
+
+ blake2s_update(&state, in, inlen);
+ blake2s_final(&state, out);
+}
+
+void blake2s_hmac(uint8_t *out, const uint8_t *in, const uint8_t *key, const size_t outlen,
+ const size_t inlen, const size_t keylen)
+{
+ struct blake2s_state state;
+ uint8_t x_key[BLAKE2S_BLOCK_SIZE] __aligned(sizeof(uint32_t)) = { 0 };
+ uint8_t i_hash[BLAKE2S_HASH_SIZE] __aligned(sizeof(uint32_t));
+ int i;
+
+ if (keylen > BLAKE2S_BLOCK_SIZE) {
+ blake2s_init(&state, BLAKE2S_HASH_SIZE);
+ blake2s_update(&state, key, keylen);
+ blake2s_final(&state, x_key);
+ } else
+ memcpy(x_key, key, keylen);
+
+ for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
+ x_key[i] ^= 0x36;
+
+ blake2s_init(&state, BLAKE2S_HASH_SIZE);
+ blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
+ blake2s_update(&state, in, inlen);
+ blake2s_final(&state, i_hash);
+
+ for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
+ x_key[i] ^= 0x5c ^ 0x36;
+
+ blake2s_init(&state, BLAKE2S_HASH_SIZE);
+ blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE);
+ blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE);
+ blake2s_final(&state, i_hash);
+
+ memcpy(out, i_hash, outlen);
+ explicit_bzero(x_key, BLAKE2S_BLOCK_SIZE);
+ explicit_bzero(i_hash, BLAKE2S_HASH_SIZE);
+}
+
+
+/* Below here is fiat's implementation of x25519.
+ *
+ * Copyright (C) 2015-2016 The fiat-crypto Authors.
+ * Copyright (C) 2018-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ *
+ * This is a machine-generated formally verified implementation of Curve25519
+ * ECDH from: <https://github.com/mit-plv/fiat-crypto>. Though originally
+ * machine generated, it has been tweaked to be suitable for use in the kernel.
+ * It is optimized for 32-bit machines and machines that cannot work efficiently
+ * with 128-bit integer types.
+ */
+
+/* fe means field element. Here the field is \Z/(2^255-19). An element t,
+ * entries t[0]...t[9], represents the integer t[0]+2^26 t[1]+2^51 t[2]+2^77
+ * t[3]+2^102 t[4]+...+2^230 t[9].
+ * fe limbs are bounded by 1.125*2^26,1.125*2^25,1.125*2^26,1.125*2^25,etc.
+ * Multiplication and carrying produce fe from fe_loose.
+ */
+typedef struct fe { uint32_t v[10]; } fe;
+
+/* fe_loose limbs are bounded by 3.375*2^26,3.375*2^25,3.375*2^26,3.375*2^25,etc
+ * Addition and subtraction produce fe_loose from (fe, fe).
+ */
+typedef struct fe_loose { uint32_t v[10]; } fe_loose;
+
+static inline void fe_frombytes_impl(uint32_t h[10], const uint8_t *s)
+{
+ /* Ignores top bit of s. */
+ uint32_t a0 = get_unaligned_le32(s);
+ uint32_t a1 = get_unaligned_le32(s+4);
+ uint32_t a2 = get_unaligned_le32(s+8);
+ uint32_t a3 = get_unaligned_le32(s+12);
+ uint32_t a4 = get_unaligned_le32(s+16);
+ uint32_t a5 = get_unaligned_le32(s+20);
+ uint32_t a6 = get_unaligned_le32(s+24);
+ uint32_t a7 = get_unaligned_le32(s+28);
+ h[0] = a0&((1<<26)-1); /* 26 used, 32-26 left. 26 */
+ h[1] = (a0>>26) | ((a1&((1<<19)-1))<< 6); /* (32-26) + 19 = 6+19 = 25 */
+ h[2] = (a1>>19) | ((a2&((1<<13)-1))<<13); /* (32-19) + 13 = 13+13 = 26 */
+ h[3] = (a2>>13) | ((a3&((1<< 6)-1))<<19); /* (32-13) + 6 = 19+ 6 = 25 */
+ h[4] = (a3>> 6); /* (32- 6) = 26 */
+ h[5] = a4&((1<<25)-1); /* 25 */
+ h[6] = (a4>>25) | ((a5&((1<<19)-1))<< 7); /* (32-25) + 19 = 7+19 = 26 */
+ h[7] = (a5>>19) | ((a6&((1<<12)-1))<<13); /* (32-19) + 12 = 13+12 = 25 */
+ h[8] = (a6>>12) | ((a7&((1<< 6)-1))<<20); /* (32-12) + 6 = 20+ 6 = 26 */
+ h[9] = (a7>> 6)&((1<<25)-1); /* 25 */
+}
+
+static inline void fe_frombytes(fe *h, const uint8_t *s)
+{
+ fe_frombytes_impl(h->v, s);
+}
+
+static inline uint8_t /*bool*/
+addcarryx_u25(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low)
+{
+ /* This function extracts 25 bits of result and 1 bit of carry
+ * (26 total), so a 32-bit intermediate is sufficient.
+ */
+ uint32_t x = a + b + c;
+ *low = x & ((1 << 25) - 1);
+ return (x >> 25) & 1;
+}
+
+static inline uint8_t /*bool*/
+addcarryx_u26(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low)
+{
+ /* This function extracts 26 bits of result and 1 bit of carry
+ * (27 total), so a 32-bit intermediate is sufficient.
+ */
+ uint32_t x = a + b + c;
+ *low = x & ((1 << 26) - 1);
+ return (x >> 26) & 1;
+}
+
+static inline uint8_t /*bool*/
+subborrow_u25(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low)
+{
+ /* This function extracts 25 bits of result and 1 bit of borrow
+ * (26 total), so a 32-bit intermediate is sufficient.
+ */
+ uint32_t x = a - b - c;
+ *low = x & ((1 << 25) - 1);
+ return x >> 31;
+}
+
+static inline uint8_t /*bool*/
+subborrow_u26(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low)
+{
+ /* This function extracts 26 bits of result and 1 bit of borrow
+ *(27 total), so a 32-bit intermediate is sufficient.
+ */
+ uint32_t x = a - b - c;
+ *low = x & ((1 << 26) - 1);
+ return x >> 31;
+}
+
+static inline uint32_t cmovznz32(uint32_t t, uint32_t z, uint32_t nz)
+{
+ t = -!!t; /* all set if nonzero, 0 if 0 */
+ return (t&nz) | ((~t)&z);
+}
+
+static inline void fe_freeze(uint32_t out[10], const uint32_t in1[10])
+{
+ const uint32_t x17 = in1[9];
+ const uint32_t x18 = in1[8];
+ const uint32_t x16 = in1[7];
+ const uint32_t x14 = in1[6];
+ const uint32_t x12 = in1[5];
+ const uint32_t x10 = in1[4];
+ const uint32_t x8 = in1[3];
+ const uint32_t x6 = in1[2];
+ const uint32_t x4 = in1[1];
+ const uint32_t x2 = in1[0];
+ uint32_t x20; uint8_t/*bool*/ x21 = subborrow_u26(0x0, x2, 0x3ffffed, &x20);
+ uint32_t x23; uint8_t/*bool*/ x24 = subborrow_u25(x21, x4, 0x1ffffff, &x23);
+ uint32_t x26; uint8_t/*bool*/ x27 = subborrow_u26(x24, x6, 0x3ffffff, &x26);
+ uint32_t x29; uint8_t/*bool*/ x30 = subborrow_u25(x27, x8, 0x1ffffff, &x29);
+ uint32_t x32; uint8_t/*bool*/ x33 = subborrow_u26(x30, x10, 0x3ffffff, &x32);
+ uint32_t x35; uint8_t/*bool*/ x36 = subborrow_u25(x33, x12, 0x1ffffff, &x35);
+ uint32_t x38; uint8_t/*bool*/ x39 = subborrow_u26(x36, x14, 0x3ffffff, &x38);
+ uint32_t x41; uint8_t/*bool*/ x42 = subborrow_u25(x39, x16, 0x1ffffff, &x41);
+ uint32_t x44; uint8_t/*bool*/ x45 = subborrow_u26(x42, x18, 0x3ffffff, &x44);
+ uint32_t x47; uint8_t/*bool*/ x48 = subborrow_u25(x45, x17, 0x1ffffff, &x47);
+ uint32_t x49 = cmovznz32(x48, 0x0, 0xffffffff);
+ uint32_t x50 = (x49 & 0x3ffffed);
+ uint32_t x52; uint8_t/*bool*/ x53 = addcarryx_u26(0x0, x20, x50, &x52);
+ uint32_t x54 = (x49 & 0x1ffffff);
+ uint32_t x56; uint8_t/*bool*/ x57 = addcarryx_u25(x53, x23, x54, &x56);
+ uint32_t x58 = (x49 & 0x3ffffff);
+ uint32_t x60; uint8_t/*bool*/ x61 = addcarryx_u26(x57, x26, x58, &x60);
+ uint32_t x62 = (x49 & 0x1ffffff);
+ uint32_t x64; uint8_t/*bool*/ x65 = addcarryx_u25(x61, x29, x62, &x64);
+ uint32_t x66 = (x49 & 0x3ffffff);
+ uint32_t x68; uint8_t/*bool*/ x69 = addcarryx_u26(x65, x32, x66, &x68);
+ uint32_t x70 = (x49 & 0x1ffffff);
+ uint32_t x72; uint8_t/*bool*/ x73 = addcarryx_u25(x69, x35, x70, &x72);
+ uint32_t x74 = (x49 & 0x3ffffff);
+ uint32_t x76; uint8_t/*bool*/ x77 = addcarryx_u26(x73, x38, x74, &x76);
+ uint32_t x78 = (x49 & 0x1ffffff);
+ uint32_t x80; uint8_t/*bool*/ x81 = addcarryx_u25(x77, x41, x78, &x80);
+ uint32_t x82 = (x49 & 0x3ffffff);
+ uint32_t x84; uint8_t/*bool*/ x85 = addcarryx_u26(x81, x44, x82, &x84);
+ uint32_t x86 = (x49 & 0x1ffffff);
+ uint32_t x88; addcarryx_u25(x85, x47, x86, &x88);
+ out[0] = x52;
+ out[1] = x56;
+ out[2] = x60;
+ out[3] = x64;
+ out[4] = x68;
+ out[5] = x72;
+ out[6] = x76;
+ out[7] = x80;
+ out[8] = x84;
+ out[9] = x88;
+}
+
+static inline void fe_tobytes(uint8_t s[32], const fe *f)
+{
+ uint32_t h[10];
+ fe_freeze(h, f->v);
+ s[0] = h[0] >> 0;
+ s[1] = h[0] >> 8;
+ s[2] = h[0] >> 16;
+ s[3] = (h[0] >> 24) | (h[1] << 2);
+ s[4] = h[1] >> 6;
+ s[5] = h[1] >> 14;
+ s[6] = (h[1] >> 22) | (h[2] << 3);
+ s[7] = h[2] >> 5;
+ s[8] = h[2] >> 13;
+ s[9] = (h[2] >> 21) | (h[3] << 5);
+ s[10] = h[3] >> 3;
+ s[11] = h[3] >> 11;
+ s[12] = (h[3] >> 19) | (h[4] << 6);
+ s[13] = h[4] >> 2;
+ s[14] = h[4] >> 10;
+ s[15] = h[4] >> 18;
+ s[16] = h[5] >> 0;
+ s[17] = h[5] >> 8;
+ s[18] = h[5] >> 16;
+ s[19] = (h[5] >> 24) | (h[6] << 1);
+ s[20] = h[6] >> 7;
+ s[21] = h[6] >> 15;
+ s[22] = (h[6] >> 23) | (h[7] << 3);
+ s[23] = h[7] >> 5;
+ s[24] = h[7] >> 13;
+ s[25] = (h[7] >> 21) | (h[8] << 4);
+ s[26] = h[8] >> 4;
+ s[27] = h[8] >> 12;
+ s[28] = (h[8] >> 20) | (h[9] << 6);
+ s[29] = h[9] >> 2;
+ s[30] = h[9] >> 10;
+ s[31] = h[9] >> 18;
+}
+
+/* h = f */
+static inline void fe_copy(fe *h, const fe *f)
+{
+ memmove(h, f, sizeof(uint32_t) * 10);
+}
+
+static inline void fe_copy_lt(fe_loose *h, const fe *f)
+{
+ memmove(h, f, sizeof(uint32_t) * 10);
+}
+
+/* h = 0 */
+static inline void fe_0(fe *h)
+{
+ memset(h, 0, sizeof(uint32_t) * 10);
+}
+
+/* h = 1 */
+static inline void fe_1(fe *h)
+{
+ memset(h, 0, sizeof(uint32_t) * 10);
+ h->v[0] = 1;
+}
+
+static void fe_add_impl(uint32_t out[10], const uint32_t in1[10], const uint32_t in2[10])
+{
+ const uint32_t x20 = in1[9];
+ const uint32_t x21 = in1[8];
+ const uint32_t x19 = in1[7];
+ const uint32_t x17 = in1[6];
+ const uint32_t x15 = in1[5];
+ const uint32_t x13 = in1[4];
+ const uint32_t x11 = in1[3];
+ const uint32_t x9 = in1[2];
+ const uint32_t x7 = in1[1];
+ const uint32_t x5 = in1[0];
+ const uint32_t x38 = in2[9];
+ const uint32_t x39 = in2[8];
+ const uint32_t x37 = in2[7];
+ const uint32_t x35 = in2[6];
+ const uint32_t x33 = in2[5];
+ const uint32_t x31 = in2[4];
+ const uint32_t x29 = in2[3];
+ const uint32_t x27 = in2[2];
+ const uint32_t x25 = in2[1];
+ const uint32_t x23 = in2[0];
+ out[0] = (x5 + x23);
+ out[1] = (x7 + x25);
+ out[2] = (x9 + x27);
+ out[3] = (x11 + x29);
+ out[4] = (x13 + x31);
+ out[5] = (x15 + x33);
+ out[6] = (x17 + x35);
+ out[7] = (x19 + x37);
+ out[8] = (x21 + x39);
+ out[9] = (x20 + x38);
+}
+
+/* h = f + g
+ * Can overlap h with f or g.
+ */
+static inline void fe_add(fe_loose *h, const fe *f, const fe *g)
+{
+ fe_add_impl(h->v, f->v, g->v);
+}
+
+static void fe_sub_impl(uint32_t out[10], const uint32_t in1[10], const uint32_t in2[10])
+{
+ const uint32_t x20 = in1[9];
+ const uint32_t x21 = in1[8];
+ const uint32_t x19 = in1[7];
+ const uint32_t x17 = in1[6];
+ const uint32_t x15 = in1[5];
+ const uint32_t x13 = in1[4];
+ const uint32_t x11 = in1[3];
+ const uint32_t x9 = in1[2];
+ const uint32_t x7 = in1[1];
+ const uint32_t x5 = in1[0];
+ const uint32_t x38 = in2[9];
+ const uint32_t x39 = in2[8];
+ const uint32_t x37 = in2[7];
+ const uint32_t x35 = in2[6];
+ const uint32_t x33 = in2[5];
+ const uint32_t x31 = in2[4];
+ const uint32_t x29 = in2[3];
+ const uint32_t x27 = in2[2];
+ const uint32_t x25 = in2[1];
+ const uint32_t x23 = in2[0];
+ out[0] = ((0x7ffffda + x5) - x23);
+ out[1] = ((0x3fffffe + x7) - x25);
+ out[2] = ((0x7fffffe + x9) - x27);
+ out[3] = ((0x3fffffe + x11) - x29);
+ out[4] = ((0x7fffffe + x13) - x31);
+ out[5] = ((0x3fffffe + x15) - x33);
+ out[6] = ((0x7fffffe + x17) - x35);
+ out[7] = ((0x3fffffe + x19) - x37);
+ out[8] = ((0x7fffffe + x21) - x39);
+ out[9] = ((0x3fffffe + x20) - x38);
+}
+
+/* h = f - g
+ * Can overlap h with f or g.
+ */
+static inline void fe_sub(fe_loose *h, const fe *f, const fe *g)
+{
+ fe_sub_impl(h->v, f->v, g->v);
+}
+
+static void fe_mul_impl(uint32_t out[10], const uint32_t in1[10], const uint32_t in2[10])
+{
+ const uint32_t x20 = in1[9];
+ const uint32_t x21 = in1[8];
+ const uint32_t x19 = in1[7];
+ const uint32_t x17 = in1[6];
+ const uint32_t x15 = in1[5];
+ const uint32_t x13 = in1[4];
+ const uint32_t x11 = in1[3];
+ const uint32_t x9 = in1[2];
+ const uint32_t x7 = in1[1];
+ const uint32_t x5 = in1[0];
+ const uint32_t x38 = in2[9];
+ const uint32_t x39 = in2[8];
+ const uint32_t x37 = in2[7];
+ const uint32_t x35 = in2[6];
+ const uint32_t x33 = in2[5];
+ const uint32_t x31 = in2[4];
+ const uint32_t x29 = in2[3];
+ const uint32_t x27 = in2[2];
+ const uint32_t x25 = in2[1];
+ const uint32_t x23 = in2[0];
+ uint64_t x40 = ((uint64_t)x23 * x5);
+ uint64_t x41 = (((uint64_t)x23 * x7) + ((uint64_t)x25 * x5));
+ uint64_t x42 = ((((uint64_t)(0x2 * x25) * x7) + ((uint64_t)x23 * x9)) + ((uint64_t)x27 * x5));
+ uint64_t x43 = (((((uint64_t)x25 * x9) + ((uint64_t)x27 * x7)) + ((uint64_t)x23 * x11)) + ((uint64_t)x29 * x5));
+ uint64_t x44 = (((((uint64_t)x27 * x9) + (0x2 * (((uint64_t)x25 * x11) + ((uint64_t)x29 * x7)))) + ((uint64_t)x23 * x13)) + ((uint64_t)x31 * x5));
+ uint64_t x45 = (((((((uint64_t)x27 * x11) + ((uint64_t)x29 * x9)) + ((uint64_t)x25 * x13)) + ((uint64_t)x31 * x7)) + ((uint64_t)x23 * x15)) + ((uint64_t)x33 * x5));
+ uint64_t x46 = (((((0x2 * ((((uint64_t)x29 * x11) + ((uint64_t)x25 * x15)) + ((uint64_t)x33 * x7))) + ((uint64_t)x27 * x13)) + ((uint64_t)x31 * x9)) + ((uint64_t)x23 * x17)) + ((uint64_t)x35 * x5));
+ uint64_t x47 = (((((((((uint64_t)x29 * x13) + ((uint64_t)x31 * x11)) + ((uint64_t)x27 * x15)) + ((uint64_t)x33 * x9)) + ((uint64_t)x25 * x17)) + ((uint64_t)x35 * x7)) + ((uint64_t)x23 * x19)) + ((uint64_t)x37 * x5));
+ uint64_t x48 = (((((((uint64_t)x31 * x13) + (0x2 * (((((uint64_t)x29 * x15) + ((uint64_t)x33 * x11)) + ((uint64_t)x25 * x19)) + ((uint64_t)x37 * x7)))) + ((uint64_t)x27 * x17)) + ((uint64_t)x35 * x9)) + ((uint64_t)x23 * x21)) + ((uint64_t)x39 * x5));
+ uint64_t x49 = (((((((((((uint64_t)x31 * x15) + ((uint64_t)x33 * x13)) + ((uint64_t)x29 * x17)) + ((uint64_t)x35 * x11)) + ((uint64_t)x27 * x19)) + ((uint64_t)x37 * x9)) + ((uint64_t)x25 * x21)) + ((uint64_t)x39 * x7)) + ((uint64_t)x23 * x20)) + ((uint64_t)x38 * x5));
+ uint64_t x50 = (((((0x2 * ((((((uint64_t)x33 * x15) + ((uint64_t)x29 * x19)) + ((uint64_t)x37 * x11)) + ((uint64_t)x25 * x20)) + ((uint64_t)x38 * x7))) + ((uint64_t)x31 * x17)) + ((uint64_t)x35 * x13)) + ((uint64_t)x27 * x21)) + ((uint64_t)x39 * x9));
+ uint64_t x51 = (((((((((uint64_t)x33 * x17) + ((uint64_t)x35 * x15)) + ((uint64_t)x31 * x19)) + ((uint64_t)x37 * x13)) + ((uint64_t)x29 * x21)) + ((uint64_t)x39 * x11)) + ((uint64_t)x27 * x20)) + ((uint64_t)x38 * x9));
+ uint64_t x52 = (((((uint64_t)x35 * x17) + (0x2 * (((((uint64_t)x33 * x19) + ((uint64_t)x37 * x15)) + ((uint64_t)x29 * x20)) + ((uint64_t)x38 * x11)))) + ((uint64_t)x31 * x21)) + ((uint64_t)x39 * x13));
+ uint64_t x53 = (((((((uint64_t)x35 * x19) + ((uint64_t)x37 * x17)) + ((uint64_t)x33 * x21)) + ((uint64_t)x39 * x15)) + ((uint64_t)x31 * x20)) + ((uint64_t)x38 * x13));
+ uint64_t x54 = (((0x2 * ((((uint64_t)x37 * x19) + ((uint64_t)x33 * x20)) + ((uint64_t)x38 * x15))) + ((uint64_t)x35 * x21)) + ((uint64_t)x39 * x17));
+ uint64_t x55 = (((((uint64_t)x37 * x21) + ((uint64_t)x39 * x19)) + ((uint64_t)x35 * x20)) + ((uint64_t)x38 * x17));
+ uint64_t x56 = (((uint64_t)x39 * x21) + (0x2 * (((uint64_t)x37 * x20) + ((uint64_t)x38 * x19))));
+ uint64_t x57 = (((uint64_t)x39 * x20) + ((uint64_t)x38 * x21));
+ uint64_t x58 = ((uint64_t)(0x2 * x38) * x20);
+ uint64_t x59 = (x48 + (x58 << 0x4));
+ uint64_t x60 = (x59 + (x58 << 0x1));
+ uint64_t x61 = (x60 + x58);
+ uint64_t x62 = (x47 + (x57 << 0x4));
+ uint64_t x63 = (x62 + (x57 << 0x1));
+ uint64_t x64 = (x63 + x57);
+ uint64_t x65 = (x46 + (x56 << 0x4));
+ uint64_t x66 = (x65 + (x56 << 0x1));
+ uint64_t x67 = (x66 + x56);
+ uint64_t x68 = (x45 + (x55 << 0x4));
+ uint64_t x69 = (x68 + (x55 << 0x1));
+ uint64_t x70 = (x69 + x55);
+ uint64_t x71 = (x44 + (x54 << 0x4));
+ uint64_t x72 = (x71 + (x54 << 0x1));
+ uint64_t x73 = (x72 + x54);
+ uint64_t x74 = (x43 + (x53 << 0x4));
+ uint64_t x75 = (x74 + (x53 << 0x1));
+ uint64_t x76 = (x75 + x53);
+ uint64_t x77 = (x42 + (x52 << 0x4));
+ uint64_t x78 = (x77 + (x52 << 0x1));
+ uint64_t x79 = (x78 + x52);
+ uint64_t x80 = (x41 + (x51 << 0x4));
+ uint64_t x81 = (x80 + (x51 << 0x1));
+ uint64_t x82 = (x81 + x51);
+ uint64_t x83 = (x40 + (x50 << 0x4));
+ uint64_t x84 = (x83 + (x50 << 0x1));
+ uint64_t x85 = (x84 + x50);
+ uint64_t x86 = (x85 >> 0x1a);
+ uint32_t x87 = ((uint32_t)x85 & 0x3ffffff);
+ uint64_t x88 = (x86 + x82);
+ uint64_t x89 = (x88 >> 0x19);
+ uint32_t x90 = ((uint32_t)x88 & 0x1ffffff);
+ uint64_t x91 = (x89 + x79);
+ uint64_t x92 = (x91 >> 0x1a);
+ uint32_t x93 = ((uint32_t)x91 & 0x3ffffff);
+ uint64_t x94 = (x92 + x76);
+ uint64_t x95 = (x94 >> 0x19);
+ uint32_t x96 = ((uint32_t)x94 & 0x1ffffff);
+ uint64_t x97 = (x95 + x73);
+ uint64_t x98 = (x97 >> 0x1a);
+ uint32_t x99 = ((uint32_t)x97 & 0x3ffffff);
+ uint64_t x100 = (x98 + x70);
+ uint64_t x101 = (x100 >> 0x19);
+ uint32_t x102 = ((uint32_t)x100 & 0x1ffffff);
+ uint64_t x103 = (x101 + x67);
+ uint64_t x104 = (x103 >> 0x1a);
+ uint32_t x105 = ((uint32_t)x103 & 0x3ffffff);
+ uint64_t x106 = (x104 + x64);
+ uint64_t x107 = (x106 >> 0x19);
+ uint32_t x108 = ((uint32_t)x106 & 0x1ffffff);
+ uint64_t x109 = (x107 + x61);
+ uint64_t x110 = (x109 >> 0x1a);
+ uint32_t x111 = ((uint32_t)x109 & 0x3ffffff);
+ uint64_t x112 = (x110 + x49);
+ uint64_t x113 = (x112 >> 0x19);
+ uint32_t x114 = ((uint32_t)x112 & 0x1ffffff);
+ uint64_t x115 = (x87 + (0x13 * x113));
+ uint32_t x116 = (uint32_t) (x115 >> 0x1a);
+ uint32_t x117 = ((uint32_t)x115 & 0x3ffffff);
+ uint32_t x118 = (x116 + x90);
+ uint32_t x119 = (x118 >> 0x19);
+ uint32_t x120 = (x118 & 0x1ffffff);
+ out[0] = x117;
+ out[1] = x120;
+ out[2] = (x119 + x93);
+ out[3] = x96;
+ out[4] = x99;
+ out[5] = x102;
+ out[6] = x105;
+ out[7] = x108;
+ out[8] = x111;
+ out[9] = x114;
+}
+
+static inline void fe_mul_ttt(fe *h, const fe *f, const fe *g)
+{
+ fe_mul_impl(h->v, f->v, g->v);
+}
+
+static inline void fe_mul_tlt(fe *h, const fe_loose *f, const fe *g)
+{
+ fe_mul_impl(h->v, f->v, g->v);
+}
+
+static inline void
+fe_mul_tll(fe *h, const fe_loose *f, const fe_loose *g)
+{
+ fe_mul_impl(h->v, f->v, g->v);
+}
+
+static void fe_sqr_impl(uint32_t out[10], const uint32_t in1[10])
+{
+ const uint32_t x17 = in1[9];
+ const uint32_t x18 = in1[8];
+ const uint32_t x16 = in1[7];
+ const uint32_t x14 = in1[6];
+ const uint32_t x12 = in1[5];
+ const uint32_t x10 = in1[4];
+ const uint32_t x8 = in1[3];
+ const uint32_t x6 = in1[2];
+ const uint32_t x4 = in1[1];
+ const uint32_t x2 = in1[0];
+ uint64_t x19 = ((uint64_t)x2 * x2);
+ uint64_t x20 = ((uint64_t)(0x2 * x2) * x4);
+ uint64_t x21 = (0x2 * (((uint64_t)x4 * x4) + ((uint64_t)x2 * x6)));
+ uint64_t x22 = (0x2 * (((uint64_t)x4 * x6) + ((uint64_t)x2 * x8)));
+ uint64_t x23 = ((((uint64_t)x6 * x6) + ((uint64_t)(0x4 * x4) * x8)) + ((uint64_t)(0x2 * x2) * x10));
+ uint64_t x24 = (0x2 * ((((uint64_t)x6 * x8) + ((uint64_t)x4 * x10)) + ((uint64_t)x2 * x12)));
+ uint64_t x25 = (0x2 * (((((uint64_t)x8 * x8) + ((uint64_t)x6 * x10)) + ((uint64_t)x2 * x14)) + ((uint64_t)(0x2 * x4) * x12)));
+ uint64_t x26 = (0x2 * (((((uint64_t)x8 * x10) + ((uint64_t)x6 * x12)) + ((uint64_t)x4 * x14)) + ((uint64_t)x2 * x16)));
+ uint64_t x27 = (((uint64_t)x10 * x10) + (0x2 * ((((uint64_t)x6 * x14) + ((uint64_t)x2 * x18)) + (0x2 * (((uint64_t)x4 * x16) + ((uint64_t)x8 * x12))))));
+ uint64_t x28 = (0x2 * ((((((uint64_t)x10 * x12) + ((uint64_t)x8 * x14)) + ((uint64_t)x6 * x16)) + ((uint64_t)x4 * x18)) + ((uint64_t)x2 * x17)));
+ uint64_t x29 = (0x2 * (((((uint64_t)x12 * x12) + ((uint64_t)x10 * x14)) + ((uint64_t)x6 * x18)) + (0x2 * (((uint64_t)x8 * x16) + ((uint64_t)x4 * x17)))));
+ uint64_t x30 = (0x2 * (((((uint64_t)x12 * x14) + ((uint64_t)x10 * x16)) + ((uint64_t)x8 * x18)) + ((uint64_t)x6 * x17)));
+ uint64_t x31 = (((uint64_t)x14 * x14) + (0x2 * (((uint64_t)x10 * x18) + (0x2 * (((uint64_t)x12 * x16) + ((uint64_t)x8 * x17))))));
+ uint64_t x32 = (0x2 * ((((uint64_t)x14 * x16) + ((uint64_t)x12 * x18)) + ((uint64_t)x10 * x17)));
+ uint64_t x33 = (0x2 * ((((uint64_t)x16 * x16) + ((uint64_t)x14 * x18)) + ((uint64_t)(0x2 * x12) * x17)));
+ uint64_t x34 = (0x2 * (((uint64_t)x16 * x18) + ((uint64_t)x14 * x17)));
+ uint64_t x35 = (((uint64_t)x18 * x18) + ((uint64_t)(0x4 * x16) * x17));
+ uint64_t x36 = ((uint64_t)(0x2 * x18) * x17);
+ uint64_t x37 = ((uint64_t)(0x2 * x17) * x17);
+ uint64_t x38 = (x27 + (x37 << 0x4));
+ uint64_t x39 = (x38 + (x37 << 0x1));
+ uint64_t x40 = (x39 + x37);
+ uint64_t x41 = (x26 + (x36 << 0x4));
+ uint64_t x42 = (x41 + (x36 << 0x1));
+ uint64_t x43 = (x42 + x36);
+ uint64_t x44 = (x25 + (x35 << 0x4));
+ uint64_t x45 = (x44 + (x35 << 0x1));
+ uint64_t x46 = (x45 + x35);
+ uint64_t x47 = (x24 + (x34 << 0x4));
+ uint64_t x48 = (x47 + (x34 << 0x1));
+ uint64_t x49 = (x48 + x34);
+ uint64_t x50 = (x23 + (x33 << 0x4));
+ uint64_t x51 = (x50 + (x33 << 0x1));
+ uint64_t x52 = (x51 + x33);
+ uint64_t x53 = (x22 + (x32 << 0x4));
+ uint64_t x54 = (x53 + (x32 << 0x1));
+ uint64_t x55 = (x54 + x32);
+ uint64_t x56 = (x21 + (x31 << 0x4));
+ uint64_t x57 = (x56 + (x31 << 0x1));
+ uint64_t x58 = (x57 + x31);
+ uint64_t x59 = (x20 + (x30 << 0x4));
+ uint64_t x60 = (x59 + (x30 << 0x1));
+ uint64_t x61 = (x60 + x30);
+ uint64_t x62 = (x19 + (x29 << 0x4));
+ uint64_t x63 = (x62 + (x29 << 0x1));
+ uint64_t x64 = (x63 + x29);
+ uint64_t x65 = (x64 >> 0x1a);
+ uint32_t x66 = ((uint32_t)x64 & 0x3ffffff);
+ uint64_t x67 = (x65 + x61);
+ uint64_t x68 = (x67 >> 0x19);
+ uint32_t x69 = ((uint32_t)x67 & 0x1ffffff);
+ uint64_t x70 = (x68 + x58);
+ uint64_t x71 = (x70 >> 0x1a);
+ uint32_t x72 = ((uint32_t)x70 & 0x3ffffff);
+ uint64_t x73 = (x71 + x55);
+ uint64_t x74 = (x73 >> 0x19);
+ uint32_t x75 = ((uint32_t)x73 & 0x1ffffff);
+ uint64_t x76 = (x74 + x52);
+ uint64_t x77 = (x76 >> 0x1a);
+ uint32_t x78 = ((uint32_t)x76 & 0x3ffffff);
+ uint64_t x79 = (x77 + x49);
+ uint64_t x80 = (x79 >> 0x19);
+ uint32_t x81 = ((uint32_t)x79 & 0x1ffffff);
+ uint64_t x82 = (x80 + x46);
+ uint64_t x83 = (x82 >> 0x1a);
+ uint32_t x84 = ((uint32_t)x82 & 0x3ffffff);
+ uint64_t x85 = (x83 + x43);
+ uint64_t x86 = (x85 >> 0x19);
+ uint32_t x87 = ((uint32_t)x85 & 0x1ffffff);
+ uint64_t x88 = (x86 + x40);
+ uint64_t x89 = (x88 >> 0x1a);
+ uint32_t x90 = ((uint32_t)x88 & 0x3ffffff);
+ uint64_t x91 = (x89 + x28);
+ uint64_t x92 = (x91 >> 0x19);
+ uint32_t x93 = ((uint32_t)x91 & 0x1ffffff);
+ uint64_t x94 = (x66 + (0x13 * x92));
+ uint32_t x95 = (uint32_t) (x94 >> 0x1a);
+ uint32_t x96 = ((uint32_t)x94 & 0x3ffffff);
+ uint32_t x97 = (x95 + x69);
+ uint32_t x98 = (x97 >> 0x19);
+ uint32_t x99 = (x97 & 0x1ffffff);
+ out[0] = x96;
+ out[1] = x99;
+ out[2] = (x98 + x72);
+ out[3] = x75;
+ out[4] = x78;
+ out[5] = x81;
+ out[6] = x84;
+ out[7] = x87;
+ out[8] = x90;
+ out[9] = x93;
+}
+
+static inline void fe_sq_tl(fe *h, const fe_loose *f)
+{
+ fe_sqr_impl(h->v, f->v);
+}
+
+static inline void fe_sq_tt(fe *h, const fe *f)
+{
+ fe_sqr_impl(h->v, f->v);
+}
+
+static inline void fe_loose_invert(fe *out, const fe_loose *z)
+{
+ fe t0;
+ fe t1;
+ fe t2;
+ fe t3;
+ int i;
+
+ fe_sq_tl(&t0, z);
+ fe_sq_tt(&t1, &t0);
+ for (i = 1; i < 2; ++i)
+ fe_sq_tt(&t1, &t1);
+ fe_mul_tlt(&t1, z, &t1);
+ fe_mul_ttt(&t0, &t0, &t1);
+ fe_sq_tt(&t2, &t0);
+ fe_mul_ttt(&t1, &t1, &t2);
+ fe_sq_tt(&t2, &t1);
+ for (i = 1; i < 5; ++i)
+ fe_sq_tt(&t2, &t2);
+ fe_mul_ttt(&t1, &t2, &t1);
+ fe_sq_tt(&t2, &t1);
+ for (i = 1; i < 10; ++i)
+ fe_sq_tt(&t2, &t2);
+ fe_mul_ttt(&t2, &t2, &t1);
+ fe_sq_tt(&t3, &t2);
+ for (i = 1; i < 20; ++i)
+ fe_sq_tt(&t3, &t3);
+ fe_mul_ttt(&t2, &t3, &t2);
+ fe_sq_tt(&t2, &t2);
+ for (i = 1; i < 10; ++i)
+ fe_sq_tt(&t2, &t2);
+ fe_mul_ttt(&t1, &t2, &t1);
+ fe_sq_tt(&t2, &t1);
+ for (i = 1; i < 50; ++i)
+ fe_sq_tt(&t2, &t2);
+ fe_mul_ttt(&t2, &t2, &t1);
+ fe_sq_tt(&t3, &t2);
+ for (i = 1; i < 100; ++i)
+ fe_sq_tt(&t3, &t3);
+ fe_mul_ttt(&t2, &t3, &t2);
+ fe_sq_tt(&t2, &t2);
+ for (i = 1; i < 50; ++i)
+ fe_sq_tt(&t2, &t2);
+ fe_mul_ttt(&t1, &t2, &t1);
+ fe_sq_tt(&t1, &t1);
+ for (i = 1; i < 5; ++i)
+ fe_sq_tt(&t1, &t1);
+ fe_mul_ttt(out, &t1, &t0);
+}
+
+static inline void fe_invert(fe *out, const fe *z)
+{
+ fe_loose l;
+ fe_copy_lt(&l, z);
+ fe_loose_invert(out, &l);
+}
+
+/* Replace (f,g) with (g,f) if b == 1;
+ * replace (f,g) with (f,g) if b == 0.
+ *
+ * Preconditions: b in {0,1}
+ */
+static inline void fe_cswap(fe *f, fe *g, unsigned int b)
+{
+ unsigned i;
+ b = 0 - b;
+ for (i = 0; i < 10; i++) {
+ uint32_t x = f->v[i] ^ g->v[i];
+ x &= b;
+ f->v[i] ^= x;
+ g->v[i] ^= x;
+ }
+}
+
+/* NOTE: based on fiat-crypto fe_mul, edited for in2=121666, 0, 0.*/
+static inline void fe_mul_121666_impl(uint32_t out[10], const uint32_t in1[10])
+{
+ const uint32_t x20 = in1[9];
+ const uint32_t x21 = in1[8];
+ const uint32_t x19 = in1[7];
+ const uint32_t x17 = in1[6];
+ const uint32_t x15 = in1[5];
+ const uint32_t x13 = in1[4];
+ const uint32_t x11 = in1[3];
+ const uint32_t x9 = in1[2];
+ const uint32_t x7 = in1[1];
+ const uint32_t x5 = in1[0];
+ const uint32_t x38 = 0;
+ const uint32_t x39 = 0;
+ const uint32_t x37 = 0;
+ const uint32_t x35 = 0;
+ const uint32_t x33 = 0;
+ const uint32_t x31 = 0;
+ const uint32_t x29 = 0;
+ const uint32_t x27 = 0;
+ const uint32_t x25 = 0;
+ const uint32_t x23 = 121666;
+ uint64_t x40 = ((uint64_t)x23 * x5);
+ uint64_t x41 = (((uint64_t)x23 * x7) + ((uint64_t)x25 * x5));
+ uint64_t x42 = ((((uint64_t)(0x2 * x25) * x7) + ((uint64_t)x23 * x9)) + ((uint64_t)x27 * x5));
+ uint64_t x43 = (((((uint64_t)x25 * x9) + ((uint64_t)x27 * x7)) + ((uint64_t)x23 * x11)) + ((uint64_t)x29 * x5));
+ uint64_t x44 = (((((uint64_t)x27 * x9) + (0x2 * (((uint64_t)x25 * x11) + ((uint64_t)x29 * x7)))) + ((uint64_t)x23 * x13)) + ((uint64_t)x31 * x5));
+ uint64_t x45 = (((((((uint64_t)x27 * x11) + ((uint64_t)x29 * x9)) + ((uint64_t)x25 * x13)) + ((uint64_t)x31 * x7)) + ((uint64_t)x23 * x15)) + ((uint64_t)x33 * x5));
+ uint64_t x46 = (((((0x2 * ((((uint64_t)x29 * x11) + ((uint64_t)x25 * x15)) + ((uint64_t)x33 * x7))) + ((uint64_t)x27 * x13)) + ((uint64_t)x31 * x9)) + ((uint64_t)x23 * x17)) + ((uint64_t)x35 * x5));
+ uint64_t x47 = (((((((((uint64_t)x29 * x13) + ((uint64_t)x31 * x11)) + ((uint64_t)x27 * x15)) + ((uint64_t)x33 * x9)) + ((uint64_t)x25 * x17)) + ((uint64_t)x35 * x7)) + ((uint64_t)x23 * x19)) + ((uint64_t)x37 * x5));
+ uint64_t x48 = (((((((uint64_t)x31 * x13) + (0x2 * (((((uint64_t)x29 * x15) + ((uint64_t)x33 * x11)) + ((uint64_t)x25 * x19)) + ((uint64_t)x37 * x7)))) + ((uint64_t)x27 * x17)) + ((uint64_t)x35 * x9)) + ((uint64_t)x23 * x21)) + ((uint64_t)x39 * x5));
+ uint64_t x49 = (((((((((((uint64_t)x31 * x15) + ((uint64_t)x33 * x13)) + ((uint64_t)x29 * x17)) + ((uint64_t)x35 * x11)) + ((uint64_t)x27 * x19)) + ((uint64_t)x37 * x9)) + ((uint64_t)x25 * x21)) + ((uint64_t)x39 * x7)) + ((uint64_t)x23 * x20)) + ((uint64_t)x38 * x5));
+ uint64_t x50 = (((((0x2 * ((((((uint64_t)x33 * x15) + ((uint64_t)x29 * x19)) + ((uint64_t)x37 * x11)) + ((uint64_t)x25 * x20)) + ((uint64_t)x38 * x7))) + ((uint64_t)x31 * x17)) + ((uint64_t)x35 * x13)) + ((uint64_t)x27 * x21)) + ((uint64_t)x39 * x9));
+ uint64_t x51 = (((((((((uint64_t)x33 * x17) + ((uint64_t)x35 * x15)) + ((uint64_t)x31 * x19)) + ((uint64_t)x37 * x13)) + ((uint64_t)x29 * x21)) + ((uint64_t)x39 * x11)) + ((uint64_t)x27 * x20)) + ((uint64_t)x38 * x9));
+ uint64_t x52 = (((((uint64_t)x35 * x17) + (0x2 * (((((uint64_t)x33 * x19) + ((uint64_t)x37 * x15)) + ((uint64_t)x29 * x20)) + ((uint64_t)x38 * x11)))) + ((uint64_t)x31 * x21)) + ((uint64_t)x39 * x13));
+ uint64_t x53 = (((((((uint64_t)x35 * x19) + ((uint64_t)x37 * x17)) + ((uint64_t)x33 * x21)) + ((uint64_t)x39 * x15)) + ((uint64_t)x31 * x20)) + ((uint64_t)x38 * x13));
+ uint64_t x54 = (((0x2 * ((((uint64_t)x37 * x19) + ((uint64_t)x33 * x20)) + ((uint64_t)x38 * x15))) + ((uint64_t)x35 * x21)) + ((uint64_t)x39 * x17));
+ uint64_t x55 = (((((uint64_t)x37 * x21) + ((uint64_t)x39 * x19)) + ((uint64_t)x35 * x20)) + ((uint64_t)x38 * x17));
+ uint64_t x56 = (((uint64_t)x39 * x21) + (0x2 * (((uint64_t)x37 * x20) + ((uint64_t)x38 * x19))));
+ uint64_t x57 = (((uint64_t)x39 * x20) + ((uint64_t)x38 * x21));
+ uint64_t x58 = ((uint64_t)(0x2 * x38) * x20);
+ uint64_t x59 = (x48 + (x58 << 0x4));
+ uint64_t x60 = (x59 + (x58 << 0x1));
+ uint64_t x61 = (x60 + x58);
+ uint64_t x62 = (x47 + (x57 << 0x4));
+ uint64_t x63 = (x62 + (x57 << 0x1));
+ uint64_t x64 = (x63 + x57);
+ uint64_t x65 = (x46 + (x56 << 0x4));
+ uint64_t x66 = (x65 + (x56 << 0x1));
+ uint64_t x67 = (x66 + x56);
+ uint64_t x68 = (x45 + (x55 << 0x4));
+ uint64_t x69 = (x68 + (x55 << 0x1));
+ uint64_t x70 = (x69 + x55);
+ uint64_t x71 = (x44 + (x54 << 0x4));
+ uint64_t x72 = (x71 + (x54 << 0x1));
+ uint64_t x73 = (x72 + x54);
+ uint64_t x74 = (x43 + (x53 << 0x4));
+ uint64_t x75 = (x74 + (x53 << 0x1));
+ uint64_t x76 = (x75 + x53);
+ uint64_t x77 = (x42 + (x52 << 0x4));
+ uint64_t x78 = (x77 + (x52 << 0x1));
+ uint64_t x79 = (x78 + x52);
+ uint64_t x80 = (x41 + (x51 << 0x4));
+ uint64_t x81 = (x80 + (x51 << 0x1));
+ uint64_t x82 = (x81 + x51);
+ uint64_t x83 = (x40 + (x50 << 0x4));
+ uint64_t x84 = (x83 + (x50 << 0x1));
+ uint64_t x85 = (x84 + x50);
+ uint64_t x86 = (x85 >> 0x1a);
+ uint32_t x87 = ((uint32_t)x85 & 0x3ffffff);
+ uint64_t x88 = (x86 + x82);
+ uint64_t x89 = (x88 >> 0x19);
+ uint32_t x90 = ((uint32_t)x88 & 0x1ffffff);
+ uint64_t x91 = (x89 + x79);
+ uint64_t x92 = (x91 >> 0x1a);
+ uint32_t x93 = ((uint32_t)x91 & 0x3ffffff);
+ uint64_t x94 = (x92 + x76);
+ uint64_t x95 = (x94 >> 0x19);
+ uint32_t x96 = ((uint32_t)x94 & 0x1ffffff);
+ uint64_t x97 = (x95 + x73);
+ uint64_t x98 = (x97 >> 0x1a);
+ uint32_t x99 = ((uint32_t)x97 & 0x3ffffff);
+ uint64_t x100 = (x98 + x70);
+ uint64_t x101 = (x100 >> 0x19);
+ uint32_t x102 = ((uint32_t)x100 & 0x1ffffff);
+ uint64_t x103 = (x101 + x67);
+ uint64_t x104 = (x103 >> 0x1a);
+ uint32_t x105 = ((uint32_t)x103 & 0x3ffffff);
+ uint64_t x106 = (x104 + x64);
+ uint64_t x107 = (x106 >> 0x19);
+ uint32_t x108 = ((uint32_t)x106 & 0x1ffffff);
+ uint64_t x109 = (x107 + x61);
+ uint64_t x110 = (x109 >> 0x1a);
+ uint32_t x111 = ((uint32_t)x109 & 0x3ffffff);
+ uint64_t x112 = (x110 + x49);
+ uint64_t x113 = (x112 >> 0x19);
+ uint32_t x114 = ((uint32_t)x112 & 0x1ffffff);
+ uint64_t x115 = (x87 + (0x13 * x113));
+ uint32_t x116 = (uint32_t) (x115 >> 0x1a);
+ uint32_t x117 = ((uint32_t)x115 & 0x3ffffff);
+ uint32_t x118 = (x116 + x90);
+ uint32_t x119 = (x118 >> 0x19);
+ uint32_t x120 = (x118 & 0x1ffffff);
+ out[0] = x117;
+ out[1] = x120;
+ out[2] = (x119 + x93);
+ out[3] = x96;
+ out[4] = x99;
+ out[5] = x102;
+ out[6] = x105;
+ out[7] = x108;
+ out[8] = x111;
+ out[9] = x114;
+}
+
+static inline void fe_mul121666(fe *h, const fe_loose *f)
+{
+ fe_mul_121666_impl(h->v, f->v);
+}
+
+static const uint8_t curve25519_null_point[CURVE25519_KEY_SIZE];
+
+bool curve25519(uint8_t out[CURVE25519_KEY_SIZE],
+ const uint8_t scalar[CURVE25519_KEY_SIZE],
+ const uint8_t point[CURVE25519_KEY_SIZE])
+{
+ fe x1, x2, z2, x3, z3;
+ fe_loose x2l, z2l, x3l;
+ unsigned swap = 0;
+ int pos;
+ uint8_t e[32];
+
+ memcpy(e, scalar, 32);
+ curve25519_clamp_secret(e);
+
+ /* The following implementation was transcribed to Coq and proven to
+ * correspond to unary scalar multiplication in affine coordinates given
+ * that x1 != 0 is the x coordinate of some point on the curve. It was
+ * also checked in Coq that doing a ladderstep with x1 = x3 = 0 gives
+ * z2' = z3' = 0, and z2 = z3 = 0 gives z2' = z3' = 0. The statement was
+ * quantified over the underlying field, so it applies to Curve25519
+ * itself and the quadratic twist of Curve25519. It was not proven in
+ * Coq that prime-field arithmetic correctly simulates extension-field
+ * arithmetic on prime-field values. The decoding of the byte array
+ * representation of e was not considered.
+ *
+ * Specification of Montgomery curves in affine coordinates:
+ * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Spec/MontgomeryCurve.v#L27>
+ *
+ * Proof that these form a group that is isomorphic to a Weierstrass
+ * curve:
+ * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/AffineProofs.v#L35>
+ *
+ * Coq transcription and correctness proof of the loop
+ * (where scalarbits=255):
+ * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZ.v#L118>
+ * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZProofs.v#L278>
+ * preconditions: 0 <= e < 2^255 (not necessarily e < order),
+ * fe_invert(0) = 0
+ */
+ fe_frombytes(&x1, point);
+ fe_1(&x2);
+ fe_0(&z2);
+ fe_copy(&x3, &x1);
+ fe_1(&z3);
+
+ for (pos = 254; pos >= 0; --pos) {
+ fe tmp0, tmp1;
+ fe_loose tmp0l, tmp1l;
+ /* loop invariant as of right before the test, for the case
+ * where x1 != 0:
+ * pos >= -1; if z2 = 0 then x2 is nonzero; if z3 = 0 then x3
+ * is nonzero
+ * let r := e >> (pos+1) in the following equalities of
+ * projective points:
+ * to_xz (r*P) === if swap then (x3, z3) else (x2, z2)
+ * to_xz ((r+1)*P) === if swap then (x2, z2) else (x3, z3)
+ * x1 is the nonzero x coordinate of the nonzero
+ * point (r*P-(r+1)*P)
+ */
+ unsigned b = 1 & (e[pos / 8] >> (pos & 7));
+ swap ^= b;
+ fe_cswap(&x2, &x3, swap);
+ fe_cswap(&z2, &z3, swap);
+ swap = b;
+ /* Coq transcription of ladderstep formula (called from
+ * transcribed loop):
+ * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZ.v#L89>
+ * <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZProofs.v#L131>
+ * x1 != 0 <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZProofs.v#L217>
+ * x1 = 0 <https://github.com/mit-plv/fiat-crypto/blob/2456d821825521f7e03e65882cc3521795b0320f/src/Curves/Montgomery/XZProofs.v#L147>
+ */
+ fe_sub(&tmp0l, &x3, &z3);
+ fe_sub(&tmp1l, &x2, &z2);
+ fe_add(&x2l, &x2, &z2);
+ fe_add(&z2l, &x3, &z3);
+ fe_mul_tll(&z3, &tmp0l, &x2l);
+ fe_mul_tll(&z2, &z2l, &tmp1l);
+ fe_sq_tl(&tmp0, &tmp1l);
+ fe_sq_tl(&tmp1, &x2l);
+ fe_add(&x3l, &z3, &z2);
+ fe_sub(&z2l, &z3, &z2);
+ fe_mul_ttt(&x2, &tmp1, &tmp0);
+ fe_sub(&tmp1l, &tmp1, &tmp0);
+ fe_sq_tl(&z2, &z2l);
+ fe_mul121666(&z3, &tmp1l);
+ fe_sq_tl(&x3, &x3l);
+ fe_add(&tmp0l, &tmp0, &z3);
+ fe_mul_ttt(&z3, &x1, &z2);
+ fe_mul_tll(&z2, &tmp1l, &tmp0l);
+ }
+ /* here pos=-1, so r=e, so to_xz (e*P) === if swap then (x3, z3)
+ * else (x2, z2)
+ */
+ fe_cswap(&x2, &x3, swap);
+ fe_cswap(&z2, &z3, swap);
+
+ fe_invert(&z2, &z2);
+ fe_mul_ttt(&x2, &x2, &z2);
+ fe_tobytes(out, &x2);
+
+ explicit_bzero(&x1, sizeof(x1));
+ explicit_bzero(&x2, sizeof(x2));
+ explicit_bzero(&z2, sizeof(z2));
+ explicit_bzero(&x3, sizeof(x3));
+ explicit_bzero(&z3, sizeof(z3));
+ explicit_bzero(&x2l, sizeof(x2l));
+ explicit_bzero(&z2l, sizeof(z2l));
+ explicit_bzero(&x3l, sizeof(x3l));
+ explicit_bzero(&e, sizeof(e));
+
+ return timingsafe_bcmp(out, curve25519_null_point, CURVE25519_KEY_SIZE) != 0;
+}
diff --git a/src/crypto.h b/src/crypto.h
new file mode 100644
index 0000000..0ac23f9
--- /dev/null
+++ b/src/crypto.h
@@ -0,0 +1,103 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ */
+
+#ifndef _WG_CRYPTO
+#define _WG_CRYPTO
+
+#include <sys/types.h>
+
+enum chacha20poly1305_lengths {
+ XCHACHA20POLY1305_NONCE_SIZE = 24,
+ CHACHA20POLY1305_KEY_SIZE = 32,
+ CHACHA20POLY1305_AUTHTAG_SIZE = 16
+};
+
+void
+chacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, const size_t src_len,
+ const uint8_t *ad, const size_t ad_len,
+ const uint64_t nonce,
+ const uint8_t key[CHACHA20POLY1305_KEY_SIZE]);
+
+bool
+chacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, const size_t src_len,
+ const uint8_t *ad, const size_t ad_len,
+ const uint64_t nonce,
+ const uint8_t key[CHACHA20POLY1305_KEY_SIZE]);
+
+void
+xchacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src,
+ const size_t src_len, const uint8_t *ad,
+ const size_t ad_len,
+ const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE],
+ const uint8_t key[CHACHA20POLY1305_KEY_SIZE]);
+
+bool
+xchacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src,
+ const size_t src_len, const uint8_t *ad,
+ const size_t ad_len,
+ const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE],
+ const uint8_t key[CHACHA20POLY1305_KEY_SIZE]);
+
+
+enum blake2s_lengths {
+ BLAKE2S_BLOCK_SIZE = 64,
+ BLAKE2S_HASH_SIZE = 32,
+ BLAKE2S_KEY_SIZE = 32
+};
+
+struct blake2s_state {
+ uint32_t h[8];
+ uint32_t t[2];
+ uint32_t f[2];
+ uint8_t buf[BLAKE2S_BLOCK_SIZE];
+ unsigned int buflen;
+ unsigned int outlen;
+};
+
+void blake2s_init(struct blake2s_state *state, const size_t outlen);
+
+void blake2s_init_key(struct blake2s_state *state, const size_t outlen,
+ const uint8_t *key, const size_t keylen);
+
+void blake2s_update(struct blake2s_state *state, const uint8_t *in, size_t inlen);
+
+void blake2s_final(struct blake2s_state *state, uint8_t *out);
+
+void blake2s(uint8_t *out, const uint8_t *in, const uint8_t *key,
+ const size_t outlen, const size_t inlen, const size_t keylen);
+
+void blake2s_hmac(uint8_t *out, const uint8_t *in, const uint8_t *key,
+ const size_t outlen, const size_t inlen, const size_t keylen);
+
+enum curve25519_lengths {
+ CURVE25519_KEY_SIZE = 32
+};
+
+bool curve25519(uint8_t mypublic[static CURVE25519_KEY_SIZE],
+ const uint8_t secret[static CURVE25519_KEY_SIZE],
+ const uint8_t basepoint[static CURVE25519_KEY_SIZE]);
+
+static inline bool
+curve25519_generate_public(uint8_t pub[static CURVE25519_KEY_SIZE],
+ const uint8_t secret[static CURVE25519_KEY_SIZE])
+{
+ static const uint8_t basepoint[CURVE25519_KEY_SIZE] = { 9 };
+
+ return curve25519(pub, secret, basepoint);
+}
+
+static inline void curve25519_clamp_secret(uint8_t secret[static CURVE25519_KEY_SIZE])
+{
+ secret[0] &= 248;
+ secret[31] = (secret[31] & 127) | 64;
+}
+
+static inline void curve25519_generate_secret(uint8_t secret[CURVE25519_KEY_SIZE])
+{
+ arc4random_buf(secret, CURVE25519_KEY_SIZE);
+ curve25519_clamp_secret(secret);
+}
+
+#endif
diff --git a/src/if_wg.c b/src/if_wg.c
new file mode 100644
index 0000000..1fa7b4c
--- /dev/null
+++ b/src/if_wg.c
@@ -0,0 +1,3451 @@
+/* SPDX-License-Identifier: BSD-2-Clause-FreeBSD
+ *
+ * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
+ * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
+ * Copyright (c) 2021 Kyle Evans <kevans@FreeBSD.org>
+ */
+
+/* TODO audit imports */
+#include "opt_inet.h"
+#include "opt_inet6.h"
+
+#include <sys/cdefs.h>
+__FBSDID("$FreeBSD$");
+
+#include <sys/param.h>
+#include <sys/types.h>
+#include <sys/systm.h>
+#include <vm/uma.h>
+
+#include <sys/mbuf.h>
+#include <sys/socket.h>
+#include <sys/kernel.h>
+
+#include <sys/sockio.h>
+#include <sys/socketvar.h>
+#include <sys/errno.h>
+#include <sys/jail.h>
+#include <sys/priv.h>
+#include <sys/proc.h>
+#include <sys/lock.h>
+#include <sys/rwlock.h>
+#include <sys/rmlock.h>
+#include <sys/protosw.h>
+#include <sys/module.h>
+#include <sys/endian.h>
+#include <sys/kdb.h>
+#include <sys/sx.h>
+#include <sys/sysctl.h>
+#include <sys/gtaskqueue.h>
+#include <sys/smp.h>
+#include <sys/nv.h>
+
+#include <net/bpf.h>
+
+#include <sys/syslog.h>
+
+#include <net/if.h>
+#include <net/if_var.h>
+#include <net/if_clone.h>
+#include <net/if_types.h>
+#include <net/ethernet.h>
+#include <net/radix.h>
+
+#include <netinet/in.h>
+#include <netinet/in_var.h>
+#include <netinet/ip.h>
+#include <netinet/ip_var.h>
+#include <netinet/ip6.h>
+#include <netinet6/ip6_var.h>
+#include <netinet6/scope6_var.h>
+#include <netinet/udp.h>
+#include <netinet/ip_icmp.h>
+#include <netinet/icmp6.h>
+#include <netinet/in_pcb.h>
+#include <netinet6/in6_pcb.h>
+#include <netinet/udp_var.h>
+
+#include <machine/in_cksum.h>
+
+#include "support.h"
+#include "wg_noise.h"
+#include "wg_cookie.h"
+#include "if_wg.h"
+
+/* It'd be nice to use IF_MAXMTU, but that means more complicated mbuf allocations,
+ * so instead just do the biggest mbuf we can easily allocate minus the usual maximum
+ * IPv6 overhead of 80 bytes. If somebody wants bigger frames, we can revisit this. */
+#define MAX_MTU (MJUM16BYTES - 80)
+
+#define DEFAULT_MTU 1420
+
+#define MAX_STAGED_PKT 128
+#define MAX_QUEUED_PKT 1024
+#define MAX_QUEUED_PKT_MASK (MAX_QUEUED_PKT - 1)
+
+#define MAX_QUEUED_HANDSHAKES 4096
+
+#define HASHTABLE_PEER_SIZE (1 << 11)
+#define HASHTABLE_INDEX_SIZE (1 << 13)
+#define MAX_PEERS_PER_IFACE (1 << 20)
+
+#define REKEY_TIMEOUT 5
+#define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */
+#define KEEPALIVE_TIMEOUT 10
+#define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT)
+#define NEW_HANDSHAKE_TIMEOUT (REKEY_TIMEOUT + KEEPALIVE_TIMEOUT)
+#define UNDERLOAD_TIMEOUT 1
+
+#define DPRINTF(sc, ...) if (wireguard_debug) if_printf(sc->sc_ifp, ##__VA_ARGS__)
+
+/* First byte indicating packet type on the wire */
+#define WG_PKT_INITIATION htole32(1)
+#define WG_PKT_RESPONSE htole32(2)
+#define WG_PKT_COOKIE htole32(3)
+#define WG_PKT_DATA htole32(4)
+
+#define WG_PKT_WITH_PADDING(n) (((n) + (16-1)) & (~(16-1)))
+#define WG_KEY_SIZE 32
+
+struct wg_pkt_initiation {
+ uint32_t t;
+ uint32_t s_idx;
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN];
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN];
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN];
+ struct cookie_macs m;
+};
+
+struct wg_pkt_response {
+ uint32_t t;
+ uint32_t s_idx;
+ uint32_t r_idx;
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN];
+ uint8_t en[0 + NOISE_AUTHTAG_LEN];
+ struct cookie_macs m;
+};
+
+struct wg_pkt_cookie {
+ uint32_t t;
+ uint32_t r_idx;
+ uint8_t nonce[COOKIE_NONCE_SIZE];
+ uint8_t ec[COOKIE_ENCRYPTED_SIZE];
+};
+
+struct wg_pkt_data {
+ uint32_t t;
+ uint32_t r_idx;
+ uint8_t nonce[sizeof(uint64_t)];
+ uint8_t buf[];
+};
+
+struct wg_endpoint {
+ union {
+ struct sockaddr r_sa;
+ struct sockaddr_in r_sin;
+#ifdef INET6
+ struct sockaddr_in6 r_sin6;
+#endif
+ } e_remote;
+ union {
+ struct in_addr l_in;
+#ifdef INET6
+ struct in6_pktinfo l_pktinfo6;
+#define l_in6 l_pktinfo6.ipi6_addr
+#endif
+ } e_local;
+};
+
+struct wg_tag {
+ struct m_tag t_tag;
+ struct wg_endpoint t_endpoint;
+ struct wg_peer *t_peer;
+ struct mbuf *t_mbuf;
+ int t_done;
+ int t_mtu;
+};
+
+struct wg_index {
+ LIST_ENTRY(wg_index) i_entry;
+ SLIST_ENTRY(wg_index) i_unused_entry;
+ uint32_t i_key;
+ struct noise_remote *i_value;
+};
+
+struct wg_timers {
+ /* t_lock is for blocking wg_timers_event_* when setting t_disabled. */
+ struct rwlock t_lock;
+
+ int t_disabled;
+ int t_need_another_keepalive;
+ uint16_t t_persistent_keepalive_interval;
+ struct callout t_new_handshake;
+ struct callout t_send_keepalive;
+ struct callout t_retry_handshake;
+ struct callout t_zero_key_material;
+ struct callout t_persistent_keepalive;
+
+ struct mtx t_handshake_mtx;
+ struct timespec t_handshake_last_sent;
+ struct timespec t_handshake_complete;
+ volatile int t_handshake_retries;
+};
+
+struct wg_aip {
+ struct radix_node r_nodes[2];
+ CK_LIST_ENTRY(wg_aip) r_entry;
+ struct sockaddr_storage r_addr;
+ struct sockaddr_storage r_mask;
+ struct wg_peer *r_peer;
+};
+
+struct wg_queue {
+ struct mtx q_mtx;
+ struct mbufq q;
+};
+
+struct wg_peer {
+ CK_LIST_ENTRY(wg_peer) p_hash_entry;
+ CK_LIST_ENTRY(wg_peer) p_entry;
+ uint64_t p_id;
+ struct wg_softc *p_sc;
+
+ struct noise_remote p_remote;
+ struct cookie_maker p_cookie;
+ struct wg_timers p_timers;
+
+ struct rwlock p_endpoint_lock;
+ struct wg_endpoint p_endpoint;
+
+ SLIST_HEAD(,wg_index) p_unused_index;
+ struct wg_index p_index[3];
+
+ struct wg_queue p_stage_queue;
+ struct wg_queue p_encap_queue;
+ struct wg_queue p_decap_queue;
+
+ struct grouptask p_clear_secrets;
+ struct grouptask p_send_initiation;
+ struct grouptask p_send_keepalive;
+ struct grouptask p_send;
+ struct grouptask p_recv;
+
+ counter_u64_t p_tx_bytes;
+ counter_u64_t p_rx_bytes;
+
+ CK_LIST_HEAD(, wg_aip) p_aips;
+ struct mtx p_lock;
+ struct epoch_context p_ctx;
+};
+
+enum route_direction {
+ /* TODO OpenBSD doesn't use IN/OUT, instead passes the address buffer
+ * directly to route_lookup. */
+ IN,
+ OUT,
+};
+
+struct wg_aip_table {
+ size_t t_count;
+ struct radix_node_head *t_ip;
+ struct radix_node_head *t_ip6;
+};
+
+struct wg_allowedip {
+ uint16_t family;
+ union {
+ struct in_addr ip4;
+ struct in6_addr ip6;
+ };
+ uint8_t cidr;
+};
+
+struct wg_hashtable {
+ struct mtx h_mtx;
+ SIPHASH_KEY h_secret;
+ CK_LIST_HEAD(, wg_peer) h_peers_list;
+ CK_LIST_HEAD(, wg_peer) *h_peers;
+ u_long h_peers_mask;
+ size_t h_num_peers;
+};
+
+struct wg_socket {
+ struct mtx so_mtx;
+ struct socket *so_so4;
+ struct socket *so_so6;
+ uint32_t so_user_cookie;
+ in_port_t so_port;
+};
+
+struct wg_softc {
+ LIST_ENTRY(wg_softc) sc_entry;
+ struct ifnet *sc_ifp;
+ int sc_flags;
+
+ struct ucred *sc_ucred;
+ struct wg_socket sc_socket;
+ struct wg_hashtable sc_hashtable;
+ struct wg_aip_table sc_aips;
+
+ struct mbufq sc_handshake_queue;
+ struct grouptask sc_handshake;
+
+ struct noise_local sc_local;
+ struct cookie_checker sc_cookie;
+
+ struct buf_ring *sc_encap_ring;
+ struct buf_ring *sc_decap_ring;
+
+ struct grouptask *sc_encrypt;
+ struct grouptask *sc_decrypt;
+
+ struct rwlock sc_index_lock;
+ LIST_HEAD(,wg_index) *sc_index;
+ u_long sc_index_mask;
+
+ struct sx sc_lock;
+ volatile u_int sc_peer_count;
+};
+
+#define WGF_DYING 0x0001
+
+/* TODO the following defines are freebsd specific, we should see what is
+ * necessary and cleanup from there (i suspect a lot can be junked). */
+
+#ifndef ENOKEY
+#define ENOKEY ENOTCAPABLE
+#endif
+
+#if __FreeBSD_version > 1300000
+typedef void timeout_t (void *);
+#endif
+
+#define GROUPTASK_DRAIN(gtask) \
+ gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task)
+
+#define MTAG_WIREGUARD 0xBEAD
+#define M_ENQUEUED M_PROTO1
+
+static int clone_count;
+static uma_zone_t ratelimit_zone;
+static int wireguard_debug;
+static volatile unsigned long peer_counter = 0;
+static const char wgname[] = "wg";
+static unsigned wg_osd_jail_slot;
+
+static struct sx wg_sx;
+SX_SYSINIT(wg_sx, &wg_sx, "wg_sx");
+
+static LIST_HEAD(, wg_softc) wg_list = LIST_HEAD_INITIALIZER(wg_list);
+
+SYSCTL_NODE(_net, OID_AUTO, wg, CTLFLAG_RW, 0, "WireGuard");
+SYSCTL_INT(_net_wg, OID_AUTO, debug, CTLFLAG_RWTUN, &wireguard_debug, 0,
+ "enable debug logging");
+
+TASKQGROUP_DECLARE(if_io_tqg);
+
+MALLOC_DEFINE(M_WG, "WG", "wireguard");
+VNET_DEFINE_STATIC(struct if_clone *, wg_cloner);
+
+
+#define V_wg_cloner VNET(wg_cloner)
+#define WG_CAPS IFCAP_LINKSTATE
+#define ph_family PH_loc.eight[5]
+
+struct wg_timespec64 {
+ uint64_t tv_sec;
+ uint64_t tv_nsec;
+};
+
+struct wg_peer_export {
+ struct sockaddr_storage endpoint;
+ struct timespec last_handshake;
+ uint8_t public_key[WG_KEY_SIZE];
+ uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN];
+ size_t endpoint_sz;
+ struct wg_allowedip *aip;
+ uint64_t rx_bytes;
+ uint64_t tx_bytes;
+ int aip_count;
+ uint16_t persistent_keepalive;
+};
+
+static struct wg_tag *wg_tag_get(struct mbuf *);
+static struct wg_endpoint *wg_mbuf_endpoint_get(struct mbuf *);
+static int wg_socket_init(struct wg_softc *, in_port_t);
+static int wg_socket_bind(struct socket *, struct socket *, in_port_t *);
+static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *);
+static void wg_socket_uninit(struct wg_softc *);
+static void wg_socket_set_cookie(struct wg_softc *, uint32_t);
+static int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *);
+static void wg_timers_event_data_sent(struct wg_timers *);
+static void wg_timers_event_data_received(struct wg_timers *);
+static void wg_timers_event_any_authenticated_packet_sent(struct wg_timers *);
+static void wg_timers_event_any_authenticated_packet_received(struct wg_timers *);
+static void wg_timers_event_any_authenticated_packet_traversal(struct wg_timers *);
+static void wg_timers_event_handshake_initiated(struct wg_timers *);
+static void wg_timers_event_handshake_responded(struct wg_timers *);
+static void wg_timers_event_handshake_complete(struct wg_timers *);
+static void wg_timers_event_session_derived(struct wg_timers *);
+static void wg_timers_event_want_initiation(struct wg_timers *);
+static void wg_timers_event_reset_handshake_last_sent(struct wg_timers *);
+static void wg_timers_run_send_initiation(struct wg_timers *, int);
+static void wg_timers_run_retry_handshake(struct wg_timers *);
+static void wg_timers_run_send_keepalive(struct wg_timers *);
+static void wg_timers_run_new_handshake(struct wg_timers *);
+static void wg_timers_run_zero_key_material(struct wg_timers *);
+static void wg_timers_run_persistent_keepalive(struct wg_timers *);
+static void wg_timers_init(struct wg_timers *);
+static void wg_timers_enable(struct wg_timers *);
+static void wg_timers_disable(struct wg_timers *);
+static void wg_timers_set_persistent_keepalive(struct wg_timers *, uint16_t);
+static void wg_timers_get_last_handshake(struct wg_timers *, struct timespec *);
+static int wg_timers_expired_handshake_last_sent(struct wg_timers *);
+static int wg_timers_check_handshake_last_sent(struct wg_timers *);
+static void wg_queue_init(struct wg_queue *, const char *);
+static void wg_queue_deinit(struct wg_queue *);
+static void wg_queue_purge(struct wg_queue *);
+static struct mbuf *wg_queue_dequeue(struct wg_queue *, struct wg_tag **);
+static int wg_queue_len(struct wg_queue *);
+static int wg_queue_in(struct wg_peer *, struct mbuf *);
+static void wg_queue_out(struct wg_peer *);
+static void wg_queue_stage(struct wg_peer *, struct mbuf *);
+static int wg_aip_init(struct wg_aip_table *);
+static void wg_aip_destroy(struct wg_aip_table *);
+static void wg_aip_populate_aip4(struct wg_aip *, const struct in_addr *, uint8_t);
+static void wg_aip_populate_aip6(struct wg_aip *, const struct in6_addr *, uint8_t);
+static int wg_aip_add(struct wg_aip_table *, struct wg_peer *, const struct wg_allowedip *);
+static int wg_peer_remove(struct radix_node *, void *);
+static void wg_peer_remove_all(struct wg_softc *);
+static int wg_aip_delete(struct wg_aip_table *, struct wg_peer *);
+static struct wg_peer *wg_aip_lookup(struct wg_aip_table *, struct mbuf *, enum route_direction);
+static void wg_hashtable_init(struct wg_hashtable *);
+static void wg_hashtable_destroy(struct wg_hashtable *);
+static void wg_hashtable_peer_insert(struct wg_hashtable *, struct wg_peer *);
+static struct wg_peer *wg_peer_lookup(struct wg_softc *, const uint8_t [32]);
+static void wg_hashtable_peer_remove(struct wg_hashtable *, struct wg_peer *);
+static int wg_cookie_validate_packet(struct cookie_checker *, struct mbuf *, int);
+static struct wg_peer *wg_peer_alloc(struct wg_softc *);
+static void wg_peer_free_deferred(epoch_context_t);
+static void wg_peer_destroy(struct wg_peer *);
+static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t);
+static void wg_send_initiation(struct wg_peer *);
+static void wg_send_response(struct wg_peer *);
+static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct mbuf *);
+static void wg_peer_set_endpoint_from_tag(struct wg_peer *, struct wg_tag *);
+static void wg_peer_clear_src(struct wg_peer *);
+static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *);
+static void wg_deliver_out(struct wg_peer *);
+static void wg_deliver_in(struct wg_peer *);
+static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t);
+static void wg_send_keepalive(struct wg_peer *);
+static void wg_handshake(struct wg_softc *, struct mbuf *);
+static void wg_encap(struct wg_softc *, struct mbuf *);
+static void wg_decap(struct wg_softc *, struct mbuf *);
+static void wg_softc_handshake_receive(struct wg_softc *);
+static void wg_softc_decrypt(struct wg_softc *);
+static void wg_softc_encrypt(struct wg_softc *);
+static struct noise_remote *wg_remote_get(struct wg_softc *, uint8_t [NOISE_PUBLIC_KEY_LEN]);
+static uint32_t wg_index_set(struct wg_softc *, struct noise_remote *);
+static struct noise_remote *wg_index_get(struct wg_softc *, uint32_t);
+static void wg_index_drop(struct wg_softc *, uint32_t);
+static int wg_update_endpoint_addrs(struct wg_endpoint *, const struct sockaddr *, struct ifnet *);
+static void wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *);
+static void wg_encrypt_dispatch(struct wg_softc *);
+static void wg_decrypt_dispatch(struct wg_softc *);
+static void crypto_taskq_setup(struct wg_softc *);
+static void crypto_taskq_destroy(struct wg_softc *);
+static int wg_clone_create(struct if_clone *, int, caddr_t);
+static void wg_qflush(struct ifnet *);
+static int wg_transmit(struct ifnet *, struct mbuf *);
+static int wg_output(struct ifnet *, struct mbuf *, const struct sockaddr *, struct route *);
+static void wg_clone_destroy(struct ifnet *);
+static int wg_peer_to_export(struct wg_peer *, struct wg_peer_export *);
+static bool wgc_privileged(struct wg_softc *);
+static int wgc_get(struct wg_softc *, struct wg_data_io *);
+static int wgc_set(struct wg_softc *, struct wg_data_io *);
+static int wg_up(struct wg_softc *);
+static void wg_down(struct wg_softc *);
+static void wg_reassign(struct ifnet *, struct vnet *, char *unused);
+static void wg_init(void *);
+static int wg_ioctl(struct ifnet *, u_long, caddr_t);
+static void vnet_wg_init(const void *);
+static void vnet_wg_uninit(const void *);
+static void wg_module_init(void);
+static void wg_module_deinit(void);
+
+/* TODO Peer */
+static struct wg_peer *
+wg_peer_alloc(struct wg_softc *sc)
+{
+ struct wg_peer *peer;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ peer = malloc(sizeof(*peer), M_WG, M_WAITOK|M_ZERO);
+ peer->p_sc = sc;
+ peer->p_id = peer_counter++;
+ CK_LIST_INIT(&peer->p_aips);
+
+ rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint");
+ wg_queue_init(&peer->p_stage_queue, "stageq");
+ wg_queue_init(&peer->p_encap_queue, "txq");
+ wg_queue_init(&peer->p_decap_queue, "rxq");
+
+ GROUPTASK_INIT(&peer->p_send_initiation, 0, (gtask_fn_t *)wg_send_initiation, peer);
+ taskqgroup_attach(qgroup_if_io_tqg, &peer->p_send_initiation, peer, NULL, NULL, "wg initiation");
+ GROUPTASK_INIT(&peer->p_send_keepalive, 0, (gtask_fn_t *)wg_send_keepalive, peer);
+ taskqgroup_attach(qgroup_if_io_tqg, &peer->p_send_keepalive, peer, NULL, NULL, "wg keepalive");
+ GROUPTASK_INIT(&peer->p_clear_secrets, 0, (gtask_fn_t *)noise_remote_clear, &peer->p_remote);
+ taskqgroup_attach(qgroup_if_io_tqg, &peer->p_clear_secrets,
+ &peer->p_remote, NULL, NULL, "wg clear secrets");
+
+ GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer);
+ taskqgroup_attach(qgroup_if_io_tqg, &peer->p_send, peer, NULL, NULL, "wg send");
+ GROUPTASK_INIT(&peer->p_recv, 0, (gtask_fn_t *)wg_deliver_in, peer);
+ taskqgroup_attach(qgroup_if_io_tqg, &peer->p_recv, peer, NULL, NULL, "wg recv");
+
+ wg_timers_init(&peer->p_timers);
+
+ peer->p_tx_bytes = counter_u64_alloc(M_WAITOK);
+ peer->p_rx_bytes = counter_u64_alloc(M_WAITOK);
+
+ SLIST_INIT(&peer->p_unused_index);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[0],
+ i_unused_entry);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[1],
+ i_unused_entry);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[2],
+ i_unused_entry);
+
+ return (peer);
+}
+
+#define WG_HASHTABLE_PEER_FOREACH(peer, i, ht) \
+ for (i = 0; i < HASHTABLE_PEER_SIZE; i++) \
+ LIST_FOREACH(peer, &(ht)->h_peers[i], p_hash_entry)
+#define WG_HASHTABLE_PEER_FOREACH_SAFE(peer, i, ht, tpeer) \
+ for (i = 0; i < HASHTABLE_PEER_SIZE; i++) \
+ CK_LIST_FOREACH_SAFE(peer, &(ht)->h_peers[i], p_hash_entry, tpeer)
+static void
+wg_hashtable_init(struct wg_hashtable *ht)
+{
+ mtx_init(&ht->h_mtx, "hash lock", NULL, MTX_DEF);
+ arc4random_buf(&ht->h_secret, sizeof(ht->h_secret));
+ ht->h_num_peers = 0;
+ ht->h_peers = hashinit(HASHTABLE_PEER_SIZE, M_DEVBUF,
+ &ht->h_peers_mask);
+}
+
+static void
+wg_hashtable_destroy(struct wg_hashtable *ht)
+{
+ MPASS(ht->h_num_peers == 0);
+ mtx_destroy(&ht->h_mtx);
+ hashdestroy(ht->h_peers, M_DEVBUF, ht->h_peers_mask);
+}
+
+static void
+wg_hashtable_peer_insert(struct wg_hashtable *ht, struct wg_peer *peer)
+{
+ uint64_t key;
+
+ key = siphash24(&ht->h_secret, peer->p_remote.r_public,
+ sizeof(peer->p_remote.r_public));
+
+ mtx_lock(&ht->h_mtx);
+ ht->h_num_peers++;
+ CK_LIST_INSERT_HEAD(&ht->h_peers[key & ht->h_peers_mask], peer, p_hash_entry);
+ CK_LIST_INSERT_HEAD(&ht->h_peers_list, peer, p_entry);
+ mtx_unlock(&ht->h_mtx);
+}
+
+static struct wg_peer *
+wg_peer_lookup(struct wg_softc *sc,
+ const uint8_t pubkey[WG_KEY_SIZE])
+{
+ struct wg_hashtable *ht = &sc->sc_hashtable;
+ uint64_t key;
+ struct wg_peer *i = NULL;
+
+ key = siphash24(&ht->h_secret, pubkey, WG_KEY_SIZE);
+
+ mtx_lock(&ht->h_mtx);
+ CK_LIST_FOREACH(i, &ht->h_peers[key & ht->h_peers_mask], p_hash_entry) {
+ if (timingsafe_bcmp(i->p_remote.r_public, pubkey,
+ WG_KEY_SIZE) == 0)
+ break;
+ }
+ mtx_unlock(&ht->h_mtx);
+
+ return i;
+}
+
+static void
+wg_hashtable_peer_remove(struct wg_hashtable *ht, struct wg_peer *peer)
+{
+ mtx_lock(&ht->h_mtx);
+ ht->h_num_peers--;
+ CK_LIST_REMOVE(peer, p_hash_entry);
+ CK_LIST_REMOVE(peer, p_entry);
+ mtx_unlock(&ht->h_mtx);
+}
+
+static void
+wg_peer_free_deferred(epoch_context_t ctx)
+{
+ struct wg_peer *peer = __containerof(ctx, struct wg_peer, p_ctx);
+ counter_u64_free(peer->p_tx_bytes);
+ counter_u64_free(peer->p_rx_bytes);
+ rw_destroy(&peer->p_timers.t_lock);
+ rw_destroy(&peer->p_endpoint_lock);
+ free(peer, M_WG);
+}
+
+static void
+wg_peer_destroy(struct wg_peer *peer)
+{
+ /* Callers should already have called:
+ * wg_hashtable_peer_remove(&sc->sc_hashtable, peer);
+ */
+ wg_aip_delete(&peer->p_sc->sc_aips, peer);
+ MPASS(CK_LIST_EMPTY(&peer->p_aips));
+
+ /* We disable all timers, so we can't call the following tasks. */
+ wg_timers_disable(&peer->p_timers);
+
+ /* Ensure the tasks have finished running */
+ GROUPTASK_DRAIN(&peer->p_clear_secrets);
+ GROUPTASK_DRAIN(&peer->p_send_initiation);
+ GROUPTASK_DRAIN(&peer->p_send_keepalive);
+ GROUPTASK_DRAIN(&peer->p_recv);
+ GROUPTASK_DRAIN(&peer->p_send);
+
+ taskqgroup_detach(qgroup_if_io_tqg, &peer->p_clear_secrets);
+ taskqgroup_detach(qgroup_if_io_tqg, &peer->p_send_initiation);
+ taskqgroup_detach(qgroup_if_io_tqg, &peer->p_send_keepalive);
+ taskqgroup_detach(qgroup_if_io_tqg, &peer->p_recv);
+ taskqgroup_detach(qgroup_if_io_tqg, &peer->p_send);
+
+ wg_queue_deinit(&peer->p_decap_queue);
+ wg_queue_deinit(&peer->p_encap_queue);
+ wg_queue_deinit(&peer->p_stage_queue);
+
+ /* Final cleanup */
+ --peer->p_sc->sc_peer_count;
+ noise_remote_clear(&peer->p_remote);
+ DPRINTF(peer->p_sc, "Peer %llu destroyed\n", (unsigned long long)peer->p_id);
+ NET_EPOCH_CALL(wg_peer_free_deferred, &peer->p_ctx);
+}
+
+static void
+wg_peer_set_endpoint_from_tag(struct wg_peer *peer, struct wg_tag *t)
+{
+ struct wg_endpoint *e = &t->t_endpoint;
+
+ MPASS(e->e_remote.r_sa.sa_family != 0);
+ if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0)
+ return;
+
+ peer->p_endpoint = *e;
+}
+
+static void
+wg_peer_clear_src(struct wg_peer *peer)
+{
+ rw_rlock(&peer->p_endpoint_lock);
+ bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local));
+ rw_runlock(&peer->p_endpoint_lock);
+}
+
+static void
+wg_peer_get_endpoint(struct wg_peer *p, struct wg_endpoint *e)
+{
+ memcpy(e, &p->p_endpoint, sizeof(*e));
+}
+
+/* Allowed IP */
+static int
+wg_aip_init(struct wg_aip_table *tbl)
+{
+ int rc;
+
+ tbl->t_count = 0;
+ rc = rn_inithead((void **)&tbl->t_ip,
+ offsetof(struct sockaddr_in, sin_addr) * NBBY);
+
+ if (rc == 0)
+ return (ENOMEM);
+ RADIX_NODE_HEAD_LOCK_INIT(tbl->t_ip);
+#ifdef INET6
+ rc = rn_inithead((void **)&tbl->t_ip6,
+ offsetof(struct sockaddr_in6, sin6_addr) * NBBY);
+ if (rc == 0) {
+ free(tbl->t_ip, M_RTABLE);
+ return (ENOMEM);
+ }
+ RADIX_NODE_HEAD_LOCK_INIT(tbl->t_ip6);
+#endif
+ return (0);
+}
+
+static void
+wg_aip_destroy(struct wg_aip_table *tbl)
+{
+ RADIX_NODE_HEAD_DESTROY(tbl->t_ip);
+ free(tbl->t_ip, M_RTABLE);
+#ifdef INET6
+ RADIX_NODE_HEAD_DESTROY(tbl->t_ip6);
+ free(tbl->t_ip6, M_RTABLE);
+#endif
+}
+
+static void
+wg_aip_populate_aip4(struct wg_aip *aip, const struct in_addr *addr,
+ uint8_t mask)
+{
+ struct sockaddr_in *raddr, *rmask;
+ uint8_t *p;
+ unsigned int i;
+
+ raddr = (struct sockaddr_in *)&aip->r_addr;
+ rmask = (struct sockaddr_in *)&aip->r_mask;
+
+ raddr->sin_len = sizeof(*raddr);
+ raddr->sin_family = AF_INET;
+ raddr->sin_addr = *addr;
+
+ rmask->sin_len = sizeof(*rmask);
+ p = (uint8_t *)&rmask->sin_addr.s_addr;
+ for (i = 0; i < mask / NBBY; i++)
+ p[i] = 0xff;
+ if ((mask % NBBY) != 0)
+ p[i] = (0xff00 >> (mask % NBBY)) & 0xff;
+ raddr->sin_addr.s_addr &= rmask->sin_addr.s_addr;
+}
+
+static void
+wg_aip_populate_aip6(struct wg_aip *aip, const struct in6_addr *addr,
+ uint8_t mask)
+{
+ struct sockaddr_in6 *raddr, *rmask;
+
+ raddr = (struct sockaddr_in6 *)&aip->r_addr;
+ rmask = (struct sockaddr_in6 *)&aip->r_mask;
+
+ raddr->sin6_len = sizeof(*raddr);
+ raddr->sin6_family = AF_INET6;
+ raddr->sin6_addr = *addr;
+
+ rmask->sin6_len = sizeof(*rmask);
+ in6_prefixlen2mask(&rmask->sin6_addr, mask);
+ for (int i = 0; i < 4; ++i)
+ raddr->sin6_addr.__u6_addr.__u6_addr32[i] &= rmask->sin6_addr.__u6_addr.__u6_addr32[i];
+}
+
+/* wg_aip_take assumes that the caller guarantees the allowed-ip exists. */
+static void
+wg_aip_take(struct radix_node_head *root, struct wg_peer *peer,
+ struct wg_aip *route)
+{
+ struct radix_node *node;
+ struct wg_peer *ppeer;
+
+ RADIX_NODE_HEAD_LOCK_ASSERT(root);
+
+ node = root->rnh_lookup(&route->r_addr, &route->r_mask,
+ &root->rh);
+ MPASS(node != NULL);
+
+ route = (struct wg_aip *)node;
+ ppeer = route->r_peer;
+ if (ppeer != peer) {
+ route->r_peer = peer;
+
+ CK_LIST_REMOVE(route, r_entry);
+ CK_LIST_INSERT_HEAD(&peer->p_aips, route, r_entry);
+ }
+}
+
+static int
+wg_aip_add(struct wg_aip_table *tbl, struct wg_peer *peer,
+ const struct wg_allowedip *aip)
+{
+ struct radix_node *node;
+ struct radix_node_head *root;
+ struct wg_aip *route;
+ sa_family_t family;
+ bool needfree = false;
+
+ family = aip->family;
+ if (family != AF_INET && family != AF_INET6) {
+ return (EINVAL);
+ }
+
+ route = malloc(sizeof(*route), M_WG, M_WAITOK|M_ZERO);
+ switch (family) {
+ case AF_INET:
+ root = tbl->t_ip;
+
+ wg_aip_populate_aip4(route, &aip->ip4, aip->cidr);
+ break;
+ case AF_INET6:
+ root = tbl->t_ip6;
+
+ wg_aip_populate_aip6(route, &aip->ip6, aip->cidr);
+ break;
+ }
+
+ route->r_peer = peer;
+
+ RADIX_NODE_HEAD_LOCK(root);
+ node = root->rnh_addaddr(&route->r_addr, &route->r_mask, &root->rh,
+ route->r_nodes);
+ if (node == route->r_nodes) {
+ tbl->t_count++;
+ CK_LIST_INSERT_HEAD(&peer->p_aips, route, r_entry);
+ } else {
+ needfree = true;
+ wg_aip_take(root, peer, route);
+ }
+ RADIX_NODE_HEAD_UNLOCK(root);
+ if (needfree) {
+ free(route, M_WG);
+ }
+ return (0);
+}
+
+static struct wg_peer *
+wg_aip_lookup(struct wg_aip_table *tbl, struct mbuf *m,
+ enum route_direction dir)
+{
+ RADIX_NODE_HEAD_RLOCK_TRACKER;
+ struct ip *iphdr;
+ struct ip6_hdr *ip6hdr;
+ struct radix_node_head *root;
+ struct radix_node *node;
+ struct wg_peer *peer = NULL;
+ struct sockaddr_in sin;
+ struct sockaddr_in6 sin6;
+ void *addr;
+ int version;
+
+ NET_EPOCH_ASSERT();
+ iphdr = mtod(m, struct ip *);
+ version = iphdr->ip_v;
+
+ if (__predict_false(dir != IN && dir != OUT))
+ return NULL;
+
+ if (version == 4) {
+ root = tbl->t_ip;
+ memset(&sin, 0, sizeof(sin));
+ sin.sin_len = sizeof(struct sockaddr_in);
+ if (dir == IN)
+ sin.sin_addr = iphdr->ip_src;
+ else
+ sin.sin_addr = iphdr->ip_dst;
+ addr = &sin;
+ } else if (version == 6) {
+ ip6hdr = mtod(m, struct ip6_hdr *);
+ memset(&sin6, 0, sizeof(sin6));
+ sin6.sin6_len = sizeof(struct sockaddr_in6);
+
+ root = tbl->t_ip6;
+ if (dir == IN)
+ addr = &ip6hdr->ip6_src;
+ else
+ addr = &ip6hdr->ip6_dst;
+ memcpy(&sin6.sin6_addr, addr, sizeof(sin6.sin6_addr));
+ addr = &sin6;
+ } else {
+ return (NULL);
+ }
+ RADIX_NODE_HEAD_RLOCK(root);
+ if ((node = root->rnh_matchaddr(addr, &root->rh)) != NULL) {
+ peer = ((struct wg_aip *) node)->r_peer;
+ }
+ RADIX_NODE_HEAD_RUNLOCK(root);
+ return (peer);
+}
+
+struct peer_del_arg {
+ struct radix_node_head * pda_head;
+ struct wg_peer *pda_peer;
+ struct wg_aip_table *pda_tbl;
+};
+
+static int
+wg_peer_remove(struct radix_node *rn, void *arg)
+{
+ struct peer_del_arg *pda = arg;
+ struct wg_peer *peer = pda->pda_peer;
+ struct radix_node_head * rnh = pda->pda_head;
+ struct wg_aip_table *tbl = pda->pda_tbl;
+ struct wg_aip *route = (struct wg_aip *)rn;
+ struct radix_node *x;
+
+ if (route->r_peer != peer)
+ return (0);
+ x = (struct radix_node *)rnh->rnh_deladdr(&route->r_addr,
+ &route->r_mask, &rnh->rh);
+ if (x != NULL) {
+ tbl->t_count--;
+ CK_LIST_REMOVE(route, r_entry);
+ free(route, M_WG);
+ }
+ return (0);
+}
+
+static void
+wg_peer_remove_all(struct wg_softc *sc)
+{
+ struct wg_peer *peer, *tpeer;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ CK_LIST_FOREACH_SAFE(peer, &sc->sc_hashtable.h_peers_list,
+ p_entry, tpeer) {
+ wg_hashtable_peer_remove(&sc->sc_hashtable, peer);
+ wg_peer_destroy(peer);
+ }
+}
+
+static int
+wg_aip_delete(struct wg_aip_table *tbl, struct wg_peer *peer)
+{
+ struct peer_del_arg pda;
+
+ pda.pda_peer = peer;
+ pda.pda_tbl = tbl;
+ RADIX_NODE_HEAD_LOCK(tbl->t_ip);
+ pda.pda_head = tbl->t_ip;
+ rn_walktree(&tbl->t_ip->rh, wg_peer_remove, &pda);
+ RADIX_NODE_HEAD_UNLOCK(tbl->t_ip);
+
+ RADIX_NODE_HEAD_LOCK(tbl->t_ip6);
+ pda.pda_head = tbl->t_ip6;
+ rn_walktree(&tbl->t_ip6->rh, wg_peer_remove, &pda);
+ RADIX_NODE_HEAD_UNLOCK(tbl->t_ip6);
+ return (0);
+}
+
+static int
+wg_socket_init(struct wg_softc *sc, in_port_t port)
+{
+ struct thread *td;
+ struct ucred *cred;
+ struct socket *so4, *so6;
+ int rc;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ so4 = so6 = NULL;
+ td = curthread;
+ if ((cred = sc->sc_ucred) == NULL)
+ return (EBUSY);
+
+ /*
+ * For socket creation, we use the creds of the thread that created the
+ * tunnel rather than the current thread to maintain the semantics that
+ * WireGuard has on Linux with network namespaces -- that the sockets
+ * are created in their home vnet so that they can be configured and
+ * functionally attached to a foreign vnet as the jail's only interface
+ * to the network.
+ */
+ rc = socreate(AF_INET, &so4, SOCK_DGRAM, IPPROTO_UDP, cred, td);
+ if (rc != 0)
+ goto out;
+
+ rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc);
+ /*
+ * udp_set_kernel_tunneling can only fail if there is already a tunneling function set.
+ * This should never happen with a new socket.
+ */
+ MPASS(rc == 0);
+
+ rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, td);
+ if (rc != 0)
+ goto out;
+ rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc);
+ MPASS(rc == 0);
+
+ so4->so_user_cookie = so6->so_user_cookie = sc->sc_socket.so_user_cookie;
+
+ rc = wg_socket_bind(so4, so6, &port);
+ if (rc == 0) {
+ sc->sc_socket.so_port = port;
+ wg_socket_set(sc, so4, so6);
+ }
+out:
+ if (rc != 0) {
+ if (so4 != NULL)
+ soclose(so4);
+ if (so6 != NULL)
+ soclose(so6);
+ }
+ return (rc);
+}
+
+static void wg_socket_set_cookie(struct wg_softc *sc, uint32_t user_cookie)
+{
+ struct wg_socket *so = &sc->sc_socket;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ so->so_user_cookie = user_cookie;
+ if (so->so_so4)
+ so->so_so4->so_user_cookie = user_cookie;
+ if (so->so_so6)
+ so->so_so6->so_user_cookie = user_cookie;
+}
+
+static void
+wg_socket_uninit(struct wg_softc *sc)
+{
+ wg_socket_set(sc, NULL, NULL);
+}
+
+static void
+wg_socket_set(struct wg_softc *sc, struct socket *new_so4, struct socket *new_so6)
+{
+ struct wg_socket *so = &sc->sc_socket;
+ struct socket *so4, *so6;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ so4 = atomic_load_ptr(&so->so_so4);
+ so6 = atomic_load_ptr(&so->so_so6);
+ atomic_store_ptr(&so->so_so4, new_so4);
+ atomic_store_ptr(&so->so_so6, new_so6);
+
+ if (!so4 && !so6)
+ return;
+ NET_EPOCH_WAIT();
+ if (so4)
+ soclose(so4);
+ if (so6)
+ soclose(so6);
+}
+
+union wg_sockaddr {
+ struct sockaddr sa;
+ struct sockaddr_in in4;
+ struct sockaddr_in6 in6;
+};
+
+static int
+wg_socket_bind(struct socket *so4, struct socket *so6, in_port_t *requested_port)
+{
+ int rc;
+ struct thread *td;
+ union wg_sockaddr laddr;
+ struct sockaddr_in *sin;
+ struct sockaddr_in6 *sin6;
+ in_port_t port = *requested_port;
+
+ td = curthread;
+ bzero(&laddr, sizeof(laddr));
+ sin = &laddr.in4;
+ sin->sin_len = sizeof(laddr.in4);
+ sin->sin_family = AF_INET;
+ sin->sin_port = htons(port);
+ sin->sin_addr = (struct in_addr) { 0 };
+
+ if ((rc = sobind(so4, &laddr.sa, td)) != 0)
+ return (rc);
+
+ if (port == 0) {
+ rc = sogetsockaddr(so4, (struct sockaddr **)&sin);
+ if (rc != 0)
+ return (rc);
+ port = ntohs(sin->sin_port);
+ free(sin, M_SONAME);
+ }
+
+ sin6 = &laddr.in6;
+ sin6->sin6_len = sizeof(laddr.in6);
+ sin6->sin6_family = AF_INET6;
+ sin6->sin6_port = htons(port);
+ sin6->sin6_addr = (struct in6_addr) { .s6_addr = { 0 } };
+ rc = sobind(so6, &laddr.sa, td);
+ if (rc != 0)
+ return (rc);
+ *requested_port = port;
+ return (0);
+}
+
+static int
+wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m)
+{
+ struct epoch_tracker et;
+ struct sockaddr *sa;
+ struct wg_socket *so = &sc->sc_socket;
+ struct socket *so4, *so6;
+ struct mbuf *control = NULL;
+ int ret = 0;
+ size_t len = m->m_pkthdr.len;
+
+ /* Get local control address before locking */
+ if (e->e_remote.r_sa.sa_family == AF_INET) {
+ if (e->e_local.l_in.s_addr != INADDR_ANY)
+ control = sbcreatecontrol((caddr_t)&e->e_local.l_in,
+ sizeof(struct in_addr), IP_SENDSRCADDR,
+ IPPROTO_IP);
+#ifdef INET6
+ } else if (e->e_remote.r_sa.sa_family == AF_INET6) {
+ if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6))
+ control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6,
+ sizeof(struct in6_pktinfo), IPV6_PKTINFO,
+ IPPROTO_IPV6);
+#endif
+ } else {
+ m_freem(m);
+ return (EAFNOSUPPORT);
+ }
+
+ /* Get remote address */
+ sa = &e->e_remote.r_sa;
+
+ NET_EPOCH_ENTER(et);
+ so4 = atomic_load_ptr(&so->so_so4);
+ so6 = atomic_load_ptr(&so->so_so6);
+ if (e->e_remote.r_sa.sa_family == AF_INET && so4 != NULL)
+ ret = sosend(so4, sa, NULL, m, control, 0, curthread);
+ else if (e->e_remote.r_sa.sa_family == AF_INET6 && so6 != NULL)
+ ret = sosend(so6, sa, NULL, m, control, 0, curthread);
+ else {
+ ret = ENOTCONN;
+ m_freem(control);
+ m_freem(m);
+ }
+ NET_EPOCH_EXIT(et);
+ if (ret == 0) {
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_OPACKETS, 1);
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_OBYTES, len);
+ }
+ return (ret);
+}
+
+static void
+wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf,
+ size_t len)
+{
+ struct mbuf *m;
+ int ret = 0;
+
+retry:
+ m = m_gethdr(M_WAITOK, MT_DATA);
+ m->m_len = 0;
+ m_copyback(m, 0, len, buf);
+
+ if (ret == 0) {
+ ret = wg_send(sc, e, m);
+ /* Retry if we couldn't bind to e->e_local */
+ if (ret == EADDRNOTAVAIL) {
+ bzero(&e->e_local, sizeof(e->e_local));
+ goto retry;
+ }
+ } else {
+ ret = wg_send(sc, e, m);
+ }
+ if (ret)
+ DPRINTF(sc, "Unable to send packet: %d\n", ret);
+}
+
+/* TODO Tag */
+static struct wg_tag *
+wg_tag_get(struct mbuf *m)
+{
+ struct m_tag *tag;
+
+ tag = m_tag_find(m, MTAG_WIREGUARD, NULL);
+ if (tag == NULL) {
+ tag = m_tag_get(MTAG_WIREGUARD, sizeof(struct wg_tag), M_NOWAIT|M_ZERO);
+ m_tag_prepend(m, tag);
+ MPASS(!SLIST_EMPTY(&m->m_pkthdr.tags));
+ MPASS(m_tag_locate(m, MTAG_ABI_COMPAT, MTAG_WIREGUARD, NULL) == tag);
+ }
+ return (struct wg_tag *)tag;
+}
+
+static struct wg_endpoint *
+wg_mbuf_endpoint_get(struct mbuf *m)
+{
+ struct wg_tag *hdr;
+
+ if ((hdr = wg_tag_get(m)) == NULL)
+ return (NULL);
+
+ return (&hdr->t_endpoint);
+}
+
+/* Timers */
+static void
+wg_timers_init(struct wg_timers *t)
+{
+ bzero(t, sizeof(*t));
+
+ t->t_disabled = 1;
+ rw_init(&t->t_lock, "wg peer timers");
+ callout_init(&t->t_retry_handshake, true);
+ callout_init(&t->t_send_keepalive, true);
+ callout_init(&t->t_new_handshake, true);
+ callout_init(&t->t_zero_key_material, true);
+ callout_init(&t->t_persistent_keepalive, true);
+}
+
+static void
+wg_timers_enable(struct wg_timers *t)
+{
+ rw_wlock(&t->t_lock);
+ t->t_disabled = 0;
+ rw_wunlock(&t->t_lock);
+ wg_timers_run_persistent_keepalive(t);
+}
+
+static void
+wg_timers_disable(struct wg_timers *t)
+{
+ rw_wlock(&t->t_lock);
+ t->t_disabled = 1;
+ t->t_need_another_keepalive = 0;
+ rw_wunlock(&t->t_lock);
+
+ callout_stop(&t->t_retry_handshake);
+ callout_stop(&t->t_send_keepalive);
+ callout_stop(&t->t_new_handshake);
+ callout_stop(&t->t_zero_key_material);
+ callout_stop(&t->t_persistent_keepalive);
+}
+
+static void
+wg_timers_set_persistent_keepalive(struct wg_timers *t, uint16_t interval)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled) {
+ t->t_persistent_keepalive_interval = interval;
+ wg_timers_run_persistent_keepalive(t);
+ }
+ rw_runlock(&t->t_lock);
+}
+
+static void
+wg_timers_get_last_handshake(struct wg_timers *t, struct timespec *time)
+{
+ rw_rlock(&t->t_lock);
+ time->tv_sec = t->t_handshake_complete.tv_sec;
+ time->tv_nsec = t->t_handshake_complete.tv_nsec;
+ rw_runlock(&t->t_lock);
+}
+
+static int
+wg_timers_expired_handshake_last_sent(struct wg_timers *t)
+{
+ struct timespec uptime;
+ struct timespec expire = { .tv_sec = REKEY_TIMEOUT, .tv_nsec = 0 };
+
+ getnanouptime(&uptime);
+ timespecadd(&t->t_handshake_last_sent, &expire, &expire);
+ return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0;
+}
+
+static int
+wg_timers_check_handshake_last_sent(struct wg_timers *t)
+{
+ int ret;
+
+ rw_wlock(&t->t_lock);
+ if ((ret = wg_timers_expired_handshake_last_sent(t)) == ETIMEDOUT)
+ getnanouptime(&t->t_handshake_last_sent);
+ rw_wunlock(&t->t_lock);
+ return (ret);
+}
+
+/* Should be called after an authenticated data packet is sent. */
+static void
+wg_timers_event_data_sent(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled && !callout_pending(&t->t_new_handshake))
+ callout_reset(&t->t_new_handshake, MSEC_2_TICKS(
+ NEW_HANDSHAKE_TIMEOUT * 1000 +
+ arc4random_uniform(REKEY_TIMEOUT_JITTER)),
+ (timeout_t *)wg_timers_run_new_handshake, t);
+ rw_runlock(&t->t_lock);
+}
+
+/* Should be called after an authenticated data packet is received. */
+static void
+wg_timers_event_data_received(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled) {
+ if (!callout_pending(&t->t_send_keepalive)) {
+ callout_reset(&t->t_send_keepalive,
+ MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
+ (timeout_t *)wg_timers_run_send_keepalive, t);
+ } else {
+ t->t_need_another_keepalive = 1;
+ }
+ }
+ rw_runlock(&t->t_lock);
+}
+
+/*
+ * Should be called after any type of authenticated packet is sent, whether
+ * keepalive, data, or handshake.
+ */
+static void
+wg_timers_event_any_authenticated_packet_sent(struct wg_timers *t)
+{
+ callout_stop(&t->t_send_keepalive);
+}
+
+/*
+ * Should be called after any type of authenticated packet is received, whether
+ * keepalive, data, or handshake.
+ */
+static void
+wg_timers_event_any_authenticated_packet_received(struct wg_timers *t)
+{
+ callout_stop(&t->t_new_handshake);
+}
+
+/*
+ * Should be called before a packet with authentication, whether
+ * keepalive, data, or handshake is sent, or after one is received.
+ */
+static void
+wg_timers_event_any_authenticated_packet_traversal(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled && t->t_persistent_keepalive_interval > 0)
+ callout_reset(&t->t_persistent_keepalive,
+ MSEC_2_TICKS(t->t_persistent_keepalive_interval * 1000),
+ (timeout_t *)wg_timers_run_persistent_keepalive, t);
+ rw_runlock(&t->t_lock);
+}
+
+/* Should be called after a handshake initiation message is sent. */
+static void
+wg_timers_event_handshake_initiated(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled)
+ callout_reset(&t->t_retry_handshake, MSEC_2_TICKS(
+ REKEY_TIMEOUT * 1000 +
+ arc4random_uniform(REKEY_TIMEOUT_JITTER)),
+ (timeout_t *)wg_timers_run_retry_handshake, t);
+ rw_runlock(&t->t_lock);
+}
+
+static void
+wg_timers_event_handshake_responded(struct wg_timers *t)
+{
+ rw_wlock(&t->t_lock);
+ getnanouptime(&t->t_handshake_last_sent);
+ rw_wunlock(&t->t_lock);
+}
+
+/*
+ * Should be called after a handshake response message is received and processed
+ * or when getting key confirmation via the first data message.
+ */
+static void
+wg_timers_event_handshake_complete(struct wg_timers *t)
+{
+ rw_wlock(&t->t_lock);
+ if (!t->t_disabled) {
+ callout_stop(&t->t_retry_handshake);
+ t->t_handshake_retries = 0;
+ getnanotime(&t->t_handshake_complete);
+ wg_timers_run_send_keepalive(t);
+ }
+ rw_wunlock(&t->t_lock);
+}
+
+/*
+ * Should be called after an ephemeral key is created, which is before sending a
+ * handshake response or after receiving a handshake response.
+ */
+static void
+wg_timers_event_session_derived(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled) {
+ callout_reset(&t->t_zero_key_material,
+ MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
+ (timeout_t *)wg_timers_run_zero_key_material, t);
+ }
+ rw_runlock(&t->t_lock);
+}
+
+static void
+wg_timers_event_want_initiation(struct wg_timers *t)
+{
+ rw_rlock(&t->t_lock);
+ if (!t->t_disabled)
+ wg_timers_run_send_initiation(t, 0);
+ rw_runlock(&t->t_lock);
+}
+
+static void
+wg_timers_event_reset_handshake_last_sent(struct wg_timers *t)
+{
+ rw_wlock(&t->t_lock);
+ t->t_handshake_last_sent.tv_sec -= (REKEY_TIMEOUT + 1);
+ rw_wunlock(&t->t_lock);
+}
+
+static void
+wg_timers_run_send_initiation(struct wg_timers *t, int is_retry)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+ if (!is_retry)
+ t->t_handshake_retries = 0;
+ if (wg_timers_expired_handshake_last_sent(t) == ETIMEDOUT)
+ GROUPTASK_ENQUEUE(&peer->p_send_initiation);
+}
+
+static void
+wg_timers_run_retry_handshake(struct wg_timers *t)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+
+ rw_wlock(&t->t_lock);
+ if (t->t_handshake_retries <= MAX_TIMER_HANDSHAKES) {
+ t->t_handshake_retries++;
+ rw_wunlock(&t->t_lock);
+
+ DPRINTF(peer->p_sc, "Handshake for peer %llu did not complete "
+ "after %d seconds, retrying (try %d)\n",
+ (unsigned long long)peer->p_id,
+ REKEY_TIMEOUT, t->t_handshake_retries + 1);
+ wg_peer_clear_src(peer);
+ wg_timers_run_send_initiation(t, 1);
+ } else {
+ rw_wunlock(&t->t_lock);
+
+ DPRINTF(peer->p_sc, "Handshake for peer %llu did not complete "
+ "after %d retries, giving up\n",
+ (unsigned long long) peer->p_id, MAX_TIMER_HANDSHAKES + 2);
+
+ callout_stop(&t->t_send_keepalive);
+ wg_queue_purge(&peer->p_stage_queue);
+ if (!callout_pending(&t->t_zero_key_material))
+ callout_reset(&t->t_zero_key_material,
+ MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
+ (timeout_t *)wg_timers_run_zero_key_material, t);
+ }
+}
+
+static void
+wg_timers_run_send_keepalive(struct wg_timers *t)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+
+ GROUPTASK_ENQUEUE(&peer->p_send_keepalive);
+ if (t->t_need_another_keepalive) {
+ t->t_need_another_keepalive = 0;
+ callout_reset(&t->t_send_keepalive,
+ MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
+ (timeout_t *)wg_timers_run_send_keepalive, t);
+ }
+}
+
+static void
+wg_timers_run_new_handshake(struct wg_timers *t)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+
+ DPRINTF(peer->p_sc, "Retrying handshake with peer %llu because we "
+ "stopped hearing back after %d seconds\n",
+ (unsigned long long)peer->p_id, NEW_HANDSHAKE_TIMEOUT);
+ wg_peer_clear_src(peer);
+
+ wg_timers_run_send_initiation(t, 0);
+}
+
+static void
+wg_timers_run_zero_key_material(struct wg_timers *t)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+
+ DPRINTF(peer->p_sc, "Zeroing out all keys for peer %llu, since we "
+ "haven't received a new one in %d seconds\n",
+ (unsigned long long)peer->p_id, REJECT_AFTER_TIME * 3);
+ GROUPTASK_ENQUEUE(&peer->p_clear_secrets);
+}
+
+static void
+wg_timers_run_persistent_keepalive(struct wg_timers *t)
+{
+ struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers);
+
+ if (t->t_persistent_keepalive_interval != 0)
+ GROUPTASK_ENQUEUE(&peer->p_send_keepalive);
+}
+
+/* TODO Handshake */
+static void
+wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len)
+{
+ struct wg_endpoint endpoint;
+
+ counter_u64_add(peer->p_tx_bytes, len);
+ wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers);
+ wg_timers_event_any_authenticated_packet_sent(&peer->p_timers);
+ wg_peer_get_endpoint(peer, &endpoint);
+ wg_send_buf(peer->p_sc, &endpoint, buf, len);
+}
+
+static void
+wg_send_initiation(struct wg_peer *peer)
+{
+ struct wg_pkt_initiation pkt;
+ struct epoch_tracker et;
+
+ if (wg_timers_check_handshake_last_sent(&peer->p_timers) != ETIMEDOUT)
+ return;
+ DPRINTF(peer->p_sc, "Sending handshake initiation to peer %llu\n",
+ (unsigned long long)peer->p_id);
+
+ NET_EPOCH_ENTER(et);
+ if (noise_create_initiation(&peer->p_remote, &pkt.s_idx, pkt.ue,
+ pkt.es, pkt.ets) != 0)
+ goto out;
+ pkt.t = WG_PKT_INITIATION;
+ cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
+ sizeof(pkt)-sizeof(pkt.m));
+ wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
+ wg_timers_event_handshake_initiated(&peer->p_timers);
+out:
+ NET_EPOCH_EXIT(et);
+}
+
+static void
+wg_send_response(struct wg_peer *peer)
+{
+ struct wg_pkt_response pkt;
+ struct epoch_tracker et;
+
+ NET_EPOCH_ENTER(et);
+
+ DPRINTF(peer->p_sc, "Sending handshake response to peer %llu\n",
+ (unsigned long long)peer->p_id);
+
+ if (noise_create_response(&peer->p_remote, &pkt.s_idx, &pkt.r_idx,
+ pkt.ue, pkt.en) != 0)
+ goto out;
+ if (noise_remote_begin_session(&peer->p_remote) != 0)
+ goto out;
+
+ wg_timers_event_session_derived(&peer->p_timers);
+ pkt.t = WG_PKT_RESPONSE;
+ cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
+ sizeof(pkt)-sizeof(pkt.m));
+ wg_timers_event_handshake_responded(&peer->p_timers);
+ wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt));
+out:
+ NET_EPOCH_EXIT(et);
+}
+
+static void
+wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx,
+ struct mbuf *m)
+{
+ struct wg_pkt_cookie pkt;
+ struct wg_endpoint *e;
+
+ DPRINTF(sc, "Sending cookie response for denied handshake message\n");
+
+ pkt.t = WG_PKT_COOKIE;
+ pkt.r_idx = idx;
+
+ e = wg_mbuf_endpoint_get(m);
+ cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce,
+ pkt.ec, &e->e_remote.r_sa);
+ wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt));
+}
+
+static void
+wg_send_keepalive(struct wg_peer *peer)
+{
+ struct mbuf *m = NULL;
+ struct wg_tag *t;
+ struct epoch_tracker et;
+
+ if (wg_queue_len(&peer->p_stage_queue) != 0) {
+ NET_EPOCH_ENTER(et);
+ goto send;
+ }
+ if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL)
+ return;
+ if ((t = wg_tag_get(m)) == NULL) {
+ m_freem(m);
+ return;
+ }
+ t->t_peer = peer;
+ t->t_mbuf = NULL;
+ t->t_done = 0;
+ t->t_mtu = 0; /* MTU == 0 OK for keepalive */
+
+ NET_EPOCH_ENTER(et);
+ wg_queue_stage(peer, m);
+send:
+ wg_queue_out(peer);
+ NET_EPOCH_EXIT(et);
+}
+
+static int
+wg_cookie_validate_packet(struct cookie_checker *checker, struct mbuf *m,
+ int under_load)
+{
+ struct wg_pkt_initiation *init;
+ struct wg_pkt_response *resp;
+ struct cookie_macs *macs;
+ struct wg_endpoint *e;
+ int type, size;
+ void *data;
+
+ type = *mtod(m, uint32_t *);
+ data = m->m_data;
+ e = wg_mbuf_endpoint_get(m);
+ if (type == WG_PKT_INITIATION) {
+ init = mtod(m, struct wg_pkt_initiation *);
+ macs = &init->m;
+ size = sizeof(*init) - sizeof(*macs);
+ } else if (type == WG_PKT_RESPONSE) {
+ resp = mtod(m, struct wg_pkt_response *);
+ macs = &resp->m;
+ size = sizeof(*resp) - sizeof(*macs);
+ } else
+ return 0;
+
+ return (cookie_checker_validate_macs(checker, macs, data, size,
+ under_load, &e->e_remote.r_sa));
+}
+
+
+static void
+wg_handshake(struct wg_softc *sc, struct mbuf *m)
+{
+ struct wg_pkt_initiation *init;
+ struct wg_pkt_response *resp;
+ struct noise_remote *remote;
+ struct wg_pkt_cookie *cook;
+ struct wg_peer *peer;
+ struct wg_tag *t;
+
+ /* This is global, so that our load calculation applies to the whole
+ * system. We don't care about races with it at all.
+ */
+ static struct timeval wg_last_underload;
+ static const struct timeval underload_interval = { UNDERLOAD_TIMEOUT, 0 };
+ bool packet_needs_cookie = false;
+ int underload, res;
+
+ underload = mbufq_len(&sc->sc_handshake_queue) >=
+ MAX_QUEUED_HANDSHAKES / 8;
+ if (underload)
+ getmicrouptime(&wg_last_underload);
+ else if (wg_last_underload.tv_sec != 0) {
+ if (!ratecheck(&wg_last_underload, &underload_interval))
+ underload = 1;
+ else
+ bzero(&wg_last_underload, sizeof(wg_last_underload));
+ }
+
+ res = wg_cookie_validate_packet(&sc->sc_cookie, m, underload);
+
+ if (res && res != EAGAIN) {
+ printf("validate_packet got %d\n", res);
+ goto free;
+ }
+ if (res == EINVAL) {
+ DPRINTF(sc, "Invalid initiation MAC\n");
+ goto free;
+ } else if (res == ECONNREFUSED) {
+ DPRINTF(sc, "Handshake ratelimited\n");
+ goto free;
+ } else if (res == EAGAIN) {
+ packet_needs_cookie = true;
+ } else if (res != 0) {
+ DPRINTF(sc, "Unexpected handshake ratelimit response: %d\n", res);
+ goto free;
+ }
+
+ t = wg_tag_get(m);
+ switch (*mtod(m, uint32_t *)) {
+ case WG_PKT_INITIATION:
+ init = mtod(m, struct wg_pkt_initiation *);
+
+ if (packet_needs_cookie) {
+ wg_send_cookie(sc, &init->m, init->s_idx, m);
+ goto free;
+ }
+ if (noise_consume_initiation(&sc->sc_local, &remote,
+ init->s_idx, init->ue, init->es, init->ets) != 0) {
+ DPRINTF(sc, "Invalid handshake initiation");
+ goto free;
+ }
+
+ peer = __containerof(remote, struct wg_peer, p_remote);
+ DPRINTF(sc, "Receiving handshake initiation from peer %llu\n",
+ (unsigned long long)peer->p_id);
+ counter_u64_add(peer->p_rx_bytes, sizeof(*init));
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, sizeof(*init));
+ wg_peer_set_endpoint_from_tag(peer, t);
+ wg_send_response(peer);
+ break;
+ case WG_PKT_RESPONSE:
+ resp = mtod(m, struct wg_pkt_response *);
+
+ if (packet_needs_cookie) {
+ wg_send_cookie(sc, &resp->m, resp->s_idx, m);
+ goto free;
+ }
+
+ if ((remote = wg_index_get(sc, resp->r_idx)) == NULL) {
+ DPRINTF(sc, "Unknown handshake response\n");
+ goto free;
+ }
+ peer = __containerof(remote, struct wg_peer, p_remote);
+ if (noise_consume_response(remote, resp->s_idx, resp->r_idx,
+ resp->ue, resp->en) != 0) {
+ DPRINTF(sc, "Invalid handshake response\n");
+ goto free;
+ }
+
+ DPRINTF(sc, "Receiving handshake response from peer %llu\n",
+ (unsigned long long)peer->p_id);
+ counter_u64_add(peer->p_rx_bytes, sizeof(*resp));
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, sizeof(*resp));
+ wg_peer_set_endpoint_from_tag(peer, t);
+ if (noise_remote_begin_session(&peer->p_remote) == 0) {
+ wg_timers_event_session_derived(&peer->p_timers);
+ wg_timers_event_handshake_complete(&peer->p_timers);
+ }
+ break;
+ case WG_PKT_COOKIE:
+ cook = mtod(m, struct wg_pkt_cookie *);
+
+ if ((remote = wg_index_get(sc, cook->r_idx)) == NULL) {
+ DPRINTF(sc, "Unknown cookie index\n");
+ goto free;
+ }
+
+ peer = __containerof(remote, struct wg_peer, p_remote);
+
+ if (cookie_maker_consume_payload(&peer->p_cookie,
+ cook->nonce, cook->ec) != 0) {
+ DPRINTF(sc, "Could not decrypt cookie response\n");
+ goto free;
+ }
+
+ DPRINTF(sc, "Receiving cookie response\n");
+ goto free;
+ default:
+ goto free;
+ }
+ MPASS(peer != NULL);
+ wg_timers_event_any_authenticated_packet_received(&peer->p_timers);
+ wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers);
+
+free:
+ m_freem(m);
+}
+
+static void
+wg_softc_handshake_receive(struct wg_softc *sc)
+{
+ struct mbuf *m;
+
+ while ((m = mbufq_dequeue(&sc->sc_handshake_queue)) != NULL)
+ wg_handshake(sc, m);
+}
+
+/* TODO Encrypt */
+static void
+wg_encap(struct wg_softc *sc, struct mbuf *m)
+{
+ struct wg_pkt_data *data;
+ size_t padding_len, plaintext_len, out_len;
+ struct mbuf *mc;
+ struct wg_peer *peer;
+ struct wg_tag *t;
+ uint64_t nonce;
+ int res, allocation_order;
+
+ NET_EPOCH_ASSERT();
+ t = wg_tag_get(m);
+ peer = t->t_peer;
+
+ plaintext_len = MIN(WG_PKT_WITH_PADDING(m->m_pkthdr.len), t->t_mtu);
+ padding_len = plaintext_len - m->m_pkthdr.len;
+ out_len = sizeof(struct wg_pkt_data) + plaintext_len + NOISE_AUTHTAG_LEN;
+
+ if (out_len <= MCLBYTES)
+ allocation_order = MCLBYTES;
+ else if (out_len <= MJUMPAGESIZE)
+ allocation_order = MJUMPAGESIZE;
+ else if (out_len <= MJUM9BYTES)
+ allocation_order = MJUM9BYTES;
+ else if (out_len <= MJUM16BYTES)
+ allocation_order = MJUM16BYTES;
+ else
+ goto error;
+
+ if ((mc = m_getjcl(M_NOWAIT, MT_DATA, M_PKTHDR, allocation_order)) == NULL)
+ goto error;
+
+ data = mtod(mc, struct wg_pkt_data *);
+ m_copydata(m, 0, m->m_pkthdr.len, data->buf);
+ bzero(data->buf + m->m_pkthdr.len, padding_len);
+
+ data->t = WG_PKT_DATA;
+
+ res = noise_remote_encrypt(&peer->p_remote, &data->r_idx, &nonce,
+ data->buf, plaintext_len);
+ nonce = htole64(nonce); /* Wire format is little endian. */
+ memcpy(data->nonce, &nonce, sizeof(data->nonce));
+
+ if (__predict_false(res)) {
+ if (res == EINVAL) {
+ wg_timers_event_want_initiation(&peer->p_timers);
+ m_freem(mc);
+ goto error;
+ } else if (res == ESTALE) {
+ wg_timers_event_want_initiation(&peer->p_timers);
+ } else {
+ m_freem(mc);
+ goto error;
+ }
+ }
+
+ /* A packet with length 0 is a keepalive packet */
+ if (m->m_pkthdr.len == 0)
+ DPRINTF(sc, "Sending keepalive packet to peer %llu\n",
+ (unsigned long long)peer->p_id);
+ /*
+ * Set the correct output value here since it will be copied
+ * when we move the pkthdr in send.
+ */
+ mc->m_len = mc->m_pkthdr.len = out_len;
+ mc->m_flags &= ~(M_MCAST | M_BCAST);
+
+ t->t_mbuf = mc;
+ error:
+ /* XXX membar ? */
+ t->t_done = 1;
+ GROUPTASK_ENQUEUE(&peer->p_send);
+}
+
+static void
+wg_decap(struct wg_softc *sc, struct mbuf *m)
+{
+ struct wg_pkt_data *data;
+ struct wg_peer *peer, *routed_peer;
+ struct wg_tag *t;
+ size_t plaintext_len;
+ uint8_t version;
+ uint64_t nonce;
+ int res;
+
+ NET_EPOCH_ASSERT();
+ data = mtod(m, struct wg_pkt_data *);
+ plaintext_len = m->m_pkthdr.len - sizeof(struct wg_pkt_data);
+
+ t = wg_tag_get(m);
+ peer = t->t_peer;
+
+ memcpy(&nonce, data->nonce, sizeof(nonce));
+ nonce = le64toh(nonce); /* Wire format is little endian. */
+
+ res = noise_remote_decrypt(&peer->p_remote, data->r_idx, nonce,
+ data->buf, plaintext_len);
+
+ if (__predict_false(res)) {
+ if (res == EINVAL) {
+ goto error;
+ } else if (res == ECONNRESET) {
+ wg_timers_event_handshake_complete(&peer->p_timers);
+ } else if (res == ESTALE) {
+ wg_timers_event_want_initiation(&peer->p_timers);
+ } else {
+ panic("unexpected response: %d\n", res);
+ }
+ }
+ wg_peer_set_endpoint_from_tag(peer, t);
+
+ /* Remove the data header, and crypto mac tail from the packet */
+ m_adj(m, sizeof(struct wg_pkt_data));
+ m_adj(m, -NOISE_AUTHTAG_LEN);
+
+ /* A packet with length 0 is a keepalive packet */
+ if (m->m_pkthdr.len == 0) {
+ DPRINTF(peer->p_sc, "Receiving keepalive packet from peer "
+ "%llu\n", (unsigned long long)peer->p_id);
+ goto done;
+ }
+
+ version = mtod(m, struct ip *)->ip_v;
+ if (!((version == 4 && m->m_pkthdr.len >= sizeof(struct ip)) ||
+ (version == 6 && m->m_pkthdr.len >= sizeof(struct ip6_hdr)))) {
+ DPRINTF(peer->p_sc, "Packet is neither ipv4 nor ipv6 from peer "
+ "%llu\n", (unsigned long long)peer->p_id);
+ goto error;
+ }
+
+ routed_peer = wg_aip_lookup(&peer->p_sc->sc_aips, m, IN);
+ if (routed_peer != peer) {
+ DPRINTF(peer->p_sc, "Packet has unallowed src IP from peer "
+ "%llu\n", (unsigned long long)peer->p_id);
+ goto error;
+ }
+
+done:
+ t->t_mbuf = m;
+error:
+ t->t_done = 1;
+ GROUPTASK_ENQUEUE(&peer->p_recv);
+}
+
+static void
+wg_softc_decrypt(struct wg_softc *sc)
+{
+ struct epoch_tracker et;
+ struct mbuf *m;
+
+ NET_EPOCH_ENTER(et);
+ while ((m = buf_ring_dequeue_mc(sc->sc_decap_ring)) != NULL)
+ wg_decap(sc, m);
+ NET_EPOCH_EXIT(et);
+}
+
+static void
+wg_softc_encrypt(struct wg_softc *sc)
+{
+ struct mbuf *m;
+ struct epoch_tracker et;
+
+ NET_EPOCH_ENTER(et);
+ while ((m = buf_ring_dequeue_mc(sc->sc_encap_ring)) != NULL)
+ wg_encap(sc, m);
+ NET_EPOCH_EXIT(et);
+}
+
+static void
+wg_encrypt_dispatch(struct wg_softc *sc)
+{
+ for (int i = 0; i < mp_ncpus; i++) {
+ if (sc->sc_encrypt[i].gt_task.ta_flags & TASK_ENQUEUED)
+ continue;
+ GROUPTASK_ENQUEUE(&sc->sc_encrypt[i]);
+ }
+}
+
+static void
+wg_decrypt_dispatch(struct wg_softc *sc)
+{
+ for (int i = 0; i < mp_ncpus; i++) {
+ if (sc->sc_decrypt[i].gt_task.ta_flags & TASK_ENQUEUED)
+ continue;
+ GROUPTASK_ENQUEUE(&sc->sc_decrypt[i]);
+ }
+}
+
+static void
+wg_deliver_out(struct wg_peer *peer)
+{
+ struct epoch_tracker et;
+ struct wg_tag *t;
+ struct mbuf *m;
+ struct wg_endpoint endpoint;
+ size_t len;
+ int ret;
+
+ NET_EPOCH_ENTER(et);
+ if (peer->p_sc->sc_ifp->if_link_state == LINK_STATE_DOWN)
+ goto done;
+
+ wg_peer_get_endpoint(peer, &endpoint);
+
+ while ((m = wg_queue_dequeue(&peer->p_encap_queue, &t)) != NULL) {
+ /* t_mbuf will contain the encrypted packet */
+ if (t->t_mbuf == NULL) {
+ if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OERRORS, 1);
+ m_freem(m);
+ continue;
+ }
+ len = t->t_mbuf->m_pkthdr.len;
+ ret = wg_send(peer->p_sc, &endpoint, t->t_mbuf);
+
+ if (ret == 0) {
+ wg_timers_event_any_authenticated_packet_traversal(
+ &peer->p_timers);
+ wg_timers_event_any_authenticated_packet_sent(
+ &peer->p_timers);
+
+ if (m->m_pkthdr.len != 0)
+ wg_timers_event_data_sent(&peer->p_timers);
+ counter_u64_add(peer->p_tx_bytes, len);
+ } else if (ret == EADDRNOTAVAIL) {
+ wg_peer_clear_src(peer);
+ wg_peer_get_endpoint(peer, &endpoint);
+ }
+ m_freem(m);
+ }
+done:
+ NET_EPOCH_EXIT(et);
+}
+
+static void
+wg_deliver_in(struct wg_peer *peer)
+{
+ struct mbuf *m;
+ struct ifnet *ifp;
+ struct wg_softc *sc;
+ struct epoch_tracker et;
+ struct wg_tag *t;
+ uint32_t af;
+ int version;
+
+ NET_EPOCH_ENTER(et);
+ sc = peer->p_sc;
+ ifp = sc->sc_ifp;
+
+ while ((m = wg_queue_dequeue(&peer->p_decap_queue, &t)) != NULL) {
+ /* t_mbuf will contain the encrypted packet */
+ if (t->t_mbuf == NULL) {
+ if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
+ m_freem(m);
+ continue;
+ }
+ MPASS(m == t->t_mbuf);
+
+ wg_timers_event_any_authenticated_packet_received(
+ &peer->p_timers);
+ wg_timers_event_any_authenticated_packet_traversal(
+ &peer->p_timers);
+
+ counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
+
+ if (m->m_pkthdr.len == 0) {
+ m_freem(m);
+ continue;
+ }
+
+ m->m_flags &= ~(M_MCAST | M_BCAST);
+ m->m_pkthdr.rcvif = ifp;
+ version = mtod(m, struct ip *)->ip_v;
+ if (version == IPVERSION) {
+ af = AF_INET;
+ BPF_MTAP2(ifp, &af, sizeof(af), m);
+ CURVNET_SET(ifp->if_vnet);
+ ip_input(m);
+ CURVNET_RESTORE();
+ } else if (version == 6) {
+ af = AF_INET6;
+ BPF_MTAP2(ifp, &af, sizeof(af), m);
+ CURVNET_SET(ifp->if_vnet);
+ ip6_input(m);
+ CURVNET_RESTORE();
+ } else
+ m_freem(m);
+
+ wg_timers_event_data_received(&peer->p_timers);
+ }
+ NET_EPOCH_EXIT(et);
+}
+
+static int
+wg_queue_in(struct wg_peer *peer, struct mbuf *m)
+{
+ struct buf_ring *parallel = peer->p_sc->sc_decap_ring;
+ struct wg_queue *serial = &peer->p_decap_queue;
+ struct wg_tag *t;
+ int rc;
+
+ MPASS(wg_tag_get(m) != NULL);
+
+ mtx_lock(&serial->q_mtx);
+ if ((rc = mbufq_enqueue(&serial->q, m)) == ENOBUFS) {
+ m_freem(m);
+ if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
+ } else {
+ m->m_flags |= M_ENQUEUED;
+ rc = buf_ring_enqueue(parallel, m);
+ if (rc == ENOBUFS) {
+ t = wg_tag_get(m);
+ t->t_done = 1;
+ }
+ }
+ mtx_unlock(&serial->q_mtx);
+ return (rc);
+}
+
+static void
+wg_queue_stage(struct wg_peer *peer, struct mbuf *m)
+{
+ struct wg_queue *q = &peer->p_stage_queue;
+ mtx_lock(&q->q_mtx);
+ STAILQ_INSERT_TAIL(&q->q.mq_head, m, m_stailqpkt);
+ q->q.mq_len++;
+ while (mbufq_full(&q->q)) {
+ m = mbufq_dequeue(&q->q);
+ if (m) {
+ m_freem(m);
+ if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
+ }
+ }
+ mtx_unlock(&q->q_mtx);
+}
+
+static void
+wg_queue_out(struct wg_peer *peer)
+{
+ struct buf_ring *parallel = peer->p_sc->sc_encap_ring;
+ struct wg_queue *serial = &peer->p_encap_queue;
+ struct wg_tag *t;
+ struct mbufq staged;
+ struct mbuf *m;
+
+ if (noise_remote_ready(&peer->p_remote) != 0) {
+ if (wg_queue_len(&peer->p_stage_queue))
+ wg_timers_event_want_initiation(&peer->p_timers);
+ return;
+ }
+
+ /* We first "steal" the staged queue to a local queue, so that we can do these
+ * remaining operations without having to hold the staged queue mutex. */
+ STAILQ_INIT(&staged.mq_head);
+ mtx_lock(&peer->p_stage_queue.q_mtx);
+ STAILQ_SWAP(&staged.mq_head, &peer->p_stage_queue.q.mq_head, mbuf);
+ staged.mq_len = peer->p_stage_queue.q.mq_len;
+ peer->p_stage_queue.q.mq_len = 0;
+ staged.mq_maxlen = peer->p_stage_queue.q.mq_maxlen;
+ mtx_unlock(&peer->p_stage_queue.q_mtx);
+
+ while ((m = mbufq_dequeue(&staged)) != NULL) {
+ if ((t = wg_tag_get(m)) == NULL) {
+ m_freem(m);
+ continue;
+ }
+ t->t_peer = peer;
+ mtx_lock(&serial->q_mtx);
+ if (mbufq_enqueue(&serial->q, m) != 0) {
+ m_freem(m);
+ if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
+ } else {
+ m->m_flags |= M_ENQUEUED;
+ if (buf_ring_enqueue(parallel, m)) {
+ t = wg_tag_get(m);
+ t->t_done = 1;
+ }
+ }
+ mtx_unlock(&serial->q_mtx);
+ }
+ wg_encrypt_dispatch(peer->p_sc);
+}
+
+static struct mbuf *
+wg_queue_dequeue(struct wg_queue *q, struct wg_tag **t)
+{
+ struct mbuf *m_, *m;
+
+ m = NULL;
+ mtx_lock(&q->q_mtx);
+ m_ = mbufq_first(&q->q);
+ if (m_ != NULL && (*t = wg_tag_get(m_))->t_done) {
+ m = mbufq_dequeue(&q->q);
+ m->m_flags &= ~M_ENQUEUED;
+ }
+ mtx_unlock(&q->q_mtx);
+ return (m);
+}
+
+static int
+wg_queue_len(struct wg_queue *q)
+{
+ /* This access races. We might consider adding locking here. */
+ return (mbufq_len(&q->q));
+}
+
+static void
+wg_queue_init(struct wg_queue *q, const char *name)
+{
+ mtx_init(&q->q_mtx, name, NULL, MTX_DEF);
+ mbufq_init(&q->q, MAX_QUEUED_PKT);
+}
+
+static void
+wg_queue_deinit(struct wg_queue *q)
+{
+ wg_queue_purge(q);
+ mtx_destroy(&q->q_mtx);
+}
+
+static void
+wg_queue_purge(struct wg_queue *q)
+{
+ mtx_lock(&q->q_mtx);
+ mbufq_drain(&q->q);
+ mtx_unlock(&q->q_mtx);
+}
+
+/* TODO Indexes */
+static struct noise_remote *
+wg_remote_get(struct wg_softc *sc, uint8_t public[NOISE_PUBLIC_KEY_LEN])
+{
+ struct wg_peer *peer;
+
+ if ((peer = wg_peer_lookup(sc, public)) == NULL)
+ return (NULL);
+ return (&peer->p_remote);
+}
+
+static uint32_t
+wg_index_set(struct wg_softc *sc, struct noise_remote *remote)
+{
+ struct wg_index *index, *iter;
+ struct wg_peer *peer;
+ uint32_t key;
+
+ /* We can modify this without a lock as wg_index_set, wg_index_drop are
+ * guaranteed to be serialised (per remote). */
+ peer = __containerof(remote, struct wg_peer, p_remote);
+ index = SLIST_FIRST(&peer->p_unused_index);
+ MPASS(index != NULL);
+ SLIST_REMOVE_HEAD(&peer->p_unused_index, i_unused_entry);
+
+ index->i_value = remote;
+
+ rw_wlock(&sc->sc_index_lock);
+assign_id:
+ key = index->i_key = arc4random();
+ key &= sc->sc_index_mask;
+ LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
+ if (iter->i_key == index->i_key)
+ goto assign_id;
+
+ LIST_INSERT_HEAD(&sc->sc_index[key], index, i_entry);
+
+ rw_wunlock(&sc->sc_index_lock);
+
+ /* Likewise, no need to lock for index here. */
+ return index->i_key;
+}
+
+static struct noise_remote *
+wg_index_get(struct wg_softc *sc, uint32_t key0)
+{
+ struct wg_index *iter;
+ struct noise_remote *remote = NULL;
+ uint32_t key = key0 & sc->sc_index_mask;
+
+ rw_enter_read(&sc->sc_index_lock);
+ LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
+ if (iter->i_key == key0) {
+ remote = iter->i_value;
+ break;
+ }
+ rw_exit_read(&sc->sc_index_lock);
+ return remote;
+}
+
+static void
+wg_index_drop(struct wg_softc *sc, uint32_t key0)
+{
+ struct wg_index *iter;
+ struct wg_peer *peer = NULL;
+ uint32_t key = key0 & sc->sc_index_mask;
+
+ rw_enter_write(&sc->sc_index_lock);
+ LIST_FOREACH(iter, &sc->sc_index[key], i_entry)
+ if (iter->i_key == key0) {
+ LIST_REMOVE(iter, i_entry);
+ break;
+ }
+ rw_exit_write(&sc->sc_index_lock);
+
+ if (iter == NULL)
+ return;
+
+ /* We expect a peer */
+ peer = __containerof(iter->i_value, struct wg_peer, p_remote);
+ MPASS(peer != NULL);
+ SLIST_INSERT_HEAD(&peer->p_unused_index, iter, i_unused_entry);
+}
+
+static int
+wg_update_endpoint_addrs(struct wg_endpoint *e, const struct sockaddr *srcsa,
+ struct ifnet *rcvif)
+{
+ const struct sockaddr_in *sa4;
+#ifdef INET6
+ const struct sockaddr_in6 *sa6;
+#endif
+ int ret = 0;
+
+ /*
+ * UDP passes a 2-element sockaddr array: first element is the
+ * source addr/port, second the destination addr/port.
+ */
+ if (srcsa->sa_family == AF_INET) {
+ sa4 = (const struct sockaddr_in *)srcsa;
+ e->e_remote.r_sin = sa4[0];
+ e->e_local.l_in = sa4[1].sin_addr;
+#ifdef INET6
+ } else if (srcsa->sa_family == AF_INET6) {
+ sa6 = (const struct sockaddr_in6 *)srcsa;
+ e->e_remote.r_sin6 = sa6[0];
+ e->e_local.l_in6 = sa6[1].sin6_addr;
+#endif
+ } else {
+ ret = EAFNOSUPPORT;
+ }
+
+ return (ret);
+}
+
+static void
+wg_input(struct mbuf *m0, int offset, struct inpcb *inpcb,
+ const struct sockaddr *srcsa, void *_sc)
+{
+ struct wg_pkt_data *pkt_data;
+ struct wg_endpoint *e;
+ struct wg_softc *sc = _sc;
+ struct mbuf *m;
+ int pktlen, pkttype;
+ struct noise_remote *remote;
+ struct wg_tag *t;
+ void *data;
+
+ /* Caller provided us with srcsa, no need for this header. */
+ m_adj(m0, offset + sizeof(struct udphdr));
+
+ /*
+ * Ensure mbuf has at least enough contiguous data to peel off our
+ * headers at the beginning.
+ */
+ if ((m = m_defrag(m0, M_NOWAIT)) == NULL) {
+ m_freem(m0);
+ return;
+ }
+ data = mtod(m, void *);
+ pkttype = *(uint32_t*)data;
+ t = wg_tag_get(m);
+ if (t == NULL) {
+ goto free;
+ }
+ e = wg_mbuf_endpoint_get(m);
+
+ if (wg_update_endpoint_addrs(e, srcsa, m->m_pkthdr.rcvif)) {
+ goto free;
+ }
+
+ pktlen = m->m_pkthdr.len;
+
+ if ((pktlen == sizeof(struct wg_pkt_initiation) &&
+ pkttype == WG_PKT_INITIATION) ||
+ (pktlen == sizeof(struct wg_pkt_response) &&
+ pkttype == WG_PKT_RESPONSE) ||
+ (pktlen == sizeof(struct wg_pkt_cookie) &&
+ pkttype == WG_PKT_COOKIE)) {
+ if (mbufq_enqueue(&sc->sc_handshake_queue, m) == 0) {
+ GROUPTASK_ENQUEUE(&sc->sc_handshake);
+ } else {
+ DPRINTF(sc, "Dropping handshake packet\n");
+ m_freem(m);
+ }
+ } else if (pktlen >= sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN
+ && pkttype == WG_PKT_DATA) {
+
+ pkt_data = data;
+ remote = wg_index_get(sc, pkt_data->r_idx);
+ if (remote == NULL) {
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1);
+ m_freem(m);
+ } else if (buf_ring_count(sc->sc_decap_ring) > MAX_QUEUED_PKT) {
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
+ m_freem(m);
+ } else {
+ t->t_peer = __containerof(remote, struct wg_peer,
+ p_remote);
+ t->t_mbuf = NULL;
+ t->t_done = 0;
+
+ wg_queue_in(t->t_peer, m);
+ wg_decrypt_dispatch(sc);
+ }
+ } else {
+free:
+ m_freem(m);
+ }
+}
+
+static int
+wg_transmit(struct ifnet *ifp, struct mbuf *m)
+{
+ struct wg_softc *sc;
+ sa_family_t family;
+ struct epoch_tracker et;
+ struct wg_peer *peer;
+ struct wg_tag *t;
+ uint32_t af;
+ int rc;
+
+ /*
+ * Work around lifetime issue in the ipv6 mld code.
+ */
+ if (__predict_false(ifp->if_flags & IFF_DYING))
+ return (ENXIO);
+
+ rc = 0;
+ sc = ifp->if_softc;
+ if ((t = wg_tag_get(m)) == NULL) {
+ rc = ENOBUFS;
+ goto early_out;
+ }
+ af = m->m_pkthdr.ph_family;
+ BPF_MTAP2(ifp, &af, sizeof(af), m);
+
+ NET_EPOCH_ENTER(et);
+ peer = wg_aip_lookup(&sc->sc_aips, m, OUT);
+ if (__predict_false(peer == NULL)) {
+ rc = ENOKEY;
+ goto err;
+ }
+
+ family = peer->p_endpoint.e_remote.r_sa.sa_family;
+ if (__predict_false(family != AF_INET && family != AF_INET6)) {
+ DPRINTF(sc, "No valid endpoint has been configured or "
+ "discovered for peer %llu\n", (unsigned long long)peer->p_id);
+
+ rc = EHOSTUNREACH;
+ goto err;
+ }
+ t->t_peer = peer;
+ t->t_mbuf = NULL;
+ t->t_done = 0;
+ t->t_mtu = ifp->if_mtu;
+
+ wg_queue_stage(peer, m);
+ wg_queue_out(peer);
+ NET_EPOCH_EXIT(et);
+ return (rc);
+err:
+ NET_EPOCH_EXIT(et);
+early_out:
+ if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
+ /* TODO: send ICMP unreachable */
+ m_free(m);
+ return (rc);
+}
+
+static int
+wg_output(struct ifnet *ifp, struct mbuf *m, const struct sockaddr *sa, struct route *rt)
+{
+ m->m_pkthdr.ph_family = sa->sa_family;
+ return (wg_transmit(ifp, m));
+}
+
+static int
+wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
+{
+ uint8_t public[WG_KEY_SIZE];
+ const void *pub_key;
+ const struct sockaddr *endpoint;
+ int err;
+ size_t size;
+ struct wg_peer *peer = NULL;
+ bool need_insert = false;
+
+ sx_assert(&sc->sc_lock, SX_XLOCKED);
+
+ if (!nvlist_exists_binary(nvl, "public-key")) {
+ return (EINVAL);
+ }
+ pub_key = nvlist_get_binary(nvl, "public-key", &size);
+ if (size != WG_KEY_SIZE) {
+ return (EINVAL);
+ }
+ if (noise_local_keys(&sc->sc_local, public, NULL) == 0 &&
+ bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
+ return (0); // Silently ignored; not actually a failure.
+ }
+ peer = wg_peer_lookup(sc, pub_key);
+ if (nvlist_exists_bool(nvl, "remove") &&
+ nvlist_get_bool(nvl, "remove")) {
+ if (peer != NULL) {
+ wg_hashtable_peer_remove(&sc->sc_hashtable, peer);
+ wg_peer_destroy(peer);
+ }
+ return (0);
+ }
+ if (nvlist_exists_bool(nvl, "replace-allowedips") &&
+ nvlist_get_bool(nvl, "replace-allowedips") &&
+ peer != NULL) {
+
+ wg_aip_delete(&peer->p_sc->sc_aips, peer);
+ }
+ if (peer == NULL) {
+ if (sc->sc_peer_count >= MAX_PEERS_PER_IFACE)
+ return (E2BIG);
+ sc->sc_peer_count++;
+
+ need_insert = true;
+ peer = wg_peer_alloc(sc);
+ MPASS(peer != NULL);
+ noise_remote_init(&peer->p_remote, pub_key, &sc->sc_local);
+ cookie_maker_init(&peer->p_cookie, pub_key);
+ }
+ if (nvlist_exists_binary(nvl, "endpoint")) {
+ endpoint = nvlist_get_binary(nvl, "endpoint", &size);
+ if (size > sizeof(peer->p_endpoint.e_remote)) {
+ err = EINVAL;
+ goto out;
+ }
+ memcpy(&peer->p_endpoint.e_remote, endpoint, size);
+ }
+ if (nvlist_exists_binary(nvl, "preshared-key")) {
+ const void *key;
+
+ key = nvlist_get_binary(nvl, "preshared-key", &size);
+ if (size != WG_KEY_SIZE) {
+ err = EINVAL;
+ goto out;
+ }
+ noise_remote_set_psk(&peer->p_remote, key);
+ }
+ if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
+ uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
+ if (pki > UINT16_MAX) {
+ err = EINVAL;
+ goto out;
+ }
+ wg_timers_set_persistent_keepalive(&peer->p_timers, pki);
+ }
+ if (nvlist_exists_nvlist_array(nvl, "allowed-ips")) {
+ const void *binary;
+ uint64_t cidr;
+ const nvlist_t * const * aipl;
+ struct wg_allowedip aip;
+ size_t allowedip_count;
+
+ aipl = nvlist_get_nvlist_array(nvl, "allowed-ips",
+ &allowedip_count);
+ for (size_t idx = 0; idx < allowedip_count; idx++) {
+ if (!nvlist_exists_number(aipl[idx], "cidr"))
+ continue;
+ cidr = nvlist_get_number(aipl[idx], "cidr");
+ if (nvlist_exists_binary(aipl[idx], "ipv4")) {
+ binary = nvlist_get_binary(aipl[idx], "ipv4", &size);
+ if (binary == NULL || cidr > 32 || size != sizeof(aip.ip4)) {
+ err = EINVAL;
+ goto out;
+ }
+ aip.family = AF_INET;
+ memcpy(&aip.ip4, binary, sizeof(aip.ip4));
+ } else if (nvlist_exists_binary(aipl[idx], "ipv6")) {
+ binary = nvlist_get_binary(aipl[idx], "ipv6", &size);
+ if (binary == NULL || cidr > 128 || size != sizeof(aip.ip6)) {
+ err = EINVAL;
+ goto out;
+ }
+ aip.family = AF_INET6;
+ memcpy(&aip.ip6, binary, sizeof(aip.ip6));
+ } else {
+ continue;
+ }
+ aip.cidr = cidr;
+
+ if ((err = wg_aip_add(&sc->sc_aips, peer, &aip)) != 0) {
+ goto out;
+ }
+ }
+ }
+ if (need_insert) {
+ wg_hashtable_peer_insert(&sc->sc_hashtable, peer);
+ if (sc->sc_ifp->if_link_state == LINK_STATE_UP)
+ wg_timers_enable(&peer->p_timers);
+ }
+ return (0);
+
+out:
+ if (need_insert) /* If we fail, only destroy if it was new. */
+ wg_peer_destroy(peer);
+ return (err);
+}
+
+static int
+wgc_set(struct wg_softc *sc, struct wg_data_io *wgd)
+{
+ uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE];
+ struct ifnet *ifp;
+ void *nvlpacked;
+ nvlist_t *nvl;
+ ssize_t size;
+ int err;
+
+ ifp = sc->sc_ifp;
+ if (wgd->wgd_size == 0 || wgd->wgd_data == NULL)
+ return (EFAULT);
+
+ sx_xlock(&sc->sc_lock);
+
+ nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK);
+ err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size);
+ if (err)
+ goto out;
+ nvl = nvlist_unpack(nvlpacked, wgd->wgd_size, 0);
+ if (nvl == NULL) {
+ err = EBADMSG;
+ goto out;
+ }
+ if (nvlist_exists_bool(nvl, "replace-peers") &&
+ nvlist_get_bool(nvl, "replace-peers"))
+ wg_peer_remove_all(sc);
+ if (nvlist_exists_number(nvl, "listen-port")) {
+ uint64_t new_port = nvlist_get_number(nvl, "listen-port");
+ if (new_port > UINT16_MAX) {
+ err = EINVAL;
+ goto out;
+ }
+ if (new_port != sc->sc_socket.so_port) {
+ if ((ifp->if_drv_flags & IFF_DRV_RUNNING) != 0) {
+ if ((err = wg_socket_init(sc, new_port)) != 0)
+ goto out;
+ } else
+ sc->sc_socket.so_port = new_port;
+ }
+ }
+ if (nvlist_exists_binary(nvl, "private-key")) {
+ const void *key = nvlist_get_binary(nvl, "private-key", &size);
+ if (size != WG_KEY_SIZE) {
+ err = EINVAL;
+ goto out;
+ }
+
+ if (noise_local_keys(&sc->sc_local, NULL, private) != 0 ||
+ timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) {
+ struct noise_local *local;
+ struct wg_peer *peer;
+ struct wg_hashtable *ht = &sc->sc_hashtable;
+ bool has_identity;
+
+ if (curve25519_generate_public(public, key)) {
+ /* Peer conflict: remove conflicting peer. */
+ if ((peer = wg_peer_lookup(sc, public)) !=
+ NULL) {
+ wg_hashtable_peer_remove(ht, peer);
+ wg_peer_destroy(peer);
+ }
+ }
+
+ /*
+ * Set the private key and invalidate all existing
+ * handshakes.
+ */
+ local = &sc->sc_local;
+ noise_local_lock_identity(local);
+ /* Note: we might be removing the private key. */
+ has_identity = noise_local_set_private(local, key) == 0;
+ mtx_lock(&ht->h_mtx);
+ CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) {
+ noise_remote_precompute(&peer->p_remote);
+ wg_timers_event_reset_handshake_last_sent(
+ &peer->p_timers);
+ noise_remote_expire_current(&peer->p_remote);
+ }
+ mtx_unlock(&ht->h_mtx);
+ cookie_checker_update(&sc->sc_cookie,
+ has_identity ? public : NULL);
+ noise_local_unlock_identity(local);
+ }
+ }
+ if (nvlist_exists_number(nvl, "user-cookie")) {
+ uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie");
+ if (user_cookie > UINT32_MAX) {
+ err = EINVAL;
+ goto out;
+ }
+ wg_socket_set_cookie(sc, user_cookie);
+ }
+ if (nvlist_exists_nvlist_array(nvl, "peers")) {
+ size_t peercount;
+ const nvlist_t * const*nvl_peers;
+
+ nvl_peers = nvlist_get_nvlist_array(nvl, "peers", &peercount);
+ for (int i = 0; i < peercount; i++) {
+ err = wg_peer_add(sc, nvl_peers[i]);
+ if (err != 0)
+ goto out;
+ }
+ }
+
+ nvlist_destroy(nvl);
+out:
+ free(nvlpacked, M_TEMP);
+ sx_xunlock(&sc->sc_lock);
+ return (err);
+}
+
+static unsigned int
+in_mask2len(struct in_addr *mask)
+{
+ unsigned int x, y;
+ uint8_t *p;
+
+ p = (uint8_t *)mask;
+ for (x = 0; x < sizeof(*mask); x++) {
+ if (p[x] != 0xff)
+ break;
+ }
+ y = 0;
+ if (x < sizeof(*mask)) {
+ for (y = 0; y < NBBY; y++) {
+ if ((p[x] & (0x80 >> y)) == 0)
+ break;
+ }
+ }
+ return x * NBBY + y;
+}
+
+static int
+wg_peer_to_export(struct wg_peer *peer, struct wg_peer_export *exp)
+{
+ struct wg_endpoint *ep;
+ struct wg_aip *rt;
+ struct noise_remote *remote;
+ int i;
+
+ /* Non-sleepable context. */
+ NET_EPOCH_ASSERT();
+
+ bzero(&exp->endpoint, sizeof(exp->endpoint));
+ remote = &peer->p_remote;
+ ep = &peer->p_endpoint;
+ if (ep->e_remote.r_sa.sa_family != 0) {
+ exp->endpoint_sz = (ep->e_remote.r_sa.sa_family == AF_INET) ?
+ sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
+
+ memcpy(&exp->endpoint, &ep->e_remote, exp->endpoint_sz);
+ }
+
+ /* We always export it. */
+ (void)noise_remote_keys(remote, exp->public_key, exp->preshared_key);
+ exp->persistent_keepalive =
+ peer->p_timers.t_persistent_keepalive_interval;
+ wg_timers_get_last_handshake(&peer->p_timers, &exp->last_handshake);
+ exp->rx_bytes = counter_u64_fetch(peer->p_rx_bytes);
+ exp->tx_bytes = counter_u64_fetch(peer->p_tx_bytes);
+
+ exp->aip_count = 0;
+ CK_LIST_FOREACH(rt, &peer->p_aips, r_entry) {
+ exp->aip_count++;
+ }
+
+ /* Early success; no allowed-ips to copy out. */
+ if (exp->aip_count == 0)
+ return (0);
+
+ exp->aip = malloc(exp->aip_count * sizeof(*exp->aip), M_TEMP, M_NOWAIT);
+ if (exp->aip == NULL)
+ return (ENOMEM);
+
+ i = 0;
+ CK_LIST_FOREACH(rt, &peer->p_aips, r_entry) {
+ exp->aip[i].family = rt->r_addr.ss_family;
+ if (exp->aip[i].family == AF_INET) {
+ struct sockaddr_in *sin =
+ (struct sockaddr_in *)&rt->r_addr;
+
+ exp->aip[i].ip4 = sin->sin_addr;
+
+ sin = (struct sockaddr_in *)&rt->r_mask;
+ exp->aip[i].cidr = in_mask2len(&sin->sin_addr);
+ } else if (exp->aip[i].family == AF_INET6) {
+ struct sockaddr_in6 *sin6 =
+ (struct sockaddr_in6 *)&rt->r_addr;
+
+ exp->aip[i].ip6 = sin6->sin6_addr;
+
+ sin6 = (struct sockaddr_in6 *)&rt->r_mask;
+ exp->aip[i].cidr = in6_mask2len(&sin6->sin6_addr, NULL);
+ }
+ i++;
+ if (i == exp->aip_count)
+ break;
+ }
+
+ /* Again, AllowedIPs might have shrank; update it. */
+ exp->aip_count = i;
+
+ return (0);
+}
+
+static nvlist_t *
+wg_peer_export_to_nvl(struct wg_softc *sc, struct wg_peer_export *exp)
+{
+ struct wg_timespec64 ts64;
+ nvlist_t *nvl, **nvl_aips;
+ size_t i;
+ uint16_t family;
+
+ nvl_aips = NULL;
+ if ((nvl = nvlist_create(0)) == NULL)
+ return (NULL);
+
+ nvlist_add_binary(nvl, "public-key", exp->public_key,
+ sizeof(exp->public_key));
+ if (wgc_privileged(sc))
+ nvlist_add_binary(nvl, "preshared-key", exp->preshared_key,
+ sizeof(exp->preshared_key));
+ if (exp->endpoint_sz != 0)
+ nvlist_add_binary(nvl, "endpoint", &exp->endpoint,
+ exp->endpoint_sz);
+
+ if (exp->aip_count != 0) {
+ nvl_aips = mallocarray(exp->aip_count, sizeof(*nvl_aips),
+ M_WG, M_WAITOK | M_ZERO);
+ }
+
+ for (i = 0; i < exp->aip_count; i++) {
+ nvl_aips[i] = nvlist_create(0);
+ if (nvl_aips[i] == NULL)
+ goto err;
+ family = exp->aip[i].family;
+ nvlist_add_number(nvl_aips[i], "cidr", exp->aip[i].cidr);
+ if (family == AF_INET)
+ nvlist_add_binary(nvl_aips[i], "ipv4",
+ &exp->aip[i].ip4, sizeof(exp->aip[i].ip4));
+ else if (family == AF_INET6)
+ nvlist_add_binary(nvl_aips[i], "ipv6",
+ &exp->aip[i].ip6, sizeof(exp->aip[i].ip6));
+ }
+
+ if (i != 0) {
+ nvlist_add_nvlist_array(nvl, "allowed-ips",
+ (const nvlist_t *const *)nvl_aips, i);
+ }
+
+ for (i = 0; i < exp->aip_count; ++i)
+ nvlist_destroy(nvl_aips[i]);
+
+ free(nvl_aips, M_WG);
+ nvl_aips = NULL;
+
+ ts64.tv_sec = exp->last_handshake.tv_sec;
+ ts64.tv_nsec = exp->last_handshake.tv_nsec;
+ nvlist_add_binary(nvl, "last-handshake-time", &ts64, sizeof(ts64));
+
+ if (exp->persistent_keepalive != 0)
+ nvlist_add_number(nvl, "persistent-keepalive-interval",
+ exp->persistent_keepalive);
+
+ if (exp->rx_bytes != 0)
+ nvlist_add_number(nvl, "rx-bytes", exp->rx_bytes);
+ if (exp->tx_bytes != 0)
+ nvlist_add_number(nvl, "tx-bytes", exp->tx_bytes);
+
+ return (nvl);
+err:
+ for (i = 0; i < exp->aip_count && nvl_aips[i] != NULL; i++) {
+ nvlist_destroy(nvl_aips[i]);
+ }
+
+ free(nvl_aips, M_WG);
+ nvlist_destroy(nvl);
+ return (NULL);
+}
+
+static int
+wg_marshal_peers(struct wg_softc *sc, nvlist_t **nvlp, nvlist_t ***nvl_arrayp, int *peer_countp)
+{
+ struct wg_peer *peer;
+ int err, i, peer_count;
+ nvlist_t *nvl, **nvl_array;
+ struct epoch_tracker et;
+ struct wg_peer_export *wpe;
+
+ nvl = NULL;
+ nvl_array = NULL;
+ if (nvl_arrayp)
+ *nvl_arrayp = NULL;
+ if (nvlp)
+ *nvlp = NULL;
+ if (peer_countp)
+ *peer_countp = 0;
+ peer_count = sc->sc_hashtable.h_num_peers;
+ if (peer_count == 0) {
+ return (ENOENT);
+ }
+
+ if (nvlp && (nvl = nvlist_create(0)) == NULL)
+ return (ENOMEM);
+
+ err = i = 0;
+ nvl_array = malloc(peer_count*sizeof(void*), M_TEMP, M_WAITOK | M_ZERO);
+ wpe = malloc(peer_count*sizeof(*wpe), M_TEMP, M_WAITOK | M_ZERO);
+
+ NET_EPOCH_ENTER(et);
+ CK_LIST_FOREACH(peer, &sc->sc_hashtable.h_peers_list, p_entry) {
+ if ((err = wg_peer_to_export(peer, &wpe[i])) != 0) {
+ break;
+ }
+
+ i++;
+ if (i == peer_count)
+ break;
+ }
+ NET_EPOCH_EXIT(et);
+
+ if (err != 0)
+ goto out;
+
+ /* Update the peer count, in case we found fewer entries. */
+ *peer_countp = peer_count = i;
+ if (peer_count == 0) {
+ err = ENOENT;
+ goto out;
+ }
+
+ for (i = 0; i < peer_count; i++) {
+ int idx;
+
+ /*
+ * Peers are added to the list in reverse order, effectively,
+ * because it's simpler/quicker to add at the head every time.
+ *
+ * Export them in reverse order. No worries if we fail mid-way
+ * through, the cleanup below will DTRT.
+ */
+ idx = peer_count - i - 1;
+ nvl_array[idx] = wg_peer_export_to_nvl(sc, &wpe[i]);
+ if (nvl_array[idx] == NULL) {
+ break;
+ }
+ }
+
+ if (i < peer_count) {
+ /* Error! */
+ *peer_countp = 0;
+ err = ENOMEM;
+ } else if (nvl) {
+ nvlist_add_nvlist_array(nvl, "peers",
+ (const nvlist_t * const *)nvl_array, peer_count);
+ if ((err = nvlist_error(nvl))) {
+ goto out;
+ }
+ *nvlp = nvl;
+ }
+ *nvl_arrayp = nvl_array;
+ out:
+ if (err != 0) {
+ /* Note that nvl_array is populated in reverse order. */
+ for (i = 0; i < peer_count; i++) {
+ nvlist_destroy(nvl_array[i]);
+ }
+
+ free(nvl_array, M_TEMP);
+ if (nvl != NULL)
+ nvlist_destroy(nvl);
+ }
+
+ for (i = 0; i < peer_count; i++)
+ free(wpe[i].aip, M_TEMP);
+ free(wpe, M_TEMP);
+ return (err);
+}
+
+static int
+wgc_get(struct wg_softc *sc, struct wg_data_io *wgd)
+{
+ nvlist_t *nvl, **nvl_array;
+ void *packed;
+ size_t size;
+ int peer_count, err;
+
+ nvl = nvlist_create(0);
+ if (nvl == NULL)
+ return (ENOMEM);
+
+ sx_slock(&sc->sc_lock);
+
+ err = 0;
+ packed = NULL;
+ if (sc->sc_socket.so_port != 0)
+ nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
+ if (sc->sc_socket.so_user_cookie != 0)
+ nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie);
+ if (sc->sc_local.l_has_identity) {
+ nvlist_add_binary(nvl, "public-key", sc->sc_local.l_public, WG_KEY_SIZE);
+ if (wgc_privileged(sc))
+ nvlist_add_binary(nvl, "private-key", sc->sc_local.l_private, WG_KEY_SIZE);
+ }
+ if (sc->sc_hashtable.h_num_peers > 0) {
+ err = wg_marshal_peers(sc, NULL, &nvl_array, &peer_count);
+ if (err)
+ goto out_nvl;
+ nvlist_add_nvlist_array(nvl, "peers",
+ (const nvlist_t * const *)nvl_array, peer_count);
+ }
+ packed = nvlist_pack(nvl, &size);
+ if (packed == NULL) {
+ err = ENOMEM;
+ goto out_nvl;
+ }
+ if (wgd->wgd_size == 0) {
+ wgd->wgd_size = size;
+ goto out_packed;
+ }
+ if (wgd->wgd_size < size) {
+ err = ENOSPC;
+ goto out_packed;
+ }
+ if (wgd->wgd_data == NULL) {
+ err = EFAULT;
+ goto out_packed;
+ }
+ err = copyout(packed, wgd->wgd_data, size);
+ wgd->wgd_size = size;
+
+out_packed:
+ free(packed, M_NVLIST);
+out_nvl:
+ nvlist_destroy(nvl);
+ sx_sunlock(&sc->sc_lock);
+ return (err);
+}
+
+static int
+wg_ioctl(struct ifnet *ifp, u_long cmd, caddr_t data)
+{
+ struct wg_data_io *wgd = (struct wg_data_io *)data;
+ struct ifreq *ifr = (struct ifreq *)data;
+ struct wg_softc *sc = ifp->if_softc;
+ int ret = 0;
+
+ switch (cmd) {
+ case SIOCSWG:
+ ret = priv_check(curthread, PRIV_NET_WG);
+ if (ret == 0)
+ ret = wgc_set(sc, wgd);
+ break;
+ case SIOCGWG:
+ ret = wgc_get(sc, wgd);
+ break;
+ /* Interface IOCTLs */
+ case SIOCSIFADDR:
+ /*
+ * This differs from *BSD norms, but is more uniform with how
+ * WireGuard behaves elsewhere.
+ */
+ break;
+ case SIOCSIFFLAGS:
+ if ((ifp->if_flags & IFF_UP) != 0)
+ ret = wg_up(sc);
+ else
+ wg_down(sc);
+ break;
+ case SIOCSIFMTU:
+ if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > MAX_MTU)
+ ret = EINVAL;
+ else
+ ifp->if_mtu = ifr->ifr_mtu;
+ break;
+ case SIOCADDMULTI:
+ case SIOCDELMULTI:
+ break;
+ default:
+ ret = ENOTTY;
+ }
+
+ return ret;
+}
+
+static int
+wg_up(struct wg_softc *sc)
+{
+ struct wg_hashtable *ht = &sc->sc_hashtable;
+ struct ifnet *ifp = sc->sc_ifp;
+ struct wg_peer *peer;
+ int rc = EBUSY;
+
+ sx_xlock(&sc->sc_lock);
+ /* Jail's being removed, no more wg_up(). */
+ if ((sc->sc_flags & WGF_DYING) != 0)
+ goto out;
+
+ /* Silent success if we're already running. */
+ rc = 0;
+ if (ifp->if_drv_flags & IFF_DRV_RUNNING)
+ goto out;
+ ifp->if_drv_flags |= IFF_DRV_RUNNING;
+
+ rc = wg_socket_init(sc, sc->sc_socket.so_port);
+ if (rc == 0) {
+ mtx_lock(&ht->h_mtx);
+ CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) {
+ wg_timers_enable(&peer->p_timers);
+ wg_queue_out(peer);
+ }
+ mtx_unlock(&ht->h_mtx);
+
+ if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
+ } else {
+ ifp->if_drv_flags &= ~IFF_DRV_RUNNING;
+ }
+out:
+ sx_xunlock(&sc->sc_lock);
+ return (rc);
+}
+
+static void
+wg_down(struct wg_softc *sc)
+{
+ struct wg_hashtable *ht = &sc->sc_hashtable;
+ struct ifnet *ifp = sc->sc_ifp;
+ struct wg_peer *peer;
+
+ sx_xlock(&sc->sc_lock);
+ if (!(ifp->if_drv_flags & IFF_DRV_RUNNING)) {
+ sx_xunlock(&sc->sc_lock);
+ return;
+ }
+ ifp->if_drv_flags &= ~IFF_DRV_RUNNING;
+
+ mtx_lock(&ht->h_mtx);
+ CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) {
+ wg_queue_purge(&peer->p_stage_queue);
+ wg_timers_disable(&peer->p_timers);
+ }
+ mtx_unlock(&ht->h_mtx);
+
+ mbufq_drain(&sc->sc_handshake_queue);
+
+ mtx_lock(&ht->h_mtx);
+ CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) {
+ noise_remote_clear(&peer->p_remote);
+ wg_timers_event_reset_handshake_last_sent(&peer->p_timers);
+ }
+ mtx_unlock(&ht->h_mtx);
+
+ if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
+ wg_socket_uninit(sc);
+
+ sx_xunlock(&sc->sc_lock);
+}
+
+static void
+crypto_taskq_setup(struct wg_softc *sc)
+{
+
+ sc->sc_encrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK);
+ sc->sc_decrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK);
+
+ for (int i = 0; i < mp_ncpus; i++) {
+ GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
+ (gtask_fn_t *)wg_softc_encrypt, sc);
+ taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_encrypt[i], sc, i, NULL, NULL, "wg encrypt");
+ GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
+ (gtask_fn_t *)wg_softc_decrypt, sc);
+ taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_decrypt[i], sc, i, NULL, NULL, "wg decrypt");
+ }
+}
+
+static void
+crypto_taskq_destroy(struct wg_softc *sc)
+{
+ for (int i = 0; i < mp_ncpus; i++) {
+ taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_encrypt[i]);
+ taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_decrypt[i]);
+ }
+ free(sc->sc_encrypt, M_WG);
+ free(sc->sc_decrypt, M_WG);
+}
+
+static int
+wg_clone_create(struct if_clone *ifc, int unit, caddr_t params)
+{
+ struct wg_softc *sc;
+ struct ifnet *ifp;
+ struct noise_upcall noise_upcall;
+
+ sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO);
+ sc->sc_ucred = crhold(curthread->td_ucred);
+ ifp = sc->sc_ifp = if_alloc(IFT_WIREGUARD);
+ ifp->if_softc = sc;
+ if_initname(ifp, wgname, unit);
+
+ noise_upcall.u_arg = sc;
+ noise_upcall.u_remote_get =
+ (struct noise_remote *(*)(void *, uint8_t *))wg_remote_get;
+ noise_upcall.u_index_set =
+ (uint32_t (*)(void *, struct noise_remote *))wg_index_set;
+ noise_upcall.u_index_drop =
+ (void (*)(void *, uint32_t))wg_index_drop;
+ noise_local_init(&sc->sc_local, &noise_upcall);
+ cookie_checker_init(&sc->sc_cookie, ratelimit_zone);
+
+ sc->sc_socket.so_port = 0;
+
+ atomic_add_int(&clone_count, 1);
+ ifp->if_capabilities = ifp->if_capenable = WG_CAPS;
+
+ mbufq_init(&sc->sc_handshake_queue, MAX_QUEUED_HANDSHAKES);
+ sx_init(&sc->sc_lock, "wg softc lock");
+ rw_init(&sc->sc_index_lock, "wg index lock");
+ sc->sc_peer_count = 0;
+ sc->sc_encap_ring = buf_ring_alloc(MAX_QUEUED_PKT, M_WG, M_WAITOK, NULL);
+ sc->sc_decap_ring = buf_ring_alloc(MAX_QUEUED_PKT, M_WG, M_WAITOK, NULL);
+ GROUPTASK_INIT(&sc->sc_handshake, 0,
+ (gtask_fn_t *)wg_softc_handshake_receive, sc);
+ taskqgroup_attach(qgroup_if_io_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation");
+ crypto_taskq_setup(sc);
+
+ wg_hashtable_init(&sc->sc_hashtable);
+ sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF, &sc->sc_index_mask);
+ wg_aip_init(&sc->sc_aips);
+
+ if_setmtu(ifp, ETHERMTU - 80);
+ ifp->if_flags = IFF_BROADCAST | IFF_MULTICAST | IFF_NOARP;
+ ifp->if_init = wg_init;
+ ifp->if_reassign = wg_reassign;
+ ifp->if_qflush = wg_qflush;
+ ifp->if_transmit = wg_transmit;
+ ifp->if_output = wg_output;
+ ifp->if_ioctl = wg_ioctl;
+
+ if_attach(ifp);
+ bpfattach(ifp, DLT_NULL, sizeof(uint32_t));
+
+ sx_xlock(&wg_sx);
+ LIST_INSERT_HEAD(&wg_list, sc, sc_entry);
+ sx_xunlock(&wg_sx);
+
+ return 0;
+}
+
+static void
+wg_clone_destroy(struct ifnet *ifp)
+{
+ struct wg_softc *sc = ifp->if_softc;
+ struct ucred *cred;
+
+ sx_xlock(&wg_sx);
+ sx_xlock(&sc->sc_lock);
+ sc->sc_flags |= WGF_DYING;
+ cred = sc->sc_ucred;
+ sc->sc_ucred = NULL;
+ sx_xunlock(&sc->sc_lock);
+ LIST_REMOVE(sc, sc_entry);
+ sx_xunlock(&wg_sx);
+
+ if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
+
+ sx_xlock(&sc->sc_lock);
+ wg_socket_uninit(sc);
+ sx_xunlock(&sc->sc_lock);
+
+ /*
+ * No guarantees that all traffic have passed until the epoch has
+ * elapsed with the socket closed.
+ */
+ NET_EPOCH_WAIT();
+
+ taskqgroup_drain_all(qgroup_if_io_tqg);
+ sx_xlock(&sc->sc_lock);
+ wg_peer_remove_all(sc);
+ epoch_drain_callbacks(net_epoch_preempt);
+ sx_xunlock(&sc->sc_lock);
+ sx_destroy(&sc->sc_lock);
+ rw_destroy(&sc->sc_index_lock);
+ taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_handshake);
+ crypto_taskq_destroy(sc);
+ buf_ring_free(sc->sc_encap_ring, M_WG);
+ buf_ring_free(sc->sc_decap_ring, M_WG);
+
+ wg_aip_destroy(&sc->sc_aips);
+ wg_hashtable_destroy(&sc->sc_hashtable);
+
+ if (cred != NULL)
+ crfree(cred);
+ if_detach(sc->sc_ifp);
+ if_free(sc->sc_ifp);
+ /* Ensure any local/private keys are cleaned up */
+ explicit_bzero(sc, sizeof(*sc));
+ free(sc, M_WG);
+
+ atomic_add_int(&clone_count, -1);
+}
+
+static void
+wg_qflush(struct ifnet *ifp __unused)
+{
+}
+
+/*
+ * Privileged information (private-key, preshared-key) are only exported for
+ * root and jailed root by default.
+ */
+static bool
+wgc_privileged(struct wg_softc *sc)
+{
+ struct thread *td;
+
+ td = curthread;
+ return (priv_check(td, PRIV_NET_WG) == 0);
+}
+
+static void
+wg_reassign(struct ifnet *ifp, struct vnet *new_vnet __unused,
+ char *unused __unused)
+{
+ struct wg_softc *sc;
+
+ sc = ifp->if_softc;
+ wg_down(sc);
+}
+
+static void
+wg_init(void *xsc)
+{
+ struct wg_softc *sc;
+
+ sc = xsc;
+ wg_up(sc);
+}
+
+static void
+vnet_wg_init(const void *unused __unused)
+{
+
+ V_wg_cloner = if_clone_simple(wgname, wg_clone_create, wg_clone_destroy,
+ 0);
+}
+VNET_SYSINIT(vnet_wg_init, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
+ vnet_wg_init, NULL);
+
+static void
+vnet_wg_uninit(const void *unused __unused)
+{
+
+ if_clone_detach(V_wg_cloner);
+}
+VNET_SYSUNINIT(vnet_wg_uninit, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
+ vnet_wg_uninit, NULL);
+
+static int
+wg_prison_remove(void *obj, void *data __unused)
+{
+ const struct prison *pr = obj;
+ struct wg_softc *sc;
+ struct ucred *cred;
+ bool dying;
+
+ /*
+ * Do a pass through all if_wg interfaces and release creds on any from
+ * the jail that are supposed to be going away. This will, in turn, let
+ * the jail die so that we don't end up with Schrödinger's jail.
+ */
+ sx_slock(&wg_sx);
+ LIST_FOREACH(sc, &wg_list, sc_entry) {
+ cred = NULL;
+
+ sx_xlock(&sc->sc_lock);
+ dying = (sc->sc_flags & WGF_DYING) != 0;
+ if (!dying && sc->sc_ucred != NULL &&
+ sc->sc_ucred->cr_prison == pr) {
+ /* Home jail is going away. */
+ cred = sc->sc_ucred;
+ sc->sc_ucred = NULL;
+
+ sc->sc_flags |= WGF_DYING;
+ }
+
+ /*
+ * If this is our foreign vnet going away, we'll also down the
+ * link and kill the socket because traffic needs to stop. Any
+ * address will be revoked in the rehoming process.
+ */
+ if (cred != NULL || (!dying &&
+ sc->sc_ifp->if_vnet == pr->pr_vnet)) {
+ if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
+ /* Have to kill the sockets, as they also hold refs. */
+ wg_socket_uninit(sc);
+ }
+
+ sx_xunlock(&sc->sc_lock);
+
+ if (cred != NULL) {
+ CURVNET_SET(sc->sc_ifp->if_vnet);
+ if_purgeaddrs(sc->sc_ifp);
+ CURVNET_RESTORE();
+ crfree(cred);
+ }
+ }
+ sx_sunlock(&wg_sx);
+
+ return (0);
+}
+
+static void
+wg_module_init(void)
+{
+ osd_method_t methods[PR_MAXMETHOD] = {
+ [PR_METHOD_REMOVE] = wg_prison_remove,
+ };
+
+ ratelimit_zone = uma_zcreate("wg ratelimit", sizeof(struct ratelimit),
+ NULL, NULL, NULL, NULL, 0, 0);
+ wg_osd_jail_slot = osd_jail_register(NULL, methods);
+}
+
+static void
+wg_module_deinit(void)
+{
+
+ uma_zdestroy(ratelimit_zone);
+ osd_jail_deregister(wg_osd_jail_slot);
+
+ MPASS(LIST_EMPTY(&wg_list));
+}
+
+static int
+wg_module_event_handler(module_t mod, int what, void *arg)
+{
+
+ switch (what) {
+ case MOD_LOAD:
+ wg_module_init();
+ break;
+ case MOD_UNLOAD:
+ if (atomic_load_int(&clone_count) == 0)
+ wg_module_deinit();
+ else
+ return (EBUSY);
+ break;
+ default:
+ return (EOPNOTSUPP);
+ }
+ return (0);
+}
+
+static moduledata_t wg_moduledata = {
+ "wg",
+ wg_module_event_handler,
+ NULL
+};
+
+DECLARE_MODULE(wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
+MODULE_VERSION(wg, 1);
+MODULE_DEPEND(wg, crypto, 1, 1, 1);
diff --git a/src/if_wg.h b/src/if_wg.h
new file mode 100644
index 0000000..f137c93
--- /dev/null
+++ b/src/if_wg.h
@@ -0,0 +1,37 @@
+/* SPDX-License-Identifier: ISC
+ *
+ * Copyright (c) 2019 Matt Dunwoodie <ncon@noconroy.net>
+ *
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ *
+ * $FreeBSD$
+ */
+
+#ifndef __IF_WG_H__
+#define __IF_WG_H__
+
+#include <net/if.h>
+#include <netinet/in.h>
+
+struct wg_data_io {
+ char wgd_name[IFNAMSIZ];
+ void *wgd_data;
+ size_t wgd_size;
+};
+
+#define WG_KEY_SIZE 32
+
+#define SIOCSWG _IOWR('i', 210, struct wg_data_io)
+#define SIOCGWG _IOWR('i', 211, struct wg_data_io)
+
+#endif /* __IF_WG_H__ */
diff --git a/src/support.h b/src/support.h
new file mode 100644
index 0000000..f613f40
--- /dev/null
+++ b/src/support.h
@@ -0,0 +1,56 @@
+/* SPDX-License-Identifier: ISC
+ *
+ * Copyright (C) 2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2021 Matt Dunwoodie <ncon@noconroy.net>
+ */
+
+#ifndef _WG_SUPPORT
+#define _WG_SUPPORT
+
+#include <sys/types.h>
+#include <sys/limits.h>
+#include <sys/endian.h>
+#include <sys/libkern.h>
+#include <sys/malloc.h>
+#include <sys/proc.h>
+#include <sys/lock.h>
+#include <vm/uma.h>
+
+/* TODO the following is openbsd compat defines to allow us to copy the wg_*
+ * files from openbsd (almost) verbatim. this will greatly increase maintenance
+ * across the platforms. it should be moved to it's own file. the only thing
+ * we're missing from this is struct pool (freebsd: uma_zone_t), which isn't a
+ * show stopper, but is something worth considering in the future.
+ * - md */
+
+#define rw_assert_wrlock(x) rw_assert(x, RA_WLOCKED)
+#define rw_enter_write rw_wlock
+#define rw_exit_write rw_wunlock
+#define rw_enter_read rw_rlock
+#define rw_exit_read rw_runlock
+#define rw_exit rw_unlock
+
+#define RW_DOWNGRADE 1
+#define rw_enter(x, y) do { \
+ CTASSERT(y == RW_DOWNGRADE); \
+ rw_downgrade(x); \
+} while (0)
+
+MALLOC_DECLARE(M_WG);
+
+#include <crypto/siphash/siphash.h>
+typedef struct {
+ uint64_t k0;
+ uint64_t k1;
+} SIPHASH_KEY;
+
+static inline uint64_t
+siphash24(const SIPHASH_KEY *key, const void *src, size_t len)
+{
+ SIPHASH_CTX ctx;
+
+ return (SipHashX(&ctx, 2, 4, (const uint8_t *)key, src, len));
+}
+#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
+
+#endif
diff --git a/src/wg_cookie.c b/src/wg_cookie.c
new file mode 100644
index 0000000..bf0ce37
--- /dev/null
+++ b/src/wg_cookie.c
@@ -0,0 +1,427 @@
+/* SPDX-License-Identifier: ISC
+ *
+ * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
+ */
+
+#include <sys/types.h>
+#include <sys/systm.h>
+#include <sys/param.h>
+#include <sys/rwlock.h>
+#include <sys/malloc.h> /* Because systm doesn't include M_NOWAIT, M_DEVBUF */
+#include <sys/socket.h>
+
+#include "support.h"
+#include "wg_cookie.h"
+
+static void cookie_precompute_key(uint8_t *,
+ const uint8_t[COOKIE_INPUT_SIZE], const char *);
+static void cookie_macs_mac1(struct cookie_macs *, const void *, size_t,
+ const uint8_t[COOKIE_KEY_SIZE]);
+static void cookie_macs_mac2(struct cookie_macs *, const void *, size_t,
+ const uint8_t[COOKIE_COOKIE_SIZE]);
+static int cookie_timer_expired(struct timespec *, time_t, long);
+static void cookie_checker_make_cookie(struct cookie_checker *,
+ uint8_t[COOKIE_COOKIE_SIZE], struct sockaddr *);
+static int ratelimit_init(struct ratelimit *, uma_zone_t);
+static void ratelimit_deinit(struct ratelimit *);
+static void ratelimit_gc(struct ratelimit *, int);
+static int ratelimit_allow(struct ratelimit *, struct sockaddr *);
+
+/* Public Functions */
+void
+cookie_maker_init(struct cookie_maker *cp, const uint8_t key[COOKIE_INPUT_SIZE])
+{
+ bzero(cp, sizeof(*cp));
+ cookie_precompute_key(cp->cp_mac1_key, key, COOKIE_MAC1_KEY_LABEL);
+ cookie_precompute_key(cp->cp_cookie_key, key, COOKIE_COOKIE_KEY_LABEL);
+ rw_init(&cp->cp_lock, "cookie_maker");
+}
+
+int
+cookie_checker_init(struct cookie_checker *cc, uma_zone_t zone)
+{
+ int res;
+ bzero(cc, sizeof(*cc));
+
+ rw_init(&cc->cc_key_lock, "cookie_checker_key");
+ rw_init(&cc->cc_secret_lock, "cookie_checker_secret");
+
+ if ((res = ratelimit_init(&cc->cc_ratelimit_v4, zone)) != 0)
+ return res;
+#ifdef INET6
+ if ((res = ratelimit_init(&cc->cc_ratelimit_v6, zone)) != 0) {
+ ratelimit_deinit(&cc->cc_ratelimit_v4);
+ return res;
+ }
+#endif
+ return 0;
+}
+
+void
+cookie_checker_update(struct cookie_checker *cc,
+ const uint8_t key[COOKIE_INPUT_SIZE])
+{
+ rw_enter_write(&cc->cc_key_lock);
+ if (key) {
+ cookie_precompute_key(cc->cc_mac1_key, key, COOKIE_MAC1_KEY_LABEL);
+ cookie_precompute_key(cc->cc_cookie_key, key, COOKIE_COOKIE_KEY_LABEL);
+ } else {
+ bzero(cc->cc_mac1_key, sizeof(cc->cc_mac1_key));
+ bzero(cc->cc_cookie_key, sizeof(cc->cc_cookie_key));
+ }
+ rw_exit_write(&cc->cc_key_lock);
+}
+
+void
+cookie_checker_deinit(struct cookie_checker *cc)
+{
+ ratelimit_deinit(&cc->cc_ratelimit_v4);
+#ifdef INET6
+ ratelimit_deinit(&cc->cc_ratelimit_v6);
+#endif
+}
+
+void
+cookie_checker_create_payload(struct cookie_checker *cc,
+ struct cookie_macs *cm, uint8_t nonce[COOKIE_NONCE_SIZE],
+ uint8_t ecookie[COOKIE_ENCRYPTED_SIZE], struct sockaddr *sa)
+{
+ uint8_t cookie[COOKIE_COOKIE_SIZE];
+
+ cookie_checker_make_cookie(cc, cookie, sa);
+ arc4random_buf(nonce, COOKIE_NONCE_SIZE);
+
+ rw_enter_read(&cc->cc_key_lock);
+ xchacha20poly1305_encrypt(ecookie, cookie, COOKIE_COOKIE_SIZE,
+ cm->mac1, COOKIE_MAC_SIZE, nonce, cc->cc_cookie_key);
+ rw_exit_read(&cc->cc_key_lock);
+
+ explicit_bzero(cookie, sizeof(cookie));
+}
+
+int
+cookie_maker_consume_payload(struct cookie_maker *cp,
+ uint8_t nonce[COOKIE_NONCE_SIZE], uint8_t ecookie[COOKIE_ENCRYPTED_SIZE])
+{
+ int ret = 0;
+ uint8_t cookie[COOKIE_COOKIE_SIZE];
+
+ rw_enter_write(&cp->cp_lock);
+
+ if (cp->cp_mac1_valid == 0) {
+ ret = ETIMEDOUT;
+ goto error;
+ }
+
+ if (xchacha20poly1305_decrypt(cookie, ecookie, COOKIE_ENCRYPTED_SIZE,
+ cp->cp_mac1_last, COOKIE_MAC_SIZE, nonce, cp->cp_cookie_key) == 0) {
+ ret = EINVAL;
+ goto error;
+ }
+
+ memcpy(cp->cp_cookie, cookie, COOKIE_COOKIE_SIZE);
+ getnanouptime(&cp->cp_birthdate);
+ cp->cp_mac1_valid = 0;
+
+error:
+ rw_exit_write(&cp->cp_lock);
+ return ret;
+}
+
+void
+cookie_maker_mac(struct cookie_maker *cp, struct cookie_macs *cm, void *buf,
+ size_t len)
+{
+ rw_enter_read(&cp->cp_lock);
+
+ cookie_macs_mac1(cm, buf, len, cp->cp_mac1_key);
+
+ memcpy(cp->cp_mac1_last, cm->mac1, COOKIE_MAC_SIZE);
+ cp->cp_mac1_valid = 1;
+
+ if (!cookie_timer_expired(&cp->cp_birthdate,
+ COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY, 0))
+ cookie_macs_mac2(cm, buf, len, cp->cp_cookie);
+ else
+ bzero(cm->mac2, COOKIE_MAC_SIZE);
+
+ rw_exit_read(&cp->cp_lock);
+}
+
+int
+cookie_checker_validate_macs(struct cookie_checker *cc, struct cookie_macs *cm,
+ void *buf, size_t len, int busy, struct sockaddr *sa)
+{
+ struct cookie_macs our_cm;
+ uint8_t cookie[COOKIE_COOKIE_SIZE];
+
+ /* Validate incoming MACs */
+ rw_enter_read(&cc->cc_key_lock);
+ cookie_macs_mac1(&our_cm, buf, len, cc->cc_mac1_key);
+ rw_exit_read(&cc->cc_key_lock);
+
+ /* If mac1 is invald, we want to drop the packet */
+ if (timingsafe_bcmp(our_cm.mac1, cm->mac1, COOKIE_MAC_SIZE) != 0)
+ return EINVAL;
+
+ if (busy != 0) {
+ cookie_checker_make_cookie(cc, cookie, sa);
+ cookie_macs_mac2(&our_cm, buf, len, cookie);
+
+ /* If the mac2 is invalid, we want to send a cookie response */
+ if (timingsafe_bcmp(our_cm.mac2, cm->mac2, COOKIE_MAC_SIZE) != 0)
+ return EAGAIN;
+
+ /* If the mac2 is valid, we may want rate limit the peer.
+ * ratelimit_allow will return either 0 or ECONNREFUSED,
+ * implying there is no ratelimiting, or we should ratelimit
+ * (refuse) respectively. */
+ if (sa->sa_family == AF_INET)
+ return ratelimit_allow(&cc->cc_ratelimit_v4, sa);
+#ifdef INET6
+ else if (sa->sa_family == AF_INET6)
+ return ratelimit_allow(&cc->cc_ratelimit_v6, sa);
+#endif
+ else
+ return EAFNOSUPPORT;
+ }
+ return 0;
+}
+
+/* Private functions */
+static void
+cookie_precompute_key(uint8_t *key, const uint8_t input[COOKIE_INPUT_SIZE],
+ const char *label)
+{
+ struct blake2s_state blake;
+
+ blake2s_init(&blake, COOKIE_KEY_SIZE);
+ blake2s_update(&blake, label, strlen(label));
+ blake2s_update(&blake, input, COOKIE_INPUT_SIZE);
+ /* TODO we shouldn't need to provide outlen to _final. we can align
+ * this with openbsd after fixing the blake library. */
+ blake2s_final(&blake, key);
+}
+
+static void
+cookie_macs_mac1(struct cookie_macs *cm, const void *buf, size_t len,
+ const uint8_t key[COOKIE_KEY_SIZE])
+{
+ struct blake2s_state state;
+ blake2s_init_key(&state, COOKIE_MAC_SIZE, key, COOKIE_KEY_SIZE);
+ blake2s_update(&state, buf, len);
+ blake2s_final(&state, cm->mac1);
+}
+
+static void
+cookie_macs_mac2(struct cookie_macs *cm, const void *buf, size_t len,
+ const uint8_t key[COOKIE_COOKIE_SIZE])
+{
+ struct blake2s_state state;
+ blake2s_init_key(&state, COOKIE_MAC_SIZE, key, COOKIE_COOKIE_SIZE);
+ blake2s_update(&state, buf, len);
+ blake2s_update(&state, cm->mac1, COOKIE_MAC_SIZE);
+ blake2s_final(&state, cm->mac2);
+}
+
+static int
+cookie_timer_expired(struct timespec *birthdate, time_t sec, long nsec)
+{
+ struct timespec uptime;
+ struct timespec expire = { .tv_sec = sec, .tv_nsec = nsec };
+
+ if (birthdate->tv_sec == 0 && birthdate->tv_nsec == 0)
+ return ETIMEDOUT;
+
+ getnanouptime(&uptime);
+ timespecadd(birthdate, &expire, &expire);
+ return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0;
+}
+
+static void
+cookie_checker_make_cookie(struct cookie_checker *cc,
+ uint8_t cookie[COOKIE_COOKIE_SIZE], struct sockaddr *sa)
+{
+ struct blake2s_state state;
+
+ rw_enter_write(&cc->cc_secret_lock);
+ if (cookie_timer_expired(&cc->cc_secret_birthdate,
+ COOKIE_SECRET_MAX_AGE, 0)) {
+ arc4random_buf(cc->cc_secret, COOKIE_SECRET_SIZE);
+ getnanouptime(&cc->cc_secret_birthdate);
+ }
+ blake2s_init_key(&state, COOKIE_COOKIE_SIZE, cc->cc_secret,
+ COOKIE_SECRET_SIZE);
+ rw_exit_write(&cc->cc_secret_lock);
+
+ if (sa->sa_family == AF_INET) {
+ blake2s_update(&state, (uint8_t *)&satosin(sa)->sin_addr,
+ sizeof(struct in_addr));
+ blake2s_update(&state, (uint8_t *)&satosin(sa)->sin_port,
+ sizeof(in_port_t));
+ blake2s_final(&state, cookie);
+#ifdef INET6
+ } else if (sa->sa_family == AF_INET6) {
+ blake2s_update(&state, (uint8_t *)&satosin6(sa)->sin6_addr,
+ sizeof(struct in6_addr));
+ blake2s_update(&state, (uint8_t *)&satosin6(sa)->sin6_port,
+ sizeof(in_port_t));
+ blake2s_final(&state, cookie);
+#endif
+ } else {
+ arc4random_buf(cookie, COOKIE_COOKIE_SIZE);
+ }
+}
+
+static int
+ratelimit_init(struct ratelimit *rl, uma_zone_t zone)
+{
+ rw_init(&rl->rl_lock, "ratelimit_lock");
+ arc4random_buf(&rl->rl_secret, sizeof(rl->rl_secret));
+ rl->rl_table = hashinit_flags(RATELIMIT_SIZE, M_DEVBUF,
+ &rl->rl_table_mask, M_NOWAIT);
+ rl->rl_zone = zone;
+ rl->rl_table_num = 0;
+ return rl->rl_table == NULL ? ENOBUFS : 0;
+}
+
+static void
+ratelimit_deinit(struct ratelimit *rl)
+{
+ rw_enter_write(&rl->rl_lock);
+ ratelimit_gc(rl, 1);
+ hashdestroy(rl->rl_table, M_DEVBUF, rl->rl_table_mask);
+ rw_exit_write(&rl->rl_lock);
+}
+
+static void
+ratelimit_gc(struct ratelimit *rl, int force)
+{
+ size_t i;
+ struct ratelimit_entry *r, *tr;
+ struct timespec expiry;
+
+ rw_assert_wrlock(&rl->rl_lock);
+
+ if (force) {
+ for (i = 0; i < RATELIMIT_SIZE; i++) {
+ LIST_FOREACH_SAFE(r, &rl->rl_table[i], r_entry, tr) {
+ rl->rl_table_num--;
+ LIST_REMOVE(r, r_entry);
+ uma_zfree(rl->rl_zone, r);
+ }
+ }
+ return;
+ }
+
+ if ((cookie_timer_expired(&rl->rl_last_gc, ELEMENT_TIMEOUT, 0) &&
+ rl->rl_table_num > 0)) {
+ getnanouptime(&rl->rl_last_gc);
+ getnanouptime(&expiry);
+ expiry.tv_sec -= ELEMENT_TIMEOUT;
+
+ for (i = 0; i < RATELIMIT_SIZE; i++) {
+ LIST_FOREACH_SAFE(r, &rl->rl_table[i], r_entry, tr) {
+ if (timespeccmp(&r->r_last_time, &expiry, <)) {
+ rl->rl_table_num--;
+ LIST_REMOVE(r, r_entry);
+ uma_zfree(rl->rl_zone, r);
+ }
+ }
+ }
+ }
+}
+
+static int
+ratelimit_allow(struct ratelimit *rl, struct sockaddr *sa)
+{
+ uint64_t key, tokens;
+ struct timespec diff;
+ struct ratelimit_entry *r;
+ int ret = ECONNREFUSED;
+
+ if (sa->sa_family == AF_INET)
+ /* TODO siphash24 is the FreeBSD siphash, OK? */
+ key = siphash24(&rl->rl_secret, &satosin(sa)->sin_addr,
+ IPV4_MASK_SIZE);
+#ifdef INET6
+ else if (sa->sa_family == AF_INET6)
+ key = siphash24(&rl->rl_secret, &satosin6(sa)->sin6_addr,
+ IPV6_MASK_SIZE);
+#endif
+ else
+ return ret;
+
+ rw_enter_write(&rl->rl_lock);
+
+ LIST_FOREACH(r, &rl->rl_table[key & rl->rl_table_mask], r_entry) {
+ if (r->r_af != sa->sa_family)
+ continue;
+
+ if (r->r_af == AF_INET && bcmp(&r->r_in,
+ &satosin(sa)->sin_addr, IPV4_MASK_SIZE) != 0)
+ continue;
+
+#ifdef INET6
+ if (r->r_af == AF_INET6 && bcmp(&r->r_in6,
+ &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE) != 0)
+ continue;
+#endif
+
+ /* If we get to here, we've found an entry for the endpoint.
+ * We apply standard token bucket, by calculating the time
+ * lapsed since our last_time, adding that, ensuring that we
+ * cap the tokens at TOKEN_MAX. If the endpoint has no tokens
+ * left (that is tokens <= INITIATION_COST) then we block the
+ * request, otherwise we subtract the INITITIATION_COST and
+ * return OK. */
+ diff = r->r_last_time;
+ getnanouptime(&r->r_last_time);
+ timespecsub(&r->r_last_time, &diff, &diff);
+
+ tokens = r->r_tokens + diff.tv_sec * NSEC_PER_SEC + diff.tv_nsec;
+
+ if (tokens > TOKEN_MAX)
+ tokens = TOKEN_MAX;
+
+ if (tokens >= INITIATION_COST) {
+ r->r_tokens = tokens - INITIATION_COST;
+ goto ok;
+ } else {
+ r->r_tokens = tokens;
+ goto error;
+ }
+ }
+
+ /* If we get to here, we didn't have an entry for the endpoint. */
+ ratelimit_gc(rl, 0);
+
+ /* Hard limit on number of entries */
+ if (rl->rl_table_num >= RATELIMIT_SIZE_MAX)
+ goto error;
+
+ /* Goto error if out of memory */
+ if ((r = uma_zalloc(rl->rl_zone, M_NOWAIT)) == NULL)
+ goto error;
+
+ rl->rl_table_num++;
+
+ /* Insert entry into the hashtable and ensure it's initialised */
+ LIST_INSERT_HEAD(&rl->rl_table[key & rl->rl_table_mask], r, r_entry);
+ r->r_af = sa->sa_family;
+ if (r->r_af == AF_INET)
+ memcpy(&r->r_in, &satosin(sa)->sin_addr, IPV4_MASK_SIZE);
+#ifdef INET6
+ else if (r->r_af == AF_INET6)
+ memcpy(&r->r_in6, &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE);
+#endif
+
+ getnanouptime(&r->r_last_time);
+ r->r_tokens = TOKEN_MAX - INITIATION_COST;
+ok:
+ ret = 0;
+error:
+ rw_exit_write(&rl->rl_lock);
+ return ret;
+}
diff --git a/src/wg_cookie.h b/src/wg_cookie.h
new file mode 100644
index 0000000..c7338d8
--- /dev/null
+++ b/src/wg_cookie.h
@@ -0,0 +1,114 @@
+/* SPDX-License-Identifier: ISC
+ *
+ * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
+ */
+
+#ifndef __COOKIE_H__
+#define __COOKIE_H__
+
+#include <sys/types.h>
+#include <sys/time.h>
+#include <sys/rwlock.h>
+#include <sys/queue.h>
+
+#include <netinet/in.h>
+
+#include "crypto.h"
+
+#define COOKIE_MAC_SIZE 16
+#define COOKIE_KEY_SIZE 32
+#define COOKIE_NONCE_SIZE XCHACHA20POLY1305_NONCE_SIZE
+#define COOKIE_COOKIE_SIZE 16
+#define COOKIE_SECRET_SIZE 32
+#define COOKIE_INPUT_SIZE 32
+#define COOKIE_ENCRYPTED_SIZE (COOKIE_COOKIE_SIZE + COOKIE_MAC_SIZE)
+
+#define COOKIE_MAC1_KEY_LABEL "mac1----"
+#define COOKIE_COOKIE_KEY_LABEL "cookie--"
+#define COOKIE_SECRET_MAX_AGE 120
+#define COOKIE_SECRET_LATENCY 5
+
+/* Constants for initiation rate limiting */
+#define RATELIMIT_SIZE (1 << 13)
+#define RATELIMIT_SIZE_MAX (RATELIMIT_SIZE * 8)
+#define NSEC_PER_SEC 1000000000LL
+#define INITIATIONS_PER_SECOND 20
+#define INITIATIONS_BURSTABLE 5
+#define INITIATION_COST (NSEC_PER_SEC / INITIATIONS_PER_SECOND)
+#define TOKEN_MAX (INITIATION_COST * INITIATIONS_BURSTABLE)
+#define ELEMENT_TIMEOUT 1
+#define IPV4_MASK_SIZE 4 /* Use all 4 bytes of IPv4 address */
+#define IPV6_MASK_SIZE 8 /* Use top 8 bytes (/64) of IPv6 address */
+
+struct cookie_macs {
+ uint8_t mac1[COOKIE_MAC_SIZE];
+ uint8_t mac2[COOKIE_MAC_SIZE];
+};
+
+struct ratelimit_entry {
+ LIST_ENTRY(ratelimit_entry) r_entry;
+ sa_family_t r_af;
+ union {
+ struct in_addr r_in;
+#ifdef INET6
+ struct in6_addr r_in6;
+#endif
+ };
+ struct timespec r_last_time; /* nanouptime */
+ uint64_t r_tokens;
+};
+
+struct ratelimit {
+ SIPHASH_KEY rl_secret;
+ uma_zone_t rl_zone;
+
+ struct rwlock rl_lock;
+ LIST_HEAD(, ratelimit_entry) *rl_table;
+ u_long rl_table_mask;
+ size_t rl_table_num;
+ struct timespec rl_last_gc; /* nanouptime */
+};
+
+struct cookie_maker {
+ uint8_t cp_mac1_key[COOKIE_KEY_SIZE];
+ uint8_t cp_cookie_key[COOKIE_KEY_SIZE];
+
+ struct rwlock cp_lock;
+ uint8_t cp_cookie[COOKIE_COOKIE_SIZE];
+ struct timespec cp_birthdate; /* nanouptime */
+ int cp_mac1_valid;
+ uint8_t cp_mac1_last[COOKIE_MAC_SIZE];
+};
+
+struct cookie_checker {
+ struct ratelimit cc_ratelimit_v4;
+#ifdef INET6
+ struct ratelimit cc_ratelimit_v6;
+#endif
+
+ struct rwlock cc_key_lock;
+ uint8_t cc_mac1_key[COOKIE_KEY_SIZE];
+ uint8_t cc_cookie_key[COOKIE_KEY_SIZE];
+
+ struct rwlock cc_secret_lock;
+ struct timespec cc_secret_birthdate; /* nanouptime */
+ uint8_t cc_secret[COOKIE_SECRET_SIZE];
+};
+
+void cookie_maker_init(struct cookie_maker *, const uint8_t[COOKIE_INPUT_SIZE]);
+int cookie_checker_init(struct cookie_checker *, uma_zone_t);
+void cookie_checker_update(struct cookie_checker *,
+ const uint8_t[COOKIE_INPUT_SIZE]);
+void cookie_checker_deinit(struct cookie_checker *);
+void cookie_checker_create_payload(struct cookie_checker *,
+ struct cookie_macs *cm, uint8_t[COOKIE_NONCE_SIZE],
+ uint8_t [COOKIE_ENCRYPTED_SIZE], struct sockaddr *);
+int cookie_maker_consume_payload(struct cookie_maker *,
+ uint8_t[COOKIE_NONCE_SIZE], uint8_t[COOKIE_ENCRYPTED_SIZE]);
+void cookie_maker_mac(struct cookie_maker *, struct cookie_macs *,
+ void *, size_t);
+int cookie_checker_validate_macs(struct cookie_checker *,
+ struct cookie_macs *, void *, size_t, int, struct sockaddr *);
+
+#endif /* __COOKIE_H__ */
diff --git a/src/wg_noise.c b/src/wg_noise.c
new file mode 100644
index 0000000..42dcc87
--- /dev/null
+++ b/src/wg_noise.c
@@ -0,0 +1,952 @@
+/* SPDX-License-Identifier: ISC
+ *
+ * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
+ */
+
+#include <sys/types.h>
+#include <sys/systm.h>
+#include <sys/param.h>
+#include <sys/rwlock.h>
+
+#include "support.h"
+#include "wg_noise.h"
+
+/* Private functions */
+static struct noise_keypair *
+ noise_remote_keypair_allocate(struct noise_remote *);
+static void
+ noise_remote_keypair_free(struct noise_remote *,
+ struct noise_keypair *);
+static uint32_t noise_remote_handshake_index_get(struct noise_remote *);
+static void noise_remote_handshake_index_drop(struct noise_remote *);
+
+static uint64_t noise_counter_send(struct noise_counter *);
+static int noise_counter_recv(struct noise_counter *, uint64_t);
+
+static void noise_kdf(uint8_t *, uint8_t *, uint8_t *, const uint8_t *,
+ size_t, size_t, size_t, size_t,
+ const uint8_t [NOISE_HASH_LEN]);
+static int noise_mix_dh(
+ uint8_t [NOISE_HASH_LEN],
+ uint8_t [NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t [NOISE_PUBLIC_KEY_LEN],
+ const uint8_t [NOISE_PUBLIC_KEY_LEN]);
+static int noise_mix_ss(
+ uint8_t ck[NOISE_HASH_LEN],
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t ss[NOISE_PUBLIC_KEY_LEN]);
+static void noise_mix_hash(
+ uint8_t [NOISE_HASH_LEN],
+ const uint8_t *,
+ size_t);
+static void noise_mix_psk(
+ uint8_t [NOISE_HASH_LEN],
+ uint8_t [NOISE_HASH_LEN],
+ uint8_t [NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t [NOISE_SYMMETRIC_KEY_LEN]);
+static void noise_param_init(
+ uint8_t [NOISE_HASH_LEN],
+ uint8_t [NOISE_HASH_LEN],
+ const uint8_t [NOISE_PUBLIC_KEY_LEN]);
+
+static void noise_msg_encrypt(uint8_t *, const uint8_t *, size_t,
+ uint8_t [NOISE_SYMMETRIC_KEY_LEN],
+ uint8_t [NOISE_HASH_LEN]);
+static int noise_msg_decrypt(uint8_t *, const uint8_t *, size_t,
+ uint8_t [NOISE_SYMMETRIC_KEY_LEN],
+ uint8_t [NOISE_HASH_LEN]);
+static void noise_msg_ephemeral(
+ uint8_t [NOISE_HASH_LEN],
+ uint8_t [NOISE_HASH_LEN],
+ const uint8_t src[NOISE_PUBLIC_KEY_LEN]);
+
+static void noise_tai64n_now(uint8_t [NOISE_TIMESTAMP_LEN]);
+static int noise_timer_expired(struct timespec *, time_t, long);
+
+/* Set/Get noise parameters */
+void
+noise_local_init(struct noise_local *l, struct noise_upcall *upcall)
+{
+ bzero(l, sizeof(*l));
+ rw_init(&l->l_identity_lock, "noise_local_identity");
+ l->l_upcall = *upcall;
+}
+
+void
+noise_local_lock_identity(struct noise_local *l)
+{
+ rw_enter_write(&l->l_identity_lock);
+}
+
+void
+noise_local_unlock_identity(struct noise_local *l)
+{
+ rw_exit_write(&l->l_identity_lock);
+}
+
+int
+noise_local_set_private(struct noise_local *l,
+ const uint8_t private[NOISE_PUBLIC_KEY_LEN])
+{
+ rw_assert_wrlock(&l->l_identity_lock);
+
+ memcpy(l->l_private, private, NOISE_PUBLIC_KEY_LEN);
+ curve25519_clamp_secret(l->l_private);
+ l->l_has_identity = curve25519_generate_public(l->l_public, private);
+
+ return l->l_has_identity ? 0 : ENXIO;
+}
+
+int
+noise_local_keys(struct noise_local *l, uint8_t public[NOISE_PUBLIC_KEY_LEN],
+ uint8_t private[NOISE_PUBLIC_KEY_LEN])
+{
+ int ret = 0;
+ rw_enter_read(&l->l_identity_lock);
+ if (l->l_has_identity) {
+ if (public != NULL)
+ memcpy(public, l->l_public, NOISE_PUBLIC_KEY_LEN);
+ if (private != NULL)
+ memcpy(private, l->l_private, NOISE_PUBLIC_KEY_LEN);
+ } else {
+ ret = ENXIO;
+ }
+ rw_exit_read(&l->l_identity_lock);
+ return ret;
+}
+
+void
+noise_remote_init(struct noise_remote *r,
+ const uint8_t public[NOISE_PUBLIC_KEY_LEN], struct noise_local *l)
+{
+ bzero(r, sizeof(*r));
+ memcpy(r->r_public, public, NOISE_PUBLIC_KEY_LEN);
+ rw_init(&r->r_handshake_lock, "noise_handshake");
+ rw_init(&r->r_keypair_lock, "noise_keypair");
+
+ SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[0], kp_entry);
+ SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[1], kp_entry);
+ SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[2], kp_entry);
+
+ KASSERT(l != NULL, ("must provide local"));
+ r->r_local = l;
+
+ rw_enter_write(&l->l_identity_lock);
+ noise_remote_precompute(r);
+ rw_exit_write(&l->l_identity_lock);
+}
+
+int
+noise_remote_set_psk(struct noise_remote *r,
+ const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
+{
+ int same;
+ rw_enter_write(&r->r_handshake_lock);
+ same = !timingsafe_bcmp(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
+ if (!same) {
+ memcpy(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN);
+ }
+ rw_exit_write(&r->r_handshake_lock);
+ return same ? EEXIST : 0;
+}
+
+int
+noise_remote_keys(struct noise_remote *r, uint8_t public[NOISE_PUBLIC_KEY_LEN],
+ uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
+{
+ static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN];
+ int ret;
+
+ if (public != NULL)
+ memcpy(public, r->r_public, NOISE_PUBLIC_KEY_LEN);
+
+ rw_enter_read(&r->r_handshake_lock);
+ if (psk != NULL)
+ memcpy(psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
+ ret = timingsafe_bcmp(r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN);
+ rw_exit_read(&r->r_handshake_lock);
+
+ /* If r_psk != null_psk return 0, else ENOENT (no psk) */
+ return ret ? 0 : ENOENT;
+}
+
+void
+noise_remote_precompute(struct noise_remote *r)
+{
+ struct noise_local *l = r->r_local;
+ rw_assert_wrlock(&l->l_identity_lock);
+ if (!l->l_has_identity)
+ bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN);
+ else if (!curve25519(r->r_ss, l->l_private, r->r_public))
+ bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN);
+
+ rw_enter_write(&r->r_handshake_lock);
+ noise_remote_handshake_index_drop(r);
+ explicit_bzero(&r->r_handshake, sizeof(r->r_handshake));
+ rw_exit_write(&r->r_handshake_lock);
+}
+
+/* Handshake functions */
+int
+noise_create_initiation(struct noise_remote *r, uint32_t *s_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
+{
+ struct noise_handshake *hs = &r->r_handshake;
+ struct noise_local *l = r->r_local;
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
+ int ret = EINVAL;
+
+ rw_enter_read(&l->l_identity_lock);
+ rw_enter_write(&r->r_handshake_lock);
+ if (!l->l_has_identity)
+ goto error;
+ noise_param_init(hs->hs_ck, hs->hs_hash, r->r_public);
+
+ /* e */
+ curve25519_generate_secret(hs->hs_e);
+ if (curve25519_generate_public(ue, hs->hs_e) == 0)
+ goto error;
+ noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue);
+
+ /* es */
+ if (noise_mix_dh(hs->hs_ck, key, hs->hs_e, r->r_public) != 0)
+ goto error;
+
+ /* s */
+ noise_msg_encrypt(es, l->l_public,
+ NOISE_PUBLIC_KEY_LEN, key, hs->hs_hash);
+
+ /* ss */
+ if (noise_mix_ss(hs->hs_ck, key, r->r_ss) != 0)
+ goto error;
+
+ /* {t} */
+ noise_tai64n_now(ets);
+ noise_msg_encrypt(ets, ets,
+ NOISE_TIMESTAMP_LEN, key, hs->hs_hash);
+
+ noise_remote_handshake_index_drop(r);
+ hs->hs_state = CREATED_INITIATION;
+ hs->hs_local_index = noise_remote_handshake_index_get(r);
+ *s_idx = hs->hs_local_index;
+ ret = 0;
+error:
+ rw_exit_write(&r->r_handshake_lock);
+ rw_exit_read(&l->l_identity_lock);
+ explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
+ return ret;
+}
+
+int
+noise_consume_initiation(struct noise_local *l, struct noise_remote **rp,
+ uint32_t s_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN])
+{
+ struct noise_remote *r;
+ struct noise_handshake hs;
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
+ uint8_t r_public[NOISE_PUBLIC_KEY_LEN];
+ uint8_t timestamp[NOISE_TIMESTAMP_LEN];
+ int ret = EINVAL;
+
+ rw_enter_read(&l->l_identity_lock);
+ if (!l->l_has_identity)
+ goto error;
+ noise_param_init(hs.hs_ck, hs.hs_hash, l->l_public);
+
+ /* e */
+ noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue);
+
+ /* es */
+ if (noise_mix_dh(hs.hs_ck, key, l->l_private, ue) != 0)
+ goto error;
+
+ /* s */
+ if (noise_msg_decrypt(r_public, es,
+ NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
+ goto error;
+
+ /* Lookup the remote we received from */
+ if ((r = l->l_upcall.u_remote_get(l->l_upcall.u_arg, r_public)) == NULL)
+ goto error;
+
+ /* ss */
+ if (noise_mix_ss(hs.hs_ck, key, r->r_ss) != 0)
+ goto error;
+
+ /* {t} */
+ if (noise_msg_decrypt(timestamp, ets,
+ NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
+ goto error;
+
+ hs.hs_state = CONSUMED_INITIATION;
+ hs.hs_local_index = 0;
+ hs.hs_remote_index = s_idx;
+ memcpy(hs.hs_e, ue, NOISE_PUBLIC_KEY_LEN);
+
+ /* We have successfully computed the same results, now we ensure that
+ * this is not an initiation replay, or a flood attack */
+ rw_enter_write(&r->r_handshake_lock);
+
+ /* Replay */
+ if (memcmp(timestamp, r->r_timestamp, NOISE_TIMESTAMP_LEN) > 0)
+ memcpy(r->r_timestamp, timestamp, NOISE_TIMESTAMP_LEN);
+ else
+ goto error_set;
+ /* Flood attack */
+ if (noise_timer_expired(&r->r_last_init, 0, REJECT_INTERVAL))
+ getnanouptime(&r->r_last_init);
+ else
+ goto error_set;
+
+ /* Ok, we're happy to accept this initiation now */
+ noise_remote_handshake_index_drop(r);
+ r->r_handshake = hs;
+ *rp = r;
+ ret = 0;
+error_set:
+ rw_exit_write(&r->r_handshake_lock);
+error:
+ rw_exit_read(&l->l_identity_lock);
+ explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
+ explicit_bzero(&hs, sizeof(hs));
+ return ret;
+}
+
+int
+noise_create_response(struct noise_remote *r, uint32_t *s_idx, uint32_t *r_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN])
+{
+ struct noise_handshake *hs = &r->r_handshake;
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
+ uint8_t e[NOISE_PUBLIC_KEY_LEN];
+ int ret = EINVAL;
+
+ rw_enter_read(&r->r_local->l_identity_lock);
+ rw_enter_write(&r->r_handshake_lock);
+
+ if (hs->hs_state != CONSUMED_INITIATION)
+ goto error;
+
+ /* e */
+ curve25519_generate_secret(e);
+ if (curve25519_generate_public(ue, e) == 0)
+ goto error;
+ noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue);
+
+ /* ee */
+ if (noise_mix_dh(hs->hs_ck, NULL, e, hs->hs_e) != 0)
+ goto error;
+
+ /* se */
+ if (noise_mix_dh(hs->hs_ck, NULL, e, r->r_public) != 0)
+ goto error;
+
+ /* psk */
+ noise_mix_psk(hs->hs_ck, hs->hs_hash, key, r->r_psk);
+
+ /* {} */
+ noise_msg_encrypt(en, NULL, 0, key, hs->hs_hash);
+
+ hs->hs_state = CREATED_RESPONSE;
+ hs->hs_local_index = noise_remote_handshake_index_get(r);
+ *r_idx = hs->hs_remote_index;
+ *s_idx = hs->hs_local_index;
+ ret = 0;
+error:
+ rw_exit_write(&r->r_handshake_lock);
+ rw_exit_read(&r->r_local->l_identity_lock);
+ explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
+ explicit_bzero(e, NOISE_PUBLIC_KEY_LEN);
+ return ret;
+}
+
+int
+noise_consume_response(struct noise_remote *r, uint32_t s_idx, uint32_t r_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN])
+{
+ struct noise_local *l = r->r_local;
+ struct noise_handshake hs;
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN];
+ uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN];
+ int ret = EINVAL;
+
+ rw_enter_read(&l->l_identity_lock);
+ if (!l->l_has_identity)
+ goto error;
+
+ rw_enter_read(&r->r_handshake_lock);
+ hs = r->r_handshake;
+ memcpy(preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN);
+ rw_exit_read(&r->r_handshake_lock);
+
+ if (hs.hs_state != CREATED_INITIATION ||
+ hs.hs_local_index != r_idx)
+ goto error;
+
+ /* e */
+ noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue);
+
+ /* ee */
+ if (noise_mix_dh(hs.hs_ck, NULL, hs.hs_e, ue) != 0)
+ goto error;
+
+ /* se */
+ if (noise_mix_dh(hs.hs_ck, NULL, l->l_private, ue) != 0)
+ goto error;
+
+ /* psk */
+ noise_mix_psk(hs.hs_ck, hs.hs_hash, key, preshared_key);
+
+ /* {} */
+ if (noise_msg_decrypt(NULL, en,
+ 0 + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0)
+ goto error;
+
+ hs.hs_remote_index = s_idx;
+
+ rw_enter_write(&r->r_handshake_lock);
+ if (r->r_handshake.hs_state == hs.hs_state &&
+ r->r_handshake.hs_local_index == hs.hs_local_index) {
+ r->r_handshake = hs;
+ r->r_handshake.hs_state = CONSUMED_RESPONSE;
+ ret = 0;
+ }
+ rw_exit_write(&r->r_handshake_lock);
+error:
+ rw_exit_read(&l->l_identity_lock);
+ explicit_bzero(&hs, sizeof(hs));
+ explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN);
+ return ret;
+}
+
+int
+noise_remote_begin_session(struct noise_remote *r)
+{
+ struct noise_handshake *hs = &r->r_handshake;
+ struct noise_keypair kp, *next, *current, *previous;
+
+ rw_enter_write(&r->r_handshake_lock);
+
+ /* We now derive the keypair from the handshake */
+ if (hs->hs_state == CONSUMED_RESPONSE) {
+ kp.kp_is_initiator = 1;
+ noise_kdf(kp.kp_send, kp.kp_recv, NULL, NULL,
+ NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
+ hs->hs_ck);
+ } else if (hs->hs_state == CREATED_RESPONSE) {
+ kp.kp_is_initiator = 0;
+ noise_kdf(kp.kp_recv, kp.kp_send, NULL, NULL,
+ NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
+ hs->hs_ck);
+ } else {
+ rw_exit_write(&r->r_handshake_lock);
+ return EINVAL;
+ }
+
+ kp.kp_valid = 1;
+ kp.kp_local_index = hs->hs_local_index;
+ kp.kp_remote_index = hs->hs_remote_index;
+ getnanouptime(&kp.kp_birthdate);
+ bzero(&kp.kp_ctr, sizeof(kp.kp_ctr));
+ rw_init(&kp.kp_ctr.c_lock, "noise_counter");
+
+ /* Now we need to add_new_keypair */
+ rw_enter_write(&r->r_keypair_lock);
+ next = r->r_next;
+ current = r->r_current;
+ previous = r->r_previous;
+
+ if (kp.kp_is_initiator) {
+ if (next != NULL) {
+ r->r_next = NULL;
+ r->r_previous = next;
+ noise_remote_keypair_free(r, current);
+ } else {
+ r->r_previous = current;
+ }
+
+ noise_remote_keypair_free(r, previous);
+
+ r->r_current = noise_remote_keypair_allocate(r);
+ *r->r_current = kp;
+ } else {
+ noise_remote_keypair_free(r, next);
+ r->r_previous = NULL;
+ noise_remote_keypair_free(r, previous);
+
+ r->r_next = noise_remote_keypair_allocate(r);
+ *r->r_next = kp;
+ }
+ rw_exit_write(&r->r_keypair_lock);
+
+ explicit_bzero(&r->r_handshake, sizeof(r->r_handshake));
+ rw_exit_write(&r->r_handshake_lock);
+
+ explicit_bzero(&kp, sizeof(kp));
+ return 0;
+}
+
+void
+noise_remote_clear(struct noise_remote *r)
+{
+ rw_enter_write(&r->r_handshake_lock);
+ noise_remote_handshake_index_drop(r);
+ explicit_bzero(&r->r_handshake, sizeof(r->r_handshake));
+ rw_exit_write(&r->r_handshake_lock);
+
+ rw_enter_write(&r->r_keypair_lock);
+ noise_remote_keypair_free(r, r->r_next);
+ noise_remote_keypair_free(r, r->r_current);
+ noise_remote_keypair_free(r, r->r_previous);
+ r->r_next = NULL;
+ r->r_current = NULL;
+ r->r_previous = NULL;
+ rw_exit_write(&r->r_keypair_lock);
+}
+
+void
+noise_remote_expire_current(struct noise_remote *r)
+{
+ rw_enter_write(&r->r_keypair_lock);
+ if (r->r_next != NULL)
+ r->r_next->kp_valid = 0;
+ if (r->r_current != NULL)
+ r->r_current->kp_valid = 0;
+ rw_exit_write(&r->r_keypair_lock);
+}
+
+int
+noise_remote_ready(struct noise_remote *r)
+{
+ struct noise_keypair *kp;
+ int ret;
+
+ rw_enter_read(&r->r_keypair_lock);
+ /* kp_ctr isn't locked here, we're happy to accept a racy read. */
+ if ((kp = r->r_current) == NULL ||
+ !kp->kp_valid ||
+ noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) ||
+ kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES ||
+ kp->kp_ctr.c_send >= REJECT_AFTER_MESSAGES)
+ ret = EINVAL;
+ else
+ ret = 0;
+ rw_exit_read(&r->r_keypair_lock);
+ return ret;
+}
+
+int
+noise_remote_encrypt(struct noise_remote *r, uint32_t *r_idx, uint64_t *nonce,
+ uint8_t *buf, size_t buflen)
+{
+ struct noise_keypair *kp;
+ int ret = EINVAL;
+
+ rw_enter_read(&r->r_keypair_lock);
+ if ((kp = r->r_current) == NULL)
+ goto error;
+
+ /* We confirm that our values are within our tolerances. We want:
+ * - a valid keypair
+ * - our keypair to be less than REJECT_AFTER_TIME seconds old
+ * - our receive counter to be less than REJECT_AFTER_MESSAGES
+ * - our send counter to be less than REJECT_AFTER_MESSAGES
+ *
+ * kp_ctr isn't locked here, we're happy to accept a racy read. */
+ if (!kp->kp_valid ||
+ noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) ||
+ kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES ||
+ ((*nonce = noise_counter_send(&kp->kp_ctr)) > REJECT_AFTER_MESSAGES))
+ goto error;
+
+ /* We encrypt into the same buffer, so the caller must ensure that buf
+ * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index
+ * are passed back out to the caller through the provided data pointer. */
+ *r_idx = kp->kp_remote_index;
+ chacha20poly1305_encrypt(buf, buf, buflen,
+ NULL, 0, *nonce, kp->kp_send);
+
+ /* If our values are still within tolerances, but we are approaching
+ * the tolerances, we notify the caller with ESTALE that they should
+ * establish a new keypair. The current keypair can continue to be used
+ * until the tolerances are hit. We notify if:
+ * - our send counter is valid and not less than REKEY_AFTER_MESSAGES
+ * - we're the initiator and our keypair is older than
+ * REKEY_AFTER_TIME seconds */
+ ret = ESTALE;
+ if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) ||
+ (kp->kp_is_initiator &&
+ noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME, 0)))
+ goto error;
+
+ ret = 0;
+error:
+ rw_exit_read(&r->r_keypair_lock);
+ return ret;
+}
+
+int
+noise_remote_decrypt(struct noise_remote *r, uint32_t r_idx, uint64_t nonce,
+ uint8_t *buf, size_t buflen)
+{
+ struct noise_keypair *kp;
+ int ret = EINVAL;
+
+ /* We retrieve the keypair corresponding to the provided index. We
+ * attempt the current keypair first as that is most likely. We also
+ * want to make sure that the keypair is valid as it would be
+ * catastrophic to decrypt against a zero'ed keypair. */
+ rw_enter_read(&r->r_keypair_lock);
+
+ if (r->r_current != NULL && r->r_current->kp_local_index == r_idx) {
+ kp = r->r_current;
+ } else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx) {
+ kp = r->r_previous;
+ } else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx) {
+ kp = r->r_next;
+ } else {
+ goto error;
+ }
+
+ /* We confirm that our values are within our tolerances. These values
+ * are the same as the encrypt routine.
+ *
+ * kp_ctr isn't locked here, we're happy to accept a racy read. */
+ if (noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) ||
+ kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES)
+ goto error;
+
+ /* Decrypt, then validate the counter. We don't want to validate the
+ * counter before decrypting as we do not know the message is authentic
+ * prior to decryption. */
+ if (chacha20poly1305_decrypt(buf, buf, buflen,
+ NULL, 0, nonce, kp->kp_recv) == 0)
+ goto error;
+
+ if (noise_counter_recv(&kp->kp_ctr, nonce) != 0)
+ goto error;
+
+ /* If we've received the handshake confirming data packet then move the
+ * next keypair into current. If we do slide the next keypair in, then
+ * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a
+ * data packet can't confirm a session that we are an INITIATOR of. */
+ if (kp == r->r_next) {
+ rw_exit_read(&r->r_keypair_lock);
+ rw_enter_write(&r->r_keypair_lock);
+ if (kp == r->r_next && kp->kp_local_index == r_idx) {
+ noise_remote_keypair_free(r, r->r_previous);
+ r->r_previous = r->r_current;
+ r->r_current = r->r_next;
+ r->r_next = NULL;
+
+ ret = ECONNRESET;
+ goto error;
+ }
+ rw_enter(&r->r_keypair_lock, RW_DOWNGRADE);
+ }
+
+ /* Similar to when we encrypt, we want to notify the caller when we
+ * are approaching our tolerances. We notify if:
+ * - we're the initiator and the current keypair is older than
+ * REKEY_AFTER_TIME_RECV seconds. */
+ ret = ESTALE;
+ kp = r->r_current;
+ if (kp != NULL &&
+ kp->kp_valid &&
+ kp->kp_is_initiator &&
+ noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME_RECV, 0))
+ goto error;
+
+ ret = 0;
+
+error:
+ rw_exit(&r->r_keypair_lock);
+ return ret;
+}
+
+/* Private functions - these should not be called outside this file under any
+ * circumstances. */
+static struct noise_keypair *
+noise_remote_keypair_allocate(struct noise_remote *r)
+{
+ struct noise_keypair *kp;
+ kp = SLIST_FIRST(&r->r_unused_keypairs);
+ SLIST_REMOVE_HEAD(&r->r_unused_keypairs, kp_entry);
+ return kp;
+}
+
+static void
+noise_remote_keypair_free(struct noise_remote *r, struct noise_keypair *kp)
+{
+ struct noise_upcall *u = &r->r_local->l_upcall;
+ if (kp != NULL) {
+ SLIST_INSERT_HEAD(&r->r_unused_keypairs, kp, kp_entry);
+ u->u_index_drop(u->u_arg, kp->kp_local_index);
+ bzero(kp->kp_send, sizeof(kp->kp_send));
+ bzero(kp->kp_recv, sizeof(kp->kp_recv));
+ }
+}
+
+static uint32_t
+noise_remote_handshake_index_get(struct noise_remote *r)
+{
+ struct noise_upcall *u = &r->r_local->l_upcall;
+ return u->u_index_set(u->u_arg, r);
+}
+
+static void
+noise_remote_handshake_index_drop(struct noise_remote *r)
+{
+ struct noise_handshake *hs = &r->r_handshake;
+ struct noise_upcall *u = &r->r_local->l_upcall;
+ rw_assert_wrlock(&r->r_handshake_lock);
+ if (hs->hs_state != HS_ZEROED)
+ u->u_index_drop(u->u_arg, hs->hs_local_index);
+}
+
+static uint64_t
+noise_counter_send(struct noise_counter *ctr)
+{
+ uint64_t ret;
+ rw_enter_write(&ctr->c_lock);
+ ret = ctr->c_send++;
+ rw_exit_write(&ctr->c_lock);
+ return ret;
+}
+
+static int
+noise_counter_recv(struct noise_counter *ctr, uint64_t recv)
+{
+ uint64_t i, top, index_recv, index_ctr;
+ unsigned long bit;
+ int ret = EEXIST;
+
+ rw_enter_write(&ctr->c_lock);
+
+ /* Check that the recv counter is valid */
+ if (ctr->c_recv >= REJECT_AFTER_MESSAGES ||
+ recv >= REJECT_AFTER_MESSAGES)
+ goto error;
+
+ /* If the packet is out of the window, invalid */
+ if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv)
+ goto error;
+
+ /* If the new counter is ahead of the current counter, we'll need to
+ * zero out the bitmap that has previously been used */
+ index_recv = recv / COUNTER_BITS;
+ index_ctr = ctr->c_recv / COUNTER_BITS;
+
+ if (recv > ctr->c_recv) {
+ top = MIN(index_recv - index_ctr, COUNTER_NUM);
+ for (i = 1; i <= top; i++)
+ ctr->c_backtrack[
+ (i + index_ctr) & (COUNTER_NUM - 1)] = 0;
+ ctr->c_recv = recv;
+ }
+
+ index_recv %= COUNTER_NUM;
+ bit = 1ul << (recv % COUNTER_BITS);
+
+ if (ctr->c_backtrack[index_recv] & bit)
+ goto error;
+
+ ctr->c_backtrack[index_recv] |= bit;
+
+ ret = 0;
+error:
+ rw_exit_write(&ctr->c_lock);
+ return ret;
+}
+
+static void
+noise_kdf(uint8_t *a, uint8_t *b, uint8_t *c, const uint8_t *x,
+ size_t a_len, size_t b_len, size_t c_len, size_t x_len,
+ const uint8_t ck[NOISE_HASH_LEN])
+{
+ uint8_t out[BLAKE2S_HASH_SIZE + 1];
+ uint8_t sec[BLAKE2S_HASH_SIZE];
+
+#ifdef DIAGNOSTIC
+ MPASS(a_len <= BLAKE2S_HASH_SIZE && b_len <= BLAKE2S_HASH_SIZE &&
+ c_len <= BLAKE2S_HASH_SIZE);
+ MPASS(!(b || b_len || c || c_len) || (a && a_len));
+ MPASS(!(c || c_len) || (b && b_len));
+#endif
+
+ /* Extract entropy from "x" into sec */
+ blake2s_hmac(sec, x, ck, BLAKE2S_HASH_SIZE, x_len, NOISE_HASH_LEN);
+
+ if (a == NULL || a_len == 0)
+ goto out;
+
+ /* Expand first key: key = sec, data = 0x1 */
+ out[0] = 1;
+ blake2s_hmac(out, out, sec, BLAKE2S_HASH_SIZE, 1, BLAKE2S_HASH_SIZE);
+ memcpy(a, out, a_len);
+
+ if (b == NULL || b_len == 0)
+ goto out;
+
+ /* Expand second key: key = sec, data = "a" || 0x2 */
+ out[BLAKE2S_HASH_SIZE] = 2;
+ blake2s_hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1,
+ BLAKE2S_HASH_SIZE);
+ memcpy(b, out, b_len);
+
+ if (c == NULL || c_len == 0)
+ goto out;
+
+ /* Expand third key: key = sec, data = "b" || 0x3 */
+ out[BLAKE2S_HASH_SIZE] = 3;
+ blake2s_hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1,
+ BLAKE2S_HASH_SIZE);
+ memcpy(c, out, c_len);
+
+out:
+ /* Clear sensitive data from stack */
+ explicit_bzero(sec, BLAKE2S_HASH_SIZE);
+ explicit_bzero(out, BLAKE2S_HASH_SIZE + 1);
+}
+
+static int
+noise_mix_dh(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t private[NOISE_PUBLIC_KEY_LEN],
+ const uint8_t public[NOISE_PUBLIC_KEY_LEN])
+{
+ uint8_t dh[NOISE_PUBLIC_KEY_LEN];
+
+ if (!curve25519(dh, private, public))
+ return EINVAL;
+ noise_kdf(ck, key, NULL, dh,
+ NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck);
+ explicit_bzero(dh, NOISE_PUBLIC_KEY_LEN);
+ return 0;
+}
+
+static int
+noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t ss[NOISE_PUBLIC_KEY_LEN])
+{
+ static uint8_t null_point[NOISE_PUBLIC_KEY_LEN];
+ if (timingsafe_bcmp(ss, null_point, NOISE_PUBLIC_KEY_LEN) == 0)
+ return ENOENT;
+ noise_kdf(ck, key, NULL, ss,
+ NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck);
+ return 0;
+}
+
+static void
+noise_mix_hash(uint8_t hash[NOISE_HASH_LEN], const uint8_t *src,
+ size_t src_len)
+{
+ struct blake2s_state blake;
+
+ blake2s_init(&blake, NOISE_HASH_LEN);
+ blake2s_update(&blake, hash, NOISE_HASH_LEN);
+ blake2s_update(&blake, src, src_len);
+ blake2s_final(&blake, hash);
+}
+
+static void
+noise_mix_psk(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN],
+ const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN])
+{
+ uint8_t tmp[NOISE_HASH_LEN];
+
+ noise_kdf(ck, tmp, key, psk,
+ NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN,
+ NOISE_SYMMETRIC_KEY_LEN, ck);
+ noise_mix_hash(hash, tmp, NOISE_HASH_LEN);
+ explicit_bzero(tmp, NOISE_HASH_LEN);
+}
+
+static void
+noise_param_init(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
+ const uint8_t s[NOISE_PUBLIC_KEY_LEN])
+{
+ struct blake2s_state blake;
+
+ blake2s(ck, (uint8_t *)NOISE_HANDSHAKE_NAME, NULL,
+ NOISE_HASH_LEN, strlen(NOISE_HANDSHAKE_NAME), 0);
+ blake2s_init(&blake, NOISE_HASH_LEN);
+ blake2s_update(&blake, ck, NOISE_HASH_LEN);
+ blake2s_update(&blake, (uint8_t *)NOISE_IDENTIFIER_NAME,
+ strlen(NOISE_IDENTIFIER_NAME));
+ blake2s_final(&blake, hash);
+
+ noise_mix_hash(hash, s, NOISE_PUBLIC_KEY_LEN);
+}
+
+static void
+noise_msg_encrypt(uint8_t *dst, const uint8_t *src, size_t src_len,
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN])
+{
+ /* Nonce always zero for Noise_IK */
+ chacha20poly1305_encrypt(dst, src, src_len,
+ hash, NOISE_HASH_LEN, 0, key);
+ noise_mix_hash(hash, dst, src_len + NOISE_AUTHTAG_LEN);
+}
+
+static int
+noise_msg_decrypt(uint8_t *dst, const uint8_t *src, size_t src_len,
+ uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN])
+{
+ /* Nonce always zero for Noise_IK */
+ if (!chacha20poly1305_decrypt(dst, src, src_len,
+ hash, NOISE_HASH_LEN, 0, key))
+ return EINVAL;
+ noise_mix_hash(hash, src, src_len);
+ return 0;
+}
+
+static void
+noise_msg_ephemeral(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN],
+ const uint8_t src[NOISE_PUBLIC_KEY_LEN])
+{
+ noise_mix_hash(hash, src, NOISE_PUBLIC_KEY_LEN);
+ noise_kdf(ck, NULL, NULL, src, NOISE_HASH_LEN, 0, 0,
+ NOISE_PUBLIC_KEY_LEN, ck);
+}
+
+static void
+noise_tai64n_now(uint8_t output[NOISE_TIMESTAMP_LEN])
+{
+ struct timespec time;
+ uint64_t sec;
+ uint32_t nsec;
+
+ getnanotime(&time);
+
+ /* Round down the nsec counter to limit precise timing leak. */
+ time.tv_nsec &= REJECT_INTERVAL_MASK;
+
+ /* https://cr.yp.to/libtai/tai64.html */
+ sec = htobe64(0x400000000000000aULL + time.tv_sec);
+ nsec = htobe32(time.tv_nsec);
+
+ /* memcpy to output buffer, assuming output could be unaligned. */
+ memcpy(output, &sec, sizeof(sec));
+ memcpy(output + sizeof(sec), &nsec, sizeof(nsec));
+}
+
+static int
+noise_timer_expired(struct timespec *birthdate, time_t sec, long nsec)
+{
+ struct timespec uptime;
+ struct timespec expire = { .tv_sec = sec, .tv_nsec = nsec };
+
+ /* We don't really worry about a zeroed birthdate, to avoid the extra
+ * check on every encrypt/decrypt. This does mean that r_last_init
+ * check may fail if getnanouptime is < REJECT_INTERVAL from 0. */
+
+ getnanouptime(&uptime);
+ timespecadd(birthdate, &expire, &expire);
+ return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0;
+}
diff --git a/src/wg_noise.h b/src/wg_noise.h
new file mode 100644
index 0000000..198ddd1
--- /dev/null
+++ b/src/wg_noise.h
@@ -0,0 +1,180 @@
+/* SPDX-License-Identifier: ISC
+ *
+ * Copyright (C) 2015-2021 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2019-2021 Matt Dunwoodie <ncon@noconroy.net>
+ */
+
+#ifndef __NOISE_H__
+#define __NOISE_H__
+
+#include <sys/types.h>
+#include <sys/time.h>
+#include <sys/rwlock.h>
+
+#include "crypto.h"
+
+#define NOISE_PUBLIC_KEY_LEN CURVE25519_KEY_SIZE
+#define NOISE_SYMMETRIC_KEY_LEN CHACHA20POLY1305_KEY_SIZE
+#define NOISE_TIMESTAMP_LEN (sizeof(uint64_t) + sizeof(uint32_t))
+#define NOISE_AUTHTAG_LEN CHACHA20POLY1305_AUTHTAG_SIZE
+#define NOISE_HASH_LEN BLAKE2S_HASH_SIZE
+
+/* Protocol string constants */
+#define NOISE_HANDSHAKE_NAME "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
+#define NOISE_IDENTIFIER_NAME "WireGuard v1 zx2c4 Jason@zx2c4.com"
+
+/* Constants for the counter */
+#define COUNTER_BITS_TOTAL 8192
+#define COUNTER_BITS (sizeof(unsigned long) * 8)
+#define COUNTER_NUM (COUNTER_BITS_TOTAL / COUNTER_BITS)
+#define COUNTER_WINDOW_SIZE (COUNTER_BITS_TOTAL - COUNTER_BITS)
+
+/* Constants for the keypair */
+#define REKEY_AFTER_MESSAGES (1ull << 60)
+#define REJECT_AFTER_MESSAGES (UINT64_MAX - COUNTER_WINDOW_SIZE - 1)
+#define REKEY_AFTER_TIME 120
+#define REKEY_AFTER_TIME_RECV 165
+#define REJECT_AFTER_TIME 180
+#define REJECT_INTERVAL (1000000000 / 50) /* fifty times per sec */
+/* 24 = floor(log2(REJECT_INTERVAL)) */
+#define REJECT_INTERVAL_MASK (~((1ull<<24)-1))
+
+enum noise_state_hs {
+ HS_ZEROED = 0,
+ CREATED_INITIATION,
+ CONSUMED_INITIATION,
+ CREATED_RESPONSE,
+ CONSUMED_RESPONSE,
+};
+
+struct noise_handshake {
+ enum noise_state_hs hs_state;
+ uint32_t hs_local_index;
+ uint32_t hs_remote_index;
+ uint8_t hs_e[NOISE_PUBLIC_KEY_LEN];
+ uint8_t hs_hash[NOISE_HASH_LEN];
+ uint8_t hs_ck[NOISE_HASH_LEN];
+};
+
+struct noise_counter {
+ struct rwlock c_lock;
+ uint64_t c_send;
+ uint64_t c_recv;
+ unsigned long c_backtrack[COUNTER_NUM];
+};
+
+struct noise_keypair {
+ SLIST_ENTRY(noise_keypair) kp_entry;
+ int kp_valid;
+ int kp_is_initiator;
+ uint32_t kp_local_index;
+ uint32_t kp_remote_index;
+ uint8_t kp_send[NOISE_SYMMETRIC_KEY_LEN];
+ uint8_t kp_recv[NOISE_SYMMETRIC_KEY_LEN];
+ struct timespec kp_birthdate; /* nanouptime */
+ struct noise_counter kp_ctr;
+};
+
+struct noise_remote {
+ uint8_t r_public[NOISE_PUBLIC_KEY_LEN];
+ struct noise_local *r_local;
+ uint8_t r_ss[NOISE_PUBLIC_KEY_LEN];
+
+ struct rwlock r_handshake_lock;
+ struct noise_handshake r_handshake;
+ uint8_t r_psk[NOISE_SYMMETRIC_KEY_LEN];
+ uint8_t r_timestamp[NOISE_TIMESTAMP_LEN];
+ struct timespec r_last_init; /* nanouptime */
+
+ struct rwlock r_keypair_lock;
+ SLIST_HEAD(,noise_keypair) r_unused_keypairs;
+ struct noise_keypair *r_next, *r_current, *r_previous;
+ struct noise_keypair r_keypair[3]; /* 3: next, current, previous. */
+
+};
+
+struct noise_local {
+ struct rwlock l_identity_lock;
+ int l_has_identity;
+ uint8_t l_public[NOISE_PUBLIC_KEY_LEN];
+ uint8_t l_private[NOISE_PUBLIC_KEY_LEN];
+
+ struct noise_upcall {
+ void *u_arg;
+ struct noise_remote *
+ (*u_remote_get)(void *, uint8_t[NOISE_PUBLIC_KEY_LEN]);
+ uint32_t
+ (*u_index_set)(void *, struct noise_remote *);
+ void (*u_index_drop)(void *, uint32_t);
+ } l_upcall;
+};
+
+/* Set/Get noise parameters */
+void noise_local_init(struct noise_local *, struct noise_upcall *);
+void noise_local_lock_identity(struct noise_local *);
+void noise_local_unlock_identity(struct noise_local *);
+int noise_local_set_private(struct noise_local *,
+ const uint8_t[NOISE_PUBLIC_KEY_LEN]);
+int noise_local_keys(struct noise_local *, uint8_t[NOISE_PUBLIC_KEY_LEN],
+ uint8_t[NOISE_PUBLIC_KEY_LEN]);
+
+void noise_remote_init(struct noise_remote *,
+ const uint8_t[NOISE_PUBLIC_KEY_LEN], struct noise_local *);
+int noise_remote_set_psk(struct noise_remote *,
+ const uint8_t[NOISE_SYMMETRIC_KEY_LEN]);
+int noise_remote_keys(struct noise_remote *, uint8_t[NOISE_PUBLIC_KEY_LEN],
+ uint8_t[NOISE_SYMMETRIC_KEY_LEN]);
+
+/* Should be called anytime noise_local_set_private is called */
+void noise_remote_precompute(struct noise_remote *);
+
+/* Cryptographic functions */
+int noise_create_initiation(
+ struct noise_remote *,
+ uint32_t *s_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]);
+
+int noise_consume_initiation(
+ struct noise_local *,
+ struct noise_remote **,
+ uint32_t s_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN],
+ uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]);
+
+int noise_create_response(
+ struct noise_remote *,
+ uint32_t *s_idx,
+ uint32_t *r_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t en[0 + NOISE_AUTHTAG_LEN]);
+
+int noise_consume_response(
+ struct noise_remote *,
+ uint32_t s_idx,
+ uint32_t r_idx,
+ uint8_t ue[NOISE_PUBLIC_KEY_LEN],
+ uint8_t en[0 + NOISE_AUTHTAG_LEN]);
+
+int noise_remote_begin_session(struct noise_remote *);
+void noise_remote_clear(struct noise_remote *);
+void noise_remote_expire_current(struct noise_remote *);
+
+int noise_remote_ready(struct noise_remote *);
+
+int noise_remote_encrypt(
+ struct noise_remote *,
+ uint32_t *r_idx,
+ uint64_t *nonce,
+ uint8_t *buf,
+ size_t buflen);
+int noise_remote_decrypt(
+ struct noise_remote *,
+ uint32_t r_idx,
+ uint64_t nonce,
+ uint8_t *buf,
+ size_t buflen);
+
+#endif /* __NOISE_H__ */