From 362884e65029464d97e50c9b660b5b90621e239e Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 17 Mar 2021 09:34:21 -0600 Subject: Initial import There's still more to do with wiring this up properly. Signed-off-by: Jason A. Donenfeld --- COPYING | 17 + README.md | 9 + src/Makefile | 11 + src/crypto.c | 1694 +++++++++++++++++++++++++ src/crypto.h | 103 ++ src/if_wg.c | 3451 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/if_wg.h | 37 + src/support.h | 56 + src/wg_cookie.c | 427 +++++++ src/wg_cookie.h | 114 ++ src/wg_noise.c | 952 ++++++++++++++ src/wg_noise.h | 180 +++ tests/if_wg_test.sh | 164 +++ tests/netns.sh | 643 ++++++++++ 14 files changed, 7858 insertions(+) create mode 100644 COPYING create mode 100644 README.md create mode 100644 src/Makefile create mode 100644 src/crypto.c create mode 100644 src/crypto.h create mode 100644 src/if_wg.c create mode 100644 src/if_wg.h create mode 100644 src/support.h create mode 100644 src/wg_cookie.c create mode 100644 src/wg_cookie.h create mode 100644 src/wg_noise.c create mode 100644 src/wg_noise.h create mode 100755 tests/if_wg_test.sh create mode 100755 tests/netns.sh diff --git a/COPYING b/COPYING new file mode 100644 index 0000000..f85e365 --- /dev/null +++ b/COPYING @@ -0,0 +1,17 @@ +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..82860e6 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# WireGuard for FreeBSD + +This is a kernel module for FreeBSD to support [WireGuard](https://www.wireguard.com/). It is being developed here before its eventual submission to FreeBSD 13.1 or 14. + +### Installation instructions + +``` +TODO +``` diff --git a/src/Makefile b/src/Makefile new file mode 100644 index 0000000..d65cf80 --- /dev/null +++ b/src/Makefile @@ -0,0 +1,11 @@ +# $FreeBSD$ + +KMOD= if_wg + +.PATH: ${SRCTOP}/sys/dev/if_wg + +SRCS= opt_inet.h opt_inet6.h device_if.h bus_if.h ifdi_if.h + +SRCS+= if_wg.c wg_noise.c wg_cookie.c crypto.c + +.include diff --git a/src/crypto.c b/src/crypto.c new file mode 100644 index 0000000..97bef62 --- /dev/null +++ b/src/crypto.c @@ -0,0 +1,1694 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2015-2021 Jason A. Donenfeld . All Rights Reserved. + */ + +#include +#include +#include + +#include "crypto.h" + +#ifndef ARRAY_SIZE +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) +#endif +#ifndef noinline +#define noinline __attribute__((noinline)) +#endif +#ifndef __aligned +#define __aligned(x) __attribute__((aligned(x))) +#endif +#ifndef DIV_ROUND_UP +#define DIV_ROUND_UP(n,d) (((n) + (d) - 1) / (d)) +#endif + +#define le32_to_cpup(a) le32toh(*(a)) +#define le64_to_cpup(a) le64toh(*(a)) +#define cpu_to_le32(a) htole32(a) +#define cpu_to_le64(a) htole64(a) + +static inline uint32_t get_unaligned_le32(const uint8_t *a) +{ + uint32_t l; + __builtin_memcpy(&l, a, sizeof(l)); + return le32_to_cpup(&l); +} +static inline uint64_t get_unaligned_le64(const uint8_t *a) +{ + uint64_t l; + __builtin_memcpy(&l, a, sizeof(l)); + return le64_to_cpup(&l); +} +static inline void put_unaligned_le32(uint32_t s, uint8_t *d) +{ + uint32_t l = cpu_to_le32(s); + __builtin_memcpy(d, &l, sizeof(l)); +} +static inline void cpu_to_le32_array(uint32_t *buf, unsigned int words) +{ + while (words--) { + *buf = cpu_to_le32(*buf); + ++buf; + } +} +static inline void le32_to_cpu_array(uint32_t *buf, unsigned int words) +{ + while (words--) { + *buf = le32_to_cpup(buf); + ++buf; + } +} + +static inline uint32_t rol32(uint32_t word, unsigned int shift) +{ + return (word << (shift & 31)) | (word >> ((-shift) & 31)); +} +static inline uint32_t ror32(uint32_t word, unsigned int shift) +{ + return (word >> (shift & 31)) | (word << ((-shift) & 31)); +} + +static void xor_cpy(uint8_t *dst, const uint8_t *src1, const uint8_t *src2, + size_t len) +{ + size_t i; + + for (i = 0; i < len; ++i) + dst[i] = src1[i] ^ src2[i]; +} + +#define QUARTER_ROUND(x, a, b, c, d) ( \ + x[a] += x[b], \ + x[d] = rol32((x[d] ^ x[a]), 16), \ + x[c] += x[d], \ + x[b] = rol32((x[b] ^ x[c]), 12), \ + x[a] += x[b], \ + x[d] = rol32((x[d] ^ x[a]), 8), \ + x[c] += x[d], \ + x[b] = rol32((x[b] ^ x[c]), 7) \ +) + +#define C(i, j) (i * 4 + j) + +#define DOUBLE_ROUND(x) ( \ + /* Column Round */ \ + QUARTER_ROUND(x, C(0, 0), C(1, 0), C(2, 0), C(3, 0)), \ + QUARTER_ROUND(x, C(0, 1), C(1, 1), C(2, 1), C(3, 1)), \ + QUARTER_ROUND(x, C(0, 2), C(1, 2), C(2, 2), C(3, 2)), \ + QUARTER_ROUND(x, C(0, 3), C(1, 3), C(2, 3), C(3, 3)), \ + /* Diagonal Round */ \ + QUARTER_ROUND(x, C(0, 0), C(1, 1), C(2, 2), C(3, 3)), \ + QUARTER_ROUND(x, C(0, 1), C(1, 2), C(2, 3), C(3, 0)), \ + QUARTER_ROUND(x, C(0, 2), C(1, 3), C(2, 0), C(3, 1)), \ + QUARTER_ROUND(x, C(0, 3), C(1, 0), C(2, 1), C(3, 2)) \ +) + +#define TWENTY_ROUNDS(x) ( \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x), \ + DOUBLE_ROUND(x) \ +) + +enum chacha20_lengths { + CHACHA20_NONCE_SIZE = 16, + CHACHA20_KEY_SIZE = 32, + CHACHA20_KEY_WORDS = CHACHA20_KEY_SIZE / sizeof(uint32_t), + CHACHA20_BLOCK_SIZE = 64, + CHACHA20_BLOCK_WORDS = CHACHA20_BLOCK_SIZE / sizeof(uint32_t), + HCHACHA20_NONCE_SIZE = CHACHA20_NONCE_SIZE, + HCHACHA20_KEY_SIZE = CHACHA20_KEY_SIZE +}; + +enum chacha20_constants { /* expand 32-byte k */ + CHACHA20_CONSTANT_EXPA = 0x61707865U, + CHACHA20_CONSTANT_ND_3 = 0x3320646eU, + CHACHA20_CONSTANT_2_BY = 0x79622d32U, + CHACHA20_CONSTANT_TE_K = 0x6b206574U +}; + +struct chacha20_ctx { + union { + uint32_t state[16]; + struct { + uint32_t constant[4]; + uint32_t key[8]; + uint32_t counter[4]; + }; + }; +}; + +static void chacha20_init(struct chacha20_ctx *ctx, + const uint8_t key[CHACHA20_KEY_SIZE], + const uint64_t nonce) +{ + ctx->constant[0] = CHACHA20_CONSTANT_EXPA; + ctx->constant[1] = CHACHA20_CONSTANT_ND_3; + ctx->constant[2] = CHACHA20_CONSTANT_2_BY; + ctx->constant[3] = CHACHA20_CONSTANT_TE_K; + ctx->key[0] = get_unaligned_le32(key + 0); + ctx->key[1] = get_unaligned_le32(key + 4); + ctx->key[2] = get_unaligned_le32(key + 8); + ctx->key[3] = get_unaligned_le32(key + 12); + ctx->key[4] = get_unaligned_le32(key + 16); + ctx->key[5] = get_unaligned_le32(key + 20); + ctx->key[6] = get_unaligned_le32(key + 24); + ctx->key[7] = get_unaligned_le32(key + 28); + ctx->counter[0] = 0; + ctx->counter[1] = 0; + ctx->counter[2] = nonce & 0xffffffffU; + ctx->counter[3] = nonce >> 32; +} + +static void chacha20_block(struct chacha20_ctx *ctx, uint32_t *stream) +{ + uint32_t x[CHACHA20_BLOCK_WORDS]; + int i; + + for (i = 0; i < ARRAY_SIZE(x); ++i) + x[i] = ctx->state[i]; + + TWENTY_ROUNDS(x); + + for (i = 0; i < ARRAY_SIZE(x); ++i) + stream[i] = cpu_to_le32(x[i] + ctx->state[i]); + + ctx->counter[0] += 1; +} + +static void chacha20(struct chacha20_ctx *ctx, uint8_t *out, const uint8_t *in, + uint32_t len) +{ + uint32_t buf[CHACHA20_BLOCK_WORDS]; + + while (len >= CHACHA20_BLOCK_SIZE) { + chacha20_block(ctx, buf); + xor_cpy(out, in, (uint8_t *)buf, CHACHA20_BLOCK_SIZE); + len -= CHACHA20_BLOCK_SIZE; + out += CHACHA20_BLOCK_SIZE; + in += CHACHA20_BLOCK_SIZE; + } + if (len) { + chacha20_block(ctx, buf); + xor_cpy(out, in, (uint8_t *)buf, len); + } +} + +static void hchacha20(uint32_t derived_key[CHACHA20_KEY_WORDS], + const uint8_t nonce[HCHACHA20_NONCE_SIZE], + const uint8_t key[HCHACHA20_KEY_SIZE]) +{ + uint32_t x[] = { CHACHA20_CONSTANT_EXPA, + CHACHA20_CONSTANT_ND_3, + CHACHA20_CONSTANT_2_BY, + CHACHA20_CONSTANT_TE_K, + get_unaligned_le32(key + 0), + get_unaligned_le32(key + 4), + get_unaligned_le32(key + 8), + get_unaligned_le32(key + 12), + get_unaligned_le32(key + 16), + get_unaligned_le32(key + 20), + get_unaligned_le32(key + 24), + get_unaligned_le32(key + 28), + get_unaligned_le32(nonce + 0), + get_unaligned_le32(nonce + 4), + get_unaligned_le32(nonce + 8), + get_unaligned_le32(nonce + 12) + }; + + TWENTY_ROUNDS(x); + + memcpy(derived_key + 0, x + 0, sizeof(uint32_t) * 4); + memcpy(derived_key + 4, x + 12, sizeof(uint32_t) * 4); +} + +enum poly1305_lengths { + POLY1305_BLOCK_SIZE = 16, + POLY1305_KEY_SIZE = 32, + POLY1305_MAC_SIZE = 16 +}; + +struct poly1305_internal { + uint32_t h[5]; + uint32_t r[5]; + uint32_t s[4]; +}; + +struct poly1305_ctx { + struct poly1305_internal state; + uint32_t nonce[4]; + uint8_t data[POLY1305_BLOCK_SIZE]; + size_t num; +}; + +static void poly1305_init_core(struct poly1305_internal *st, + const uint8_t key[16]) +{ + /* r &= 0xffffffc0ffffffc0ffffffc0fffffff */ + st->r[0] = (get_unaligned_le32(&key[0])) & 0x3ffffff; + st->r[1] = (get_unaligned_le32(&key[3]) >> 2) & 0x3ffff03; + st->r[2] = (get_unaligned_le32(&key[6]) >> 4) & 0x3ffc0ff; + st->r[3] = (get_unaligned_le32(&key[9]) >> 6) & 0x3f03fff; + st->r[4] = (get_unaligned_le32(&key[12]) >> 8) & 0x00fffff; + + /* s = 5*r */ + st->s[0] = st->r[1] * 5; + st->s[1] = st->r[2] * 5; + st->s[2] = st->r[3] * 5; + st->s[3] = st->r[4] * 5; + + /* h = 0 */ + st->h[0] = 0; + st->h[1] = 0; + st->h[2] = 0; + st->h[3] = 0; + st->h[4] = 0; +} + +static void poly1305_blocks_core(struct poly1305_internal *st, + const uint8_t *input, size_t len, + const uint32_t padbit) +{ + const uint32_t hibit = padbit << 24; + uint32_t r0, r1, r2, r3, r4; + uint32_t s1, s2, s3, s4; + uint32_t h0, h1, h2, h3, h4; + uint64_t d0, d1, d2, d3, d4; + uint32_t c; + + r0 = st->r[0]; + r1 = st->r[1]; + r2 = st->r[2]; + r3 = st->r[3]; + r4 = st->r[4]; + + s1 = st->s[0]; + s2 = st->s[1]; + s3 = st->s[2]; + s4 = st->s[3]; + + h0 = st->h[0]; + h1 = st->h[1]; + h2 = st->h[2]; + h3 = st->h[3]; + h4 = st->h[4]; + + while (len >= POLY1305_BLOCK_SIZE) { + /* h += m[i] */ + h0 += (get_unaligned_le32(&input[0])) & 0x3ffffff; + h1 += (get_unaligned_le32(&input[3]) >> 2) & 0x3ffffff; + h2 += (get_unaligned_le32(&input[6]) >> 4) & 0x3ffffff; + h3 += (get_unaligned_le32(&input[9]) >> 6) & 0x3ffffff; + h4 += (get_unaligned_le32(&input[12]) >> 8) | hibit; + + /* h *= r */ + d0 = ((uint64_t)h0 * r0) + ((uint64_t)h1 * s4) + + ((uint64_t)h2 * s3) + ((uint64_t)h3 * s2) + + ((uint64_t)h4 * s1); + d1 = ((uint64_t)h0 * r1) + ((uint64_t)h1 * r0) + + ((uint64_t)h2 * s4) + ((uint64_t)h3 * s3) + + ((uint64_t)h4 * s2); + d2 = ((uint64_t)h0 * r2) + ((uint64_t)h1 * r1) + + ((uint64_t)h2 * r0) + ((uint64_t)h3 * s4) + + ((uint64_t)h4 * s3); + d3 = ((uint64_t)h0 * r3) + ((uint64_t)h1 * r2) + + ((uint64_t)h2 * r1) + ((uint64_t)h3 * r0) + + ((uint64_t)h4 * s4); + d4 = ((uint64_t)h0 * r4) + ((uint64_t)h1 * r3) + + ((uint64_t)h2 * r2) + ((uint64_t)h3 * r1) + + ((uint64_t)h4 * r0); + + /* (partial) h %= p */ + c = (uint32_t)(d0 >> 26); + h0 = (uint32_t)d0 & 0x3ffffff; + d1 += c; + c = (uint32_t)(d1 >> 26); + h1 = (uint32_t)d1 & 0x3ffffff; + d2 += c; + c = (uint32_t)(d2 >> 26); + h2 = (uint32_t)d2 & 0x3ffffff; + d3 += c; + c = (uint32_t)(d3 >> 26); + h3 = (uint32_t)d3 & 0x3ffffff; + d4 += c; + c = (uint32_t)(d4 >> 26); + h4 = (uint32_t)d4 & 0x3ffffff; + h0 += c * 5; + c = (h0 >> 26); + h0 = h0 & 0x3ffffff; + h1 += c; + + input += POLY1305_BLOCK_SIZE; + len -= POLY1305_BLOCK_SIZE; + } + + st->h[0] = h0; + st->h[1] = h1; + st->h[2] = h2; + st->h[3] = h3; + st->h[4] = h4; +} + +static void poly1305_emit_core(struct poly1305_internal *st, uint8_t mac[16], + const uint32_t nonce[4]) +{ + uint32_t h0, h1, h2, h3, h4, c; + uint32_t g0, g1, g2, g3, g4; + uint64_t f; + uint32_t mask; + + /* fully carry h */ + h0 = st->h[0]; + h1 = st->h[1]; + h2 = st->h[2]; + h3 = st->h[3]; + h4 = st->h[4]; + + c = h1 >> 26; + h1 = h1 & 0x3ffffff; + h2 += c; + c = h2 >> 26; + h2 = h2 & 0x3ffffff; + h3 += c; + c = h3 >> 26; + h3 = h3 & 0x3ffffff; + h4 += c; + c = h4 >> 26; + h4 = h4 & 0x3ffffff; + h0 += c * 5; + c = h0 >> 26; + h0 = h0 & 0x3ffffff; + h1 += c; + + /* compute h + -p */ + g0 = h0 + 5; + c = g0 >> 26; + g0 &= 0x3ffffff; + g1 = h1 + c; + c = g1 >> 26; + g1 &= 0x3ffffff; + g2 = h2 + c; + c = g2 >> 26; + g2 &= 0x3ffffff; + g3 = h3 + c; + c = g3 >> 26; + g3 &= 0x3ffffff; + g4 = h4 + c - (1UL << 26); + + /* select h if h < p, or h + -p if h >= p */ + mask = (g4 >> ((sizeof(uint32_t) * 8) - 1)) - 1; + g0 &= mask; + g1 &= mask; + g2 &= mask; + g3 &= mask; + g4 &= mask; + mask = ~mask; + + h0 = (h0 & mask) | g0; + h1 = (h1 & mask) | g1; + h2 = (h2 & mask) | g2; + h3 = (h3 & mask) | g3; + h4 = (h4 & mask) | g4; + + /* h = h % (2^128) */ + h0 = ((h0) | (h1 << 26)) & 0xffffffff; + h1 = ((h1 >> 6) | (h2 << 20)) & 0xffffffff; + h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff; + h3 = ((h3 >> 18) | (h4 << 8)) & 0xffffffff; + + /* mac = (h + nonce) % (2^128) */ + f = (uint64_t)h0 + nonce[0]; + h0 = (uint32_t)f; + f = (uint64_t)h1 + nonce[1] + (f >> 32); + h1 = (uint32_t)f; + f = (uint64_t)h2 + nonce[2] + (f >> 32); + h2 = (uint32_t)f; + f = (uint64_t)h3 + nonce[3] + (f >> 32); + h3 = (uint32_t)f; + + put_unaligned_le32(h0, &mac[0]); + put_unaligned_le32(h1, &mac[4]); + put_unaligned_le32(h2, &mac[8]); + put_unaligned_le32(h3, &mac[12]); +} + +static void poly1305_init(struct poly1305_ctx *ctx, + const uint8_t key[POLY1305_KEY_SIZE]) +{ + ctx->nonce[0] = get_unaligned_le32(&key[16]); + ctx->nonce[1] = get_unaligned_le32(&key[20]); + ctx->nonce[2] = get_unaligned_le32(&key[24]); + ctx->nonce[3] = get_unaligned_le32(&key[28]); + + poly1305_init_core(&ctx->state, key); + + ctx->num = 0; +} + +static void poly1305_update(struct poly1305_ctx *ctx, const uint8_t *input, + size_t len) +{ + const size_t num = ctx->num; + size_t rem; + + if (num) { + rem = POLY1305_BLOCK_SIZE - num; + if (len < rem) { + memcpy(ctx->data + num, input, len); + ctx->num = num + len; + return; + } + memcpy(ctx->data + num, input, rem); + poly1305_blocks_core(&ctx->state, ctx->data, + POLY1305_BLOCK_SIZE, 1); + input += rem; + len -= rem; + } + + rem = len % POLY1305_BLOCK_SIZE; + len -= rem; + + if (len >= POLY1305_BLOCK_SIZE) { + poly1305_blocks_core(&ctx->state, input, len, 1); + input += len; + } + + if (rem) + memcpy(ctx->data, input, rem); + + ctx->num = rem; +} + +static void poly1305_final(struct poly1305_ctx *ctx, + uint8_t mac[POLY1305_MAC_SIZE]) +{ + size_t num = ctx->num; + + if (num) { + ctx->data[num++] = 1; + while (num < POLY1305_BLOCK_SIZE) + ctx->data[num++] = 0; + poly1305_blocks_core(&ctx->state, ctx->data, + POLY1305_BLOCK_SIZE, 0); + } + + poly1305_emit_core(&ctx->state, mac, ctx->nonce); + + explicit_bzero(ctx, sizeof(*ctx)); +} + + +static const uint8_t pad0[16] = { 0 }; + +void +chacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, const size_t src_len, + const uint8_t *ad, const size_t ad_len, + const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + struct poly1305_ctx poly1305_state; + struct chacha20_ctx chacha20_state; + union { + uint8_t block0[POLY1305_KEY_SIZE]; + uint64_t lens[2]; + } b = { { 0 } }; + + chacha20_init(&chacha20_state, key, nonce); + chacha20(&chacha20_state, b.block0, b.block0, sizeof(b.block0)); + poly1305_init(&poly1305_state, b.block0); + + poly1305_update(&poly1305_state, ad, ad_len); + poly1305_update(&poly1305_state, pad0, (0x10 - ad_len) & 0xf); + + chacha20(&chacha20_state, dst, src, src_len); + + poly1305_update(&poly1305_state, dst, src_len); + poly1305_update(&poly1305_state, pad0, (0x10 - src_len) & 0xf); + + b.lens[0] = cpu_to_le64(ad_len); + b.lens[1] = cpu_to_le64(src_len); + poly1305_update(&poly1305_state, (uint8_t *)b.lens, sizeof(b.lens)); + + poly1305_final(&poly1305_state, dst + src_len); + + explicit_bzero(&chacha20_state, sizeof(chacha20_state)); + explicit_bzero(&b, sizeof(b)); +} + +bool +chacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, const size_t src_len, + const uint8_t *ad, const size_t ad_len, + const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + struct poly1305_ctx poly1305_state; + struct chacha20_ctx chacha20_state; + bool ret; + size_t dst_len; + union { + uint8_t block0[POLY1305_KEY_SIZE]; + uint8_t mac[POLY1305_MAC_SIZE]; + uint64_t lens[2]; + } b = { { 0 } }; + + if (src_len < POLY1305_MAC_SIZE) + return false; + + chacha20_init(&chacha20_state, key, nonce); + chacha20(&chacha20_state, b.block0, b.block0, sizeof(b.block0)); + poly1305_init(&poly1305_state, b.block0); + + poly1305_update(&poly1305_state, ad, ad_len); + poly1305_update(&poly1305_state, pad0, (0x10 - ad_len) & 0xf); + + dst_len = src_len - POLY1305_MAC_SIZE; + poly1305_update(&poly1305_state, src, dst_len); + poly1305_update(&poly1305_state, pad0, (0x10 - dst_len) & 0xf); + + b.lens[0] = cpu_to_le64(ad_len); + b.lens[1] = cpu_to_le64(dst_len); + poly1305_update(&poly1305_state, (uint8_t *)b.lens, sizeof(b.lens)); + + poly1305_final(&poly1305_state, b.mac); + + ret = timingsafe_bcmp(b.mac, src + dst_len, POLY1305_MAC_SIZE) == 0; + if (ret) + chacha20(&chacha20_state, dst, src, dst_len); + + explicit_bzero(&chacha20_state, sizeof(chacha20_state)); + explicit_bzero(&b, sizeof(b)); + + return ret; +} + +void +xchacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, + const size_t src_len, const uint8_t *ad, + const size_t ad_len, + const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE], + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + uint32_t derived_key[CHACHA20_KEY_WORDS]; + + hchacha20(derived_key, nonce, key); + cpu_to_le32_array(derived_key, ARRAY_SIZE(derived_key)); + chacha20poly1305_encrypt(dst, src, src_len, ad, ad_len, + get_unaligned_le64(nonce + 16), + (uint8_t *)derived_key); + explicit_bzero(derived_key, CHACHA20POLY1305_KEY_SIZE); +} + +bool +xchacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, + const size_t src_len, const uint8_t *ad, + const size_t ad_len, + const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE], + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]) +{ + bool ret; + uint32_t derived_key[CHACHA20_KEY_WORDS]; + + hchacha20(derived_key, nonce, key); + cpu_to_le32_array(derived_key, ARRAY_SIZE(derived_key)); + ret = chacha20poly1305_decrypt(dst, src, src_len, ad, ad_len, + get_unaligned_le64(nonce + 16), + (uint8_t *)derived_key); + explicit_bzero(derived_key, CHACHA20POLY1305_KEY_SIZE); + return ret; +} + + +static const uint32_t blake2s_iv[8] = { + 0x6A09E667UL, 0xBB67AE85UL, 0x3C6EF372UL, 0xA54FF53AUL, + 0x510E527FUL, 0x9B05688CUL, 0x1F83D9ABUL, 0x5BE0CD19UL +}; + +static const uint8_t blake2s_sigma[10][16] = { + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 }, + { 14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3 }, + { 11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4 }, + { 7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8 }, + { 9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13 }, + { 2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9 }, + { 12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11 }, + { 13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10 }, + { 6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5 }, + { 10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0 }, +}; + +static inline void blake2s_set_lastblock(struct blake2s_state *state) +{ + state->f[0] = -1; +} + +static inline void blake2s_increment_counter(struct blake2s_state *state, + const uint32_t inc) +{ + state->t[0] += inc; + state->t[1] += (state->t[0] < inc); +} + +static inline void blake2s_init_param(struct blake2s_state *state, + const uint32_t param) +{ + int i; + + memset(state, 0, sizeof(*state)); + for (i = 0; i < 8; ++i) + state->h[i] = blake2s_iv[i]; + state->h[0] ^= param; +} + +void blake2s_init(struct blake2s_state *state, const size_t outlen) +{ + blake2s_init_param(state, 0x01010000 | outlen); + state->outlen = outlen; +} + +void blake2s_init_key(struct blake2s_state *state, const size_t outlen, + const uint8_t *key, const size_t keylen) +{ + uint8_t block[BLAKE2S_BLOCK_SIZE] = { 0 }; + + blake2s_init_param(state, 0x01010000 | keylen << 8 | outlen); + state->outlen = outlen; + memcpy(block, key, keylen); + blake2s_update(state, block, BLAKE2S_BLOCK_SIZE); + explicit_bzero(block, BLAKE2S_BLOCK_SIZE); +} + +static inline void blake2s_compress(struct blake2s_state *state, + const uint8_t *block, size_t nblocks, + const uint32_t inc) +{ + uint32_t m[16]; + uint32_t v[16]; + int i; + + while (nblocks > 0) { + blake2s_increment_counter(state, inc); + memcpy(m, block, BLAKE2S_BLOCK_SIZE); + le32_to_cpu_array(m, ARRAY_SIZE(m)); + memcpy(v, state->h, 32); + v[ 8] = blake2s_iv[0]; + v[ 9] = blake2s_iv[1]; + v[10] = blake2s_iv[2]; + v[11] = blake2s_iv[3]; + v[12] = blake2s_iv[4] ^ state->t[0]; + v[13] = blake2s_iv[5] ^ state->t[1]; + v[14] = blake2s_iv[6] ^ state->f[0]; + v[15] = blake2s_iv[7] ^ state->f[1]; + +#define G(r, i, a, b, c, d) do { \ + a += b + m[blake2s_sigma[r][2 * i + 0]]; \ + d = ror32(d ^ a, 16); \ + c += d; \ + b = ror32(b ^ c, 12); \ + a += b + m[blake2s_sigma[r][2 * i + 1]]; \ + d = ror32(d ^ a, 8); \ + c += d; \ + b = ror32(b ^ c, 7); \ +} while (0) + +#define ROUND(r) do { \ + G(r, 0, v[0], v[ 4], v[ 8], v[12]); \ + G(r, 1, v[1], v[ 5], v[ 9], v[13]); \ + G(r, 2, v[2], v[ 6], v[10], v[14]); \ + G(r, 3, v[3], v[ 7], v[11], v[15]); \ + G(r, 4, v[0], v[ 5], v[10], v[15]); \ + G(r, 5, v[1], v[ 6], v[11], v[12]); \ + G(r, 6, v[2], v[ 7], v[ 8], v[13]); \ + G(r, 7, v[3], v[ 4], v[ 9], v[14]); \ +} while (0) + ROUND(0); + ROUND(1); + ROUND(2); + ROUND(3); + ROUND(4); + ROUND(5); + ROUND(6); + ROUND(7); + ROUND(8); + ROUND(9); + +#undef G +#undef ROUND + + for (i = 0; i < 8; ++i) + state->h[i] ^= v[i] ^ v[i + 8]; + + block += BLAKE2S_BLOCK_SIZE; + --nblocks; + } +} + +void blake2s_update(struct blake2s_state *state, const uint8_t *in, size_t inlen) +{ + const size_t fill = BLAKE2S_BLOCK_SIZE - state->buflen; + + if (!inlen) + return; + if (inlen > fill) { + memcpy(state->buf + state->buflen, in, fill); + blake2s_compress(state, state->buf, 1, BLAKE2S_BLOCK_SIZE); + state->buflen = 0; + in += fill; + inlen -= fill; + } + if (inlen > BLAKE2S_BLOCK_SIZE) { + const size_t nblocks = DIV_ROUND_UP(inlen, BLAKE2S_BLOCK_SIZE); + /* Hash one less (full) block than strictly possible */ + blake2s_compress(state, in, nblocks - 1, BLAKE2S_BLOCK_SIZE); + in += BLAKE2S_BLOCK_SIZE * (nblocks - 1); + inlen -= BLAKE2S_BLOCK_SIZE * (nblocks - 1); + } + memcpy(state->buf + state->buflen, in, inlen); + state->buflen += inlen; +} + +void blake2s_final(struct blake2s_state *state, uint8_t *out) +{ + blake2s_set_lastblock(state); + memset(state->buf + state->buflen, 0, + BLAKE2S_BLOCK_SIZE - state->buflen); /* Padding */ + blake2s_compress(state, state->buf, 1, state->buflen); + cpu_to_le32_array(state->h, ARRAY_SIZE(state->h)); + memcpy(out, state->h, state->outlen); + explicit_bzero(state, sizeof(*state)); +} + +void blake2s(uint8_t *out, const uint8_t *in, const uint8_t *key, + const size_t outlen, const size_t inlen, const size_t keylen) +{ + struct blake2s_state state; + + if (keylen) + blake2s_init_key(&state, outlen, key, keylen); + else + blake2s_init(&state, outlen); + + blake2s_update(&state, in, inlen); + blake2s_final(&state, out); +} + +void blake2s_hmac(uint8_t *out, const uint8_t *in, const uint8_t *key, const size_t outlen, + const size_t inlen, const size_t keylen) +{ + struct blake2s_state state; + uint8_t x_key[BLAKE2S_BLOCK_SIZE] __aligned(sizeof(uint32_t)) = { 0 }; + uint8_t i_hash[BLAKE2S_HASH_SIZE] __aligned(sizeof(uint32_t)); + int i; + + if (keylen > BLAKE2S_BLOCK_SIZE) { + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, key, keylen); + blake2s_final(&state, x_key); + } else + memcpy(x_key, key, keylen); + + for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i) + x_key[i] ^= 0x36; + + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE); + blake2s_update(&state, in, inlen); + blake2s_final(&state, i_hash); + + for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i) + x_key[i] ^= 0x5c ^ 0x36; + + blake2s_init(&state, BLAKE2S_HASH_SIZE); + blake2s_update(&state, x_key, BLAKE2S_BLOCK_SIZE); + blake2s_update(&state, i_hash, BLAKE2S_HASH_SIZE); + blake2s_final(&state, i_hash); + + memcpy(out, i_hash, outlen); + explicit_bzero(x_key, BLAKE2S_BLOCK_SIZE); + explicit_bzero(i_hash, BLAKE2S_HASH_SIZE); +} + + +/* Below here is fiat's implementation of x25519. + * + * Copyright (C) 2015-2016 The fiat-crypto Authors. + * Copyright (C) 2018-2021 Jason A. Donenfeld . All Rights Reserved. + * + * This is a machine-generated formally verified implementation of Curve25519 + * ECDH from: . Though originally + * machine generated, it has been tweaked to be suitable for use in the kernel. + * It is optimized for 32-bit machines and machines that cannot work efficiently + * with 128-bit integer types. + */ + +/* fe means field element. Here the field is \Z/(2^255-19). An element t, + * entries t[0]...t[9], represents the integer t[0]+2^26 t[1]+2^51 t[2]+2^77 + * t[3]+2^102 t[4]+...+2^230 t[9]. + * fe limbs are bounded by 1.125*2^26,1.125*2^25,1.125*2^26,1.125*2^25,etc. + * Multiplication and carrying produce fe from fe_loose. + */ +typedef struct fe { uint32_t v[10]; } fe; + +/* fe_loose limbs are bounded by 3.375*2^26,3.375*2^25,3.375*2^26,3.375*2^25,etc + * Addition and subtraction produce fe_loose from (fe, fe). + */ +typedef struct fe_loose { uint32_t v[10]; } fe_loose; + +static inline void fe_frombytes_impl(uint32_t h[10], const uint8_t *s) +{ + /* Ignores top bit of s. */ + uint32_t a0 = get_unaligned_le32(s); + uint32_t a1 = get_unaligned_le32(s+4); + uint32_t a2 = get_unaligned_le32(s+8); + uint32_t a3 = get_unaligned_le32(s+12); + uint32_t a4 = get_unaligned_le32(s+16); + uint32_t a5 = get_unaligned_le32(s+20); + uint32_t a6 = get_unaligned_le32(s+24); + uint32_t a7 = get_unaligned_le32(s+28); + h[0] = a0&((1<<26)-1); /* 26 used, 32-26 left. 26 */ + h[1] = (a0>>26) | ((a1&((1<<19)-1))<< 6); /* (32-26) + 19 = 6+19 = 25 */ + h[2] = (a1>>19) | ((a2&((1<<13)-1))<<13); /* (32-19) + 13 = 13+13 = 26 */ + h[3] = (a2>>13) | ((a3&((1<< 6)-1))<<19); /* (32-13) + 6 = 19+ 6 = 25 */ + h[4] = (a3>> 6); /* (32- 6) = 26 */ + h[5] = a4&((1<<25)-1); /* 25 */ + h[6] = (a4>>25) | ((a5&((1<<19)-1))<< 7); /* (32-25) + 19 = 7+19 = 26 */ + h[7] = (a5>>19) | ((a6&((1<<12)-1))<<13); /* (32-19) + 12 = 13+12 = 25 */ + h[8] = (a6>>12) | ((a7&((1<< 6)-1))<<20); /* (32-12) + 6 = 20+ 6 = 26 */ + h[9] = (a7>> 6)&((1<<25)-1); /* 25 */ +} + +static inline void fe_frombytes(fe *h, const uint8_t *s) +{ + fe_frombytes_impl(h->v, s); +} + +static inline uint8_t /*bool*/ +addcarryx_u25(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low) +{ + /* This function extracts 25 bits of result and 1 bit of carry + * (26 total), so a 32-bit intermediate is sufficient. + */ + uint32_t x = a + b + c; + *low = x & ((1 << 25) - 1); + return (x >> 25) & 1; +} + +static inline uint8_t /*bool*/ +addcarryx_u26(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low) +{ + /* This function extracts 26 bits of result and 1 bit of carry + * (27 total), so a 32-bit intermediate is sufficient. + */ + uint32_t x = a + b + c; + *low = x & ((1 << 26) - 1); + return (x >> 26) & 1; +} + +static inline uint8_t /*bool*/ +subborrow_u25(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low) +{ + /* This function extracts 25 bits of result and 1 bit of borrow + * (26 total), so a 32-bit intermediate is sufficient. + */ + uint32_t x = a - b - c; + *low = x & ((1 << 25) - 1); + return x >> 31; +} + +static inline uint8_t /*bool*/ +subborrow_u26(uint8_t /*bool*/ c, uint32_t a, uint32_t b, uint32_t *low) +{ + /* This function extracts 26 bits of result and 1 bit of borrow + *(27 total), so a 32-bit intermediate is sufficient. + */ + uint32_t x = a - b - c; + *low = x & ((1 << 26) - 1); + return x >> 31; +} + +static inline uint32_t cmovznz32(uint32_t t, uint32_t z, uint32_t nz) +{ + t = -!!t; /* all set if nonzero, 0 if 0 */ + return (t&nz) | ((~t)&z); +} + +static inline void fe_freeze(uint32_t out[10], const uint32_t in1[10]) +{ + const uint32_t x17 = in1[9]; + const uint32_t x18 = in1[8]; + const uint32_t x16 = in1[7]; + const uint32_t x14 = in1[6]; + const uint32_t x12 = in1[5]; + const uint32_t x10 = in1[4]; + const uint32_t x8 = in1[3]; + const uint32_t x6 = in1[2]; + const uint32_t x4 = in1[1]; + const uint32_t x2 = in1[0]; + uint32_t x20; uint8_t/*bool*/ x21 = subborrow_u26(0x0, x2, 0x3ffffed, &x20); + uint32_t x23; uint8_t/*bool*/ x24 = subborrow_u25(x21, x4, 0x1ffffff, &x23); + uint32_t x26; uint8_t/*bool*/ x27 = subborrow_u26(x24, x6, 0x3ffffff, &x26); + uint32_t x29; uint8_t/*bool*/ x30 = subborrow_u25(x27, x8, 0x1ffffff, &x29); + uint32_t x32; uint8_t/*bool*/ x33 = subborrow_u26(x30, x10, 0x3ffffff, &x32); + uint32_t x35; uint8_t/*bool*/ x36 = subborrow_u25(x33, x12, 0x1ffffff, &x35); + uint32_t x38; uint8_t/*bool*/ x39 = subborrow_u26(x36, x14, 0x3ffffff, &x38); + uint32_t x41; uint8_t/*bool*/ x42 = subborrow_u25(x39, x16, 0x1ffffff, &x41); + uint32_t x44; uint8_t/*bool*/ x45 = subborrow_u26(x42, x18, 0x3ffffff, &x44); + uint32_t x47; uint8_t/*bool*/ x48 = subborrow_u25(x45, x17, 0x1ffffff, &x47); + uint32_t x49 = cmovznz32(x48, 0x0, 0xffffffff); + uint32_t x50 = (x49 & 0x3ffffed); + uint32_t x52; uint8_t/*bool*/ x53 = addcarryx_u26(0x0, x20, x50, &x52); + uint32_t x54 = (x49 & 0x1ffffff); + uint32_t x56; uint8_t/*bool*/ x57 = addcarryx_u25(x53, x23, x54, &x56); + uint32_t x58 = (x49 & 0x3ffffff); + uint32_t x60; uint8_t/*bool*/ x61 = addcarryx_u26(x57, x26, x58, &x60); + uint32_t x62 = (x49 & 0x1ffffff); + uint32_t x64; uint8_t/*bool*/ x65 = addcarryx_u25(x61, x29, x62, &x64); + uint32_t x66 = (x49 & 0x3ffffff); + uint32_t x68; uint8_t/*bool*/ x69 = addcarryx_u26(x65, x32, x66, &x68); + uint32_t x70 = (x49 & 0x1ffffff); + uint32_t x72; uint8_t/*bool*/ x73 = addcarryx_u25(x69, x35, x70, &x72); + uint32_t x74 = (x49 & 0x3ffffff); + uint32_t x76; uint8_t/*bool*/ x77 = addcarryx_u26(x73, x38, x74, &x76); + uint32_t x78 = (x49 & 0x1ffffff); + uint32_t x80; uint8_t/*bool*/ x81 = addcarryx_u25(x77, x41, x78, &x80); + uint32_t x82 = (x49 & 0x3ffffff); + uint32_t x84; uint8_t/*bool*/ x85 = addcarryx_u26(x81, x44, x82, &x84); + uint32_t x86 = (x49 & 0x1ffffff); + uint32_t x88; addcarryx_u25(x85, x47, x86, &x88); + out[0] = x52; + out[1] = x56; + out[2] = x60; + out[3] = x64; + out[4] = x68; + out[5] = x72; + out[6] = x76; + out[7] = x80; + out[8] = x84; + out[9] = x88; +} + +static inline void fe_tobytes(uint8_t s[32], const fe *f) +{ + uint32_t h[10]; + fe_freeze(h, f->v); + s[0] = h[0] >> 0; + s[1] = h[0] >> 8; + s[2] = h[0] >> 16; + s[3] = (h[0] >> 24) | (h[1] << 2); + s[4] = h[1] >> 6; + s[5] = h[1] >> 14; + s[6] = (h[1] >> 22) | (h[2] << 3); + s[7] = h[2] >> 5; + s[8] = h[2] >> 13; + s[9] = (h[2] >> 21) | (h[3] << 5); + s[10] = h[3] >> 3; + s[11] = h[3] >> 11; + s[12] = (h[3] >> 19) | (h[4] << 6); + s[13] = h[4] >> 2; + s[14] = h[4] >> 10; + s[15] = h[4] >> 18; + s[16] = h[5] >> 0; + s[17] = h[5] >> 8; + s[18] = h[5] >> 16; + s[19] = (h[5] >> 24) | (h[6] << 1); + s[20] = h[6] >> 7; + s[21] = h[6] >> 15; + s[22] = (h[6] >> 23) | (h[7] << 3); + s[23] = h[7] >> 5; + s[24] = h[7] >> 13; + s[25] = (h[7] >> 21) | (h[8] << 4); + s[26] = h[8] >> 4; + s[27] = h[8] >> 12; + s[28] = (h[8] >> 20) | (h[9] << 6); + s[29] = h[9] >> 2; + s[30] = h[9] >> 10; + s[31] = h[9] >> 18; +} + +/* h = f */ +static inline void fe_copy(fe *h, const fe *f) +{ + memmove(h, f, sizeof(uint32_t) * 10); +} + +static inline void fe_copy_lt(fe_loose *h, const fe *f) +{ + memmove(h, f, sizeof(uint32_t) * 10); +} + +/* h = 0 */ +static inline void fe_0(fe *h) +{ + memset(h, 0, sizeof(uint32_t) * 10); +} + +/* h = 1 */ +static inline void fe_1(fe *h) +{ + memset(h, 0, sizeof(uint32_t) * 10); + h->v[0] = 1; +} + +static void fe_add_impl(uint32_t out[10], const uint32_t in1[10], const uint32_t in2[10]) +{ + const uint32_t x20 = in1[9]; + const uint32_t x21 = in1[8]; + const uint32_t x19 = in1[7]; + const uint32_t x17 = in1[6]; + const uint32_t x15 = in1[5]; + const uint32_t x13 = in1[4]; + const uint32_t x11 = in1[3]; + const uint32_t x9 = in1[2]; + const uint32_t x7 = in1[1]; + const uint32_t x5 = in1[0]; + const uint32_t x38 = in2[9]; + const uint32_t x39 = in2[8]; + const uint32_t x37 = in2[7]; + const uint32_t x35 = in2[6]; + const uint32_t x33 = in2[5]; + const uint32_t x31 = in2[4]; + const uint32_t x29 = in2[3]; + const uint32_t x27 = in2[2]; + const uint32_t x25 = in2[1]; + const uint32_t x23 = in2[0]; + out[0] = (x5 + x23); + out[1] = (x7 + x25); + out[2] = (x9 + x27); + out[3] = (x11 + x29); + out[4] = (x13 + x31); + out[5] = (x15 + x33); + out[6] = (x17 + x35); + out[7] = (x19 + x37); + out[8] = (x21 + x39); + out[9] = (x20 + x38); +} + +/* h = f + g + * Can overlap h with f or g. + */ +static inline void fe_add(fe_loose *h, const fe *f, const fe *g) +{ + fe_add_impl(h->v, f->v, g->v); +} + +static void fe_sub_impl(uint32_t out[10], const uint32_t in1[10], const uint32_t in2[10]) +{ + const uint32_t x20 = in1[9]; + const uint32_t x21 = in1[8]; + const uint32_t x19 = in1[7]; + const uint32_t x17 = in1[6]; + const uint32_t x15 = in1[5]; + const uint32_t x13 = in1[4]; + const uint32_t x11 = in1[3]; + const uint32_t x9 = in1[2]; + const uint32_t x7 = in1[1]; + const uint32_t x5 = in1[0]; + const uint32_t x38 = in2[9]; + const uint32_t x39 = in2[8]; + const uint32_t x37 = in2[7]; + const uint32_t x35 = in2[6]; + const uint32_t x33 = in2[5]; + const uint32_t x31 = in2[4]; + const uint32_t x29 = in2[3]; + const uint32_t x27 = in2[2]; + const uint32_t x25 = in2[1]; + const uint32_t x23 = in2[0]; + out[0] = ((0x7ffffda + x5) - x23); + out[1] = ((0x3fffffe + x7) - x25); + out[2] = ((0x7fffffe + x9) - x27); + out[3] = ((0x3fffffe + x11) - x29); + out[4] = ((0x7fffffe + x13) - x31); + out[5] = ((0x3fffffe + x15) - x33); + out[6] = ((0x7fffffe + x17) - x35); + out[7] = ((0x3fffffe + x19) - x37); + out[8] = ((0x7fffffe + x21) - x39); + out[9] = ((0x3fffffe + x20) - x38); +} + +/* h = f - g + * Can overlap h with f or g. + */ +static inline void fe_sub(fe_loose *h, const fe *f, const fe *g) +{ + fe_sub_impl(h->v, f->v, g->v); +} + +static void fe_mul_impl(uint32_t out[10], const uint32_t in1[10], const uint32_t in2[10]) +{ + const uint32_t x20 = in1[9]; + const uint32_t x21 = in1[8]; + const uint32_t x19 = in1[7]; + const uint32_t x17 = in1[6]; + const uint32_t x15 = in1[5]; + const uint32_t x13 = in1[4]; + const uint32_t x11 = in1[3]; + const uint32_t x9 = in1[2]; + const uint32_t x7 = in1[1]; + const uint32_t x5 = in1[0]; + const uint32_t x38 = in2[9]; + const uint32_t x39 = in2[8]; + const uint32_t x37 = in2[7]; + const uint32_t x35 = in2[6]; + const uint32_t x33 = in2[5]; + const uint32_t x31 = in2[4]; + const uint32_t x29 = in2[3]; + const uint32_t x27 = in2[2]; + const uint32_t x25 = in2[1]; + const uint32_t x23 = in2[0]; + uint64_t x40 = ((uint64_t)x23 * x5); + uint64_t x41 = (((uint64_t)x23 * x7) + ((uint64_t)x25 * x5)); + uint64_t x42 = ((((uint64_t)(0x2 * x25) * x7) + ((uint64_t)x23 * x9)) + ((uint64_t)x27 * x5)); + uint64_t x43 = (((((uint64_t)x25 * x9) + ((uint64_t)x27 * x7)) + ((uint64_t)x23 * x11)) + ((uint64_t)x29 * x5)); + uint64_t x44 = (((((uint64_t)x27 * x9) + (0x2 * (((uint64_t)x25 * x11) + ((uint64_t)x29 * x7)))) + ((uint64_t)x23 * x13)) + ((uint64_t)x31 * x5)); + uint64_t x45 = (((((((uint64_t)x27 * x11) + ((uint64_t)x29 * x9)) + ((uint64_t)x25 * x13)) + ((uint64_t)x31 * x7)) + ((uint64_t)x23 * x15)) + ((uint64_t)x33 * x5)); + uint64_t x46 = (((((0x2 * ((((uint64_t)x29 * x11) + ((uint64_t)x25 * x15)) + ((uint64_t)x33 * x7))) + ((uint64_t)x27 * x13)) + ((uint64_t)x31 * x9)) + ((uint64_t)x23 * x17)) + ((uint64_t)x35 * x5)); + uint64_t x47 = (((((((((uint64_t)x29 * x13) + ((uint64_t)x31 * x11)) + ((uint64_t)x27 * x15)) + ((uint64_t)x33 * x9)) + ((uint64_t)x25 * x17)) + ((uint64_t)x35 * x7)) + ((uint64_t)x23 * x19)) + ((uint64_t)x37 * x5)); + uint64_t x48 = (((((((uint64_t)x31 * x13) + (0x2 * (((((uint64_t)x29 * x15) + ((uint64_t)x33 * x11)) + ((uint64_t)x25 * x19)) + ((uint64_t)x37 * x7)))) + ((uint64_t)x27 * x17)) + ((uint64_t)x35 * x9)) + ((uint64_t)x23 * x21)) + ((uint64_t)x39 * x5)); + uint64_t x49 = (((((((((((uint64_t)x31 * x15) + ((uint64_t)x33 * x13)) + ((uint64_t)x29 * x17)) + ((uint64_t)x35 * x11)) + ((uint64_t)x27 * x19)) + ((uint64_t)x37 * x9)) + ((uint64_t)x25 * x21)) + ((uint64_t)x39 * x7)) + ((uint64_t)x23 * x20)) + ((uint64_t)x38 * x5)); + uint64_t x50 = (((((0x2 * ((((((uint64_t)x33 * x15) + ((uint64_t)x29 * x19)) + ((uint64_t)x37 * x11)) + ((uint64_t)x25 * x20)) + ((uint64_t)x38 * x7))) + ((uint64_t)x31 * x17)) + ((uint64_t)x35 * x13)) + ((uint64_t)x27 * x21)) + ((uint64_t)x39 * x9)); + uint64_t x51 = (((((((((uint64_t)x33 * x17) + ((uint64_t)x35 * x15)) + ((uint64_t)x31 * x19)) + ((uint64_t)x37 * x13)) + ((uint64_t)x29 * x21)) + ((uint64_t)x39 * x11)) + ((uint64_t)x27 * x20)) + ((uint64_t)x38 * x9)); + uint64_t x52 = (((((uint64_t)x35 * x17) + (0x2 * (((((uint64_t)x33 * x19) + ((uint64_t)x37 * x15)) + ((uint64_t)x29 * x20)) + ((uint64_t)x38 * x11)))) + ((uint64_t)x31 * x21)) + ((uint64_t)x39 * x13)); + uint64_t x53 = (((((((uint64_t)x35 * x19) + ((uint64_t)x37 * x17)) + ((uint64_t)x33 * x21)) + ((uint64_t)x39 * x15)) + ((uint64_t)x31 * x20)) + ((uint64_t)x38 * x13)); + uint64_t x54 = (((0x2 * ((((uint64_t)x37 * x19) + ((uint64_t)x33 * x20)) + ((uint64_t)x38 * x15))) + ((uint64_t)x35 * x21)) + ((uint64_t)x39 * x17)); + uint64_t x55 = (((((uint64_t)x37 * x21) + ((uint64_t)x39 * x19)) + ((uint64_t)x35 * x20)) + ((uint64_t)x38 * x17)); + uint64_t x56 = (((uint64_t)x39 * x21) + (0x2 * (((uint64_t)x37 * x20) + ((uint64_t)x38 * x19)))); + uint64_t x57 = (((uint64_t)x39 * x20) + ((uint64_t)x38 * x21)); + uint64_t x58 = ((uint64_t)(0x2 * x38) * x20); + uint64_t x59 = (x48 + (x58 << 0x4)); + uint64_t x60 = (x59 + (x58 << 0x1)); + uint64_t x61 = (x60 + x58); + uint64_t x62 = (x47 + (x57 << 0x4)); + uint64_t x63 = (x62 + (x57 << 0x1)); + uint64_t x64 = (x63 + x57); + uint64_t x65 = (x46 + (x56 << 0x4)); + uint64_t x66 = (x65 + (x56 << 0x1)); + uint64_t x67 = (x66 + x56); + uint64_t x68 = (x45 + (x55 << 0x4)); + uint64_t x69 = (x68 + (x55 << 0x1)); + uint64_t x70 = (x69 + x55); + uint64_t x71 = (x44 + (x54 << 0x4)); + uint64_t x72 = (x71 + (x54 << 0x1)); + uint64_t x73 = (x72 + x54); + uint64_t x74 = (x43 + (x53 << 0x4)); + uint64_t x75 = (x74 + (x53 << 0x1)); + uint64_t x76 = (x75 + x53); + uint64_t x77 = (x42 + (x52 << 0x4)); + uint64_t x78 = (x77 + (x52 << 0x1)); + uint64_t x79 = (x78 + x52); + uint64_t x80 = (x41 + (x51 << 0x4)); + uint64_t x81 = (x80 + (x51 << 0x1)); + uint64_t x82 = (x81 + x51); + uint64_t x83 = (x40 + (x50 << 0x4)); + uint64_t x84 = (x83 + (x50 << 0x1)); + uint64_t x85 = (x84 + x50); + uint64_t x86 = (x85 >> 0x1a); + uint32_t x87 = ((uint32_t)x85 & 0x3ffffff); + uint64_t x88 = (x86 + x82); + uint64_t x89 = (x88 >> 0x19); + uint32_t x90 = ((uint32_t)x88 & 0x1ffffff); + uint64_t x91 = (x89 + x79); + uint64_t x92 = (x91 >> 0x1a); + uint32_t x93 = ((uint32_t)x91 & 0x3ffffff); + uint64_t x94 = (x92 + x76); + uint64_t x95 = (x94 >> 0x19); + uint32_t x96 = ((uint32_t)x94 & 0x1ffffff); + uint64_t x97 = (x95 + x73); + uint64_t x98 = (x97 >> 0x1a); + uint32_t x99 = ((uint32_t)x97 & 0x3ffffff); + uint64_t x100 = (x98 + x70); + uint64_t x101 = (x100 >> 0x19); + uint32_t x102 = ((uint32_t)x100 & 0x1ffffff); + uint64_t x103 = (x101 + x67); + uint64_t x104 = (x103 >> 0x1a); + uint32_t x105 = ((uint32_t)x103 & 0x3ffffff); + uint64_t x106 = (x104 + x64); + uint64_t x107 = (x106 >> 0x19); + uint32_t x108 = ((uint32_t)x106 & 0x1ffffff); + uint64_t x109 = (x107 + x61); + uint64_t x110 = (x109 >> 0x1a); + uint32_t x111 = ((uint32_t)x109 & 0x3ffffff); + uint64_t x112 = (x110 + x49); + uint64_t x113 = (x112 >> 0x19); + uint32_t x114 = ((uint32_t)x112 & 0x1ffffff); + uint64_t x115 = (x87 + (0x13 * x113)); + uint32_t x116 = (uint32_t) (x115 >> 0x1a); + uint32_t x117 = ((uint32_t)x115 & 0x3ffffff); + uint32_t x118 = (x116 + x90); + uint32_t x119 = (x118 >> 0x19); + uint32_t x120 = (x118 & 0x1ffffff); + out[0] = x117; + out[1] = x120; + out[2] = (x119 + x93); + out[3] = x96; + out[4] = x99; + out[5] = x102; + out[6] = x105; + out[7] = x108; + out[8] = x111; + out[9] = x114; +} + +static inline void fe_mul_ttt(fe *h, const fe *f, const fe *g) +{ + fe_mul_impl(h->v, f->v, g->v); +} + +static inline void fe_mul_tlt(fe *h, const fe_loose *f, const fe *g) +{ + fe_mul_impl(h->v, f->v, g->v); +} + +static inline void +fe_mul_tll(fe *h, const fe_loose *f, const fe_loose *g) +{ + fe_mul_impl(h->v, f->v, g->v); +} + +static void fe_sqr_impl(uint32_t out[10], const uint32_t in1[10]) +{ + const uint32_t x17 = in1[9]; + const uint32_t x18 = in1[8]; + const uint32_t x16 = in1[7]; + const uint32_t x14 = in1[6]; + const uint32_t x12 = in1[5]; + const uint32_t x10 = in1[4]; + const uint32_t x8 = in1[3]; + const uint32_t x6 = in1[2]; + const uint32_t x4 = in1[1]; + const uint32_t x2 = in1[0]; + uint64_t x19 = ((uint64_t)x2 * x2); + uint64_t x20 = ((uint64_t)(0x2 * x2) * x4); + uint64_t x21 = (0x2 * (((uint64_t)x4 * x4) + ((uint64_t)x2 * x6))); + uint64_t x22 = (0x2 * (((uint64_t)x4 * x6) + ((uint64_t)x2 * x8))); + uint64_t x23 = ((((uint64_t)x6 * x6) + ((uint64_t)(0x4 * x4) * x8)) + ((uint64_t)(0x2 * x2) * x10)); + uint64_t x24 = (0x2 * ((((uint64_t)x6 * x8) + ((uint64_t)x4 * x10)) + ((uint64_t)x2 * x12))); + uint64_t x25 = (0x2 * (((((uint64_t)x8 * x8) + ((uint64_t)x6 * x10)) + ((uint64_t)x2 * x14)) + ((uint64_t)(0x2 * x4) * x12))); + uint64_t x26 = (0x2 * (((((uint64_t)x8 * x10) + ((uint64_t)x6 * x12)) + ((uint64_t)x4 * x14)) + ((uint64_t)x2 * x16))); + uint64_t x27 = (((uint64_t)x10 * x10) + (0x2 * ((((uint64_t)x6 * x14) + ((uint64_t)x2 * x18)) + (0x2 * (((uint64_t)x4 * x16) + ((uint64_t)x8 * x12)))))); + uint64_t x28 = (0x2 * ((((((uint64_t)x10 * x12) + ((uint64_t)x8 * x14)) + ((uint64_t)x6 * x16)) + ((uint64_t)x4 * x18)) + ((uint64_t)x2 * x17))); + uint64_t x29 = (0x2 * (((((uint64_t)x12 * x12) + ((uint64_t)x10 * x14)) + ((uint64_t)x6 * x18)) + (0x2 * (((uint64_t)x8 * x16) + ((uint64_t)x4 * x17))))); + uint64_t x30 = (0x2 * (((((uint64_t)x12 * x14) + ((uint64_t)x10 * x16)) + ((uint64_t)x8 * x18)) + ((uint64_t)x6 * x17))); + uint64_t x31 = (((uint64_t)x14 * x14) + (0x2 * (((uint64_t)x10 * x18) + (0x2 * (((uint64_t)x12 * x16) + ((uint64_t)x8 * x17)))))); + uint64_t x32 = (0x2 * ((((uint64_t)x14 * x16) + ((uint64_t)x12 * x18)) + ((uint64_t)x10 * x17))); + uint64_t x33 = (0x2 * ((((uint64_t)x16 * x16) + ((uint64_t)x14 * x18)) + ((uint64_t)(0x2 * x12) * x17))); + uint64_t x34 = (0x2 * (((uint64_t)x16 * x18) + ((uint64_t)x14 * x17))); + uint64_t x35 = (((uint64_t)x18 * x18) + ((uint64_t)(0x4 * x16) * x17)); + uint64_t x36 = ((uint64_t)(0x2 * x18) * x17); + uint64_t x37 = ((uint64_t)(0x2 * x17) * x17); + uint64_t x38 = (x27 + (x37 << 0x4)); + uint64_t x39 = (x38 + (x37 << 0x1)); + uint64_t x40 = (x39 + x37); + uint64_t x41 = (x26 + (x36 << 0x4)); + uint64_t x42 = (x41 + (x36 << 0x1)); + uint64_t x43 = (x42 + x36); + uint64_t x44 = (x25 + (x35 << 0x4)); + uint64_t x45 = (x44 + (x35 << 0x1)); + uint64_t x46 = (x45 + x35); + uint64_t x47 = (x24 + (x34 << 0x4)); + uint64_t x48 = (x47 + (x34 << 0x1)); + uint64_t x49 = (x48 + x34); + uint64_t x50 = (x23 + (x33 << 0x4)); + uint64_t x51 = (x50 + (x33 << 0x1)); + uint64_t x52 = (x51 + x33); + uint64_t x53 = (x22 + (x32 << 0x4)); + uint64_t x54 = (x53 + (x32 << 0x1)); + uint64_t x55 = (x54 + x32); + uint64_t x56 = (x21 + (x31 << 0x4)); + uint64_t x57 = (x56 + (x31 << 0x1)); + uint64_t x58 = (x57 + x31); + uint64_t x59 = (x20 + (x30 << 0x4)); + uint64_t x60 = (x59 + (x30 << 0x1)); + uint64_t x61 = (x60 + x30); + uint64_t x62 = (x19 + (x29 << 0x4)); + uint64_t x63 = (x62 + (x29 << 0x1)); + uint64_t x64 = (x63 + x29); + uint64_t x65 = (x64 >> 0x1a); + uint32_t x66 = ((uint32_t)x64 & 0x3ffffff); + uint64_t x67 = (x65 + x61); + uint64_t x68 = (x67 >> 0x19); + uint32_t x69 = ((uint32_t)x67 & 0x1ffffff); + uint64_t x70 = (x68 + x58); + uint64_t x71 = (x70 >> 0x1a); + uint32_t x72 = ((uint32_t)x70 & 0x3ffffff); + uint64_t x73 = (x71 + x55); + uint64_t x74 = (x73 >> 0x19); + uint32_t x75 = ((uint32_t)x73 & 0x1ffffff); + uint64_t x76 = (x74 + x52); + uint64_t x77 = (x76 >> 0x1a); + uint32_t x78 = ((uint32_t)x76 & 0x3ffffff); + uint64_t x79 = (x77 + x49); + uint64_t x80 = (x79 >> 0x19); + uint32_t x81 = ((uint32_t)x79 & 0x1ffffff); + uint64_t x82 = (x80 + x46); + uint64_t x83 = (x82 >> 0x1a); + uint32_t x84 = ((uint32_t)x82 & 0x3ffffff); + uint64_t x85 = (x83 + x43); + uint64_t x86 = (x85 >> 0x19); + uint32_t x87 = ((uint32_t)x85 & 0x1ffffff); + uint64_t x88 = (x86 + x40); + uint64_t x89 = (x88 >> 0x1a); + uint32_t x90 = ((uint32_t)x88 & 0x3ffffff); + uint64_t x91 = (x89 + x28); + uint64_t x92 = (x91 >> 0x19); + uint32_t x93 = ((uint32_t)x91 & 0x1ffffff); + uint64_t x94 = (x66 + (0x13 * x92)); + uint32_t x95 = (uint32_t) (x94 >> 0x1a); + uint32_t x96 = ((uint32_t)x94 & 0x3ffffff); + uint32_t x97 = (x95 + x69); + uint32_t x98 = (x97 >> 0x19); + uint32_t x99 = (x97 & 0x1ffffff); + out[0] = x96; + out[1] = x99; + out[2] = (x98 + x72); + out[3] = x75; + out[4] = x78; + out[5] = x81; + out[6] = x84; + out[7] = x87; + out[8] = x90; + out[9] = x93; +} + +static inline void fe_sq_tl(fe *h, const fe_loose *f) +{ + fe_sqr_impl(h->v, f->v); +} + +static inline void fe_sq_tt(fe *h, const fe *f) +{ + fe_sqr_impl(h->v, f->v); +} + +static inline void fe_loose_invert(fe *out, const fe_loose *z) +{ + fe t0; + fe t1; + fe t2; + fe t3; + int i; + + fe_sq_tl(&t0, z); + fe_sq_tt(&t1, &t0); + for (i = 1; i < 2; ++i) + fe_sq_tt(&t1, &t1); + fe_mul_tlt(&t1, z, &t1); + fe_mul_ttt(&t0, &t0, &t1); + fe_sq_tt(&t2, &t0); + fe_mul_ttt(&t1, &t1, &t2); + fe_sq_tt(&t2, &t1); + for (i = 1; i < 5; ++i) + fe_sq_tt(&t2, &t2); + fe_mul_ttt(&t1, &t2, &t1); + fe_sq_tt(&t2, &t1); + for (i = 1; i < 10; ++i) + fe_sq_tt(&t2, &t2); + fe_mul_ttt(&t2, &t2, &t1); + fe_sq_tt(&t3, &t2); + for (i = 1; i < 20; ++i) + fe_sq_tt(&t3, &t3); + fe_mul_ttt(&t2, &t3, &t2); + fe_sq_tt(&t2, &t2); + for (i = 1; i < 10; ++i) + fe_sq_tt(&t2, &t2); + fe_mul_ttt(&t1, &t2, &t1); + fe_sq_tt(&t2, &t1); + for (i = 1; i < 50; ++i) + fe_sq_tt(&t2, &t2); + fe_mul_ttt(&t2, &t2, &t1); + fe_sq_tt(&t3, &t2); + for (i = 1; i < 100; ++i) + fe_sq_tt(&t3, &t3); + fe_mul_ttt(&t2, &t3, &t2); + fe_sq_tt(&t2, &t2); + for (i = 1; i < 50; ++i) + fe_sq_tt(&t2, &t2); + fe_mul_ttt(&t1, &t2, &t1); + fe_sq_tt(&t1, &t1); + for (i = 1; i < 5; ++i) + fe_sq_tt(&t1, &t1); + fe_mul_ttt(out, &t1, &t0); +} + +static inline void fe_invert(fe *out, const fe *z) +{ + fe_loose l; + fe_copy_lt(&l, z); + fe_loose_invert(out, &l); +} + +/* Replace (f,g) with (g,f) if b == 1; + * replace (f,g) with (f,g) if b == 0. + * + * Preconditions: b in {0,1} + */ +static inline void fe_cswap(fe *f, fe *g, unsigned int b) +{ + unsigned i; + b = 0 - b; + for (i = 0; i < 10; i++) { + uint32_t x = f->v[i] ^ g->v[i]; + x &= b; + f->v[i] ^= x; + g->v[i] ^= x; + } +} + +/* NOTE: based on fiat-crypto fe_mul, edited for in2=121666, 0, 0.*/ +static inline void fe_mul_121666_impl(uint32_t out[10], const uint32_t in1[10]) +{ + const uint32_t x20 = in1[9]; + const uint32_t x21 = in1[8]; + const uint32_t x19 = in1[7]; + const uint32_t x17 = in1[6]; + const uint32_t x15 = in1[5]; + const uint32_t x13 = in1[4]; + const uint32_t x11 = in1[3]; + const uint32_t x9 = in1[2]; + const uint32_t x7 = in1[1]; + const uint32_t x5 = in1[0]; + const uint32_t x38 = 0; + const uint32_t x39 = 0; + const uint32_t x37 = 0; + const uint32_t x35 = 0; + const uint32_t x33 = 0; + const uint32_t x31 = 0; + const uint32_t x29 = 0; + const uint32_t x27 = 0; + const uint32_t x25 = 0; + const uint32_t x23 = 121666; + uint64_t x40 = ((uint64_t)x23 * x5); + uint64_t x41 = (((uint64_t)x23 * x7) + ((uint64_t)x25 * x5)); + uint64_t x42 = ((((uint64_t)(0x2 * x25) * x7) + ((uint64_t)x23 * x9)) + ((uint64_t)x27 * x5)); + uint64_t x43 = (((((uint64_t)x25 * x9) + ((uint64_t)x27 * x7)) + ((uint64_t)x23 * x11)) + ((uint64_t)x29 * x5)); + uint64_t x44 = (((((uint64_t)x27 * x9) + (0x2 * (((uint64_t)x25 * x11) + ((uint64_t)x29 * x7)))) + ((uint64_t)x23 * x13)) + ((uint64_t)x31 * x5)); + uint64_t x45 = (((((((uint64_t)x27 * x11) + ((uint64_t)x29 * x9)) + ((uint64_t)x25 * x13)) + ((uint64_t)x31 * x7)) + ((uint64_t)x23 * x15)) + ((uint64_t)x33 * x5)); + uint64_t x46 = (((((0x2 * ((((uint64_t)x29 * x11) + ((uint64_t)x25 * x15)) + ((uint64_t)x33 * x7))) + ((uint64_t)x27 * x13)) + ((uint64_t)x31 * x9)) + ((uint64_t)x23 * x17)) + ((uint64_t)x35 * x5)); + uint64_t x47 = (((((((((uint64_t)x29 * x13) + ((uint64_t)x31 * x11)) + ((uint64_t)x27 * x15)) + ((uint64_t)x33 * x9)) + ((uint64_t)x25 * x17)) + ((uint64_t)x35 * x7)) + ((uint64_t)x23 * x19)) + ((uint64_t)x37 * x5)); + uint64_t x48 = (((((((uint64_t)x31 * x13) + (0x2 * (((((uint64_t)x29 * x15) + ((uint64_t)x33 * x11)) + ((uint64_t)x25 * x19)) + ((uint64_t)x37 * x7)))) + ((uint64_t)x27 * x17)) + ((uint64_t)x35 * x9)) + ((uint64_t)x23 * x21)) + ((uint64_t)x39 * x5)); + uint64_t x49 = (((((((((((uint64_t)x31 * x15) + ((uint64_t)x33 * x13)) + ((uint64_t)x29 * x17)) + ((uint64_t)x35 * x11)) + ((uint64_t)x27 * x19)) + ((uint64_t)x37 * x9)) + ((uint64_t)x25 * x21)) + ((uint64_t)x39 * x7)) + ((uint64_t)x23 * x20)) + ((uint64_t)x38 * x5)); + uint64_t x50 = (((((0x2 * ((((((uint64_t)x33 * x15) + ((uint64_t)x29 * x19)) + ((uint64_t)x37 * x11)) + ((uint64_t)x25 * x20)) + ((uint64_t)x38 * x7))) + ((uint64_t)x31 * x17)) + ((uint64_t)x35 * x13)) + ((uint64_t)x27 * x21)) + ((uint64_t)x39 * x9)); + uint64_t x51 = (((((((((uint64_t)x33 * x17) + ((uint64_t)x35 * x15)) + ((uint64_t)x31 * x19)) + ((uint64_t)x37 * x13)) + ((uint64_t)x29 * x21)) + ((uint64_t)x39 * x11)) + ((uint64_t)x27 * x20)) + ((uint64_t)x38 * x9)); + uint64_t x52 = (((((uint64_t)x35 * x17) + (0x2 * (((((uint64_t)x33 * x19) + ((uint64_t)x37 * x15)) + ((uint64_t)x29 * x20)) + ((uint64_t)x38 * x11)))) + ((uint64_t)x31 * x21)) + ((uint64_t)x39 * x13)); + uint64_t x53 = (((((((uint64_t)x35 * x19) + ((uint64_t)x37 * x17)) + ((uint64_t)x33 * x21)) + ((uint64_t)x39 * x15)) + ((uint64_t)x31 * x20)) + ((uint64_t)x38 * x13)); + uint64_t x54 = (((0x2 * ((((uint64_t)x37 * x19) + ((uint64_t)x33 * x20)) + ((uint64_t)x38 * x15))) + ((uint64_t)x35 * x21)) + ((uint64_t)x39 * x17)); + uint64_t x55 = (((((uint64_t)x37 * x21) + ((uint64_t)x39 * x19)) + ((uint64_t)x35 * x20)) + ((uint64_t)x38 * x17)); + uint64_t x56 = (((uint64_t)x39 * x21) + (0x2 * (((uint64_t)x37 * x20) + ((uint64_t)x38 * x19)))); + uint64_t x57 = (((uint64_t)x39 * x20) + ((uint64_t)x38 * x21)); + uint64_t x58 = ((uint64_t)(0x2 * x38) * x20); + uint64_t x59 = (x48 + (x58 << 0x4)); + uint64_t x60 = (x59 + (x58 << 0x1)); + uint64_t x61 = (x60 + x58); + uint64_t x62 = (x47 + (x57 << 0x4)); + uint64_t x63 = (x62 + (x57 << 0x1)); + uint64_t x64 = (x63 + x57); + uint64_t x65 = (x46 + (x56 << 0x4)); + uint64_t x66 = (x65 + (x56 << 0x1)); + uint64_t x67 = (x66 + x56); + uint64_t x68 = (x45 + (x55 << 0x4)); + uint64_t x69 = (x68 + (x55 << 0x1)); + uint64_t x70 = (x69 + x55); + uint64_t x71 = (x44 + (x54 << 0x4)); + uint64_t x72 = (x71 + (x54 << 0x1)); + uint64_t x73 = (x72 + x54); + uint64_t x74 = (x43 + (x53 << 0x4)); + uint64_t x75 = (x74 + (x53 << 0x1)); + uint64_t x76 = (x75 + x53); + uint64_t x77 = (x42 + (x52 << 0x4)); + uint64_t x78 = (x77 + (x52 << 0x1)); + uint64_t x79 = (x78 + x52); + uint64_t x80 = (x41 + (x51 << 0x4)); + uint64_t x81 = (x80 + (x51 << 0x1)); + uint64_t x82 = (x81 + x51); + uint64_t x83 = (x40 + (x50 << 0x4)); + uint64_t x84 = (x83 + (x50 << 0x1)); + uint64_t x85 = (x84 + x50); + uint64_t x86 = (x85 >> 0x1a); + uint32_t x87 = ((uint32_t)x85 & 0x3ffffff); + uint64_t x88 = (x86 + x82); + uint64_t x89 = (x88 >> 0x19); + uint32_t x90 = ((uint32_t)x88 & 0x1ffffff); + uint64_t x91 = (x89 + x79); + uint64_t x92 = (x91 >> 0x1a); + uint32_t x93 = ((uint32_t)x91 & 0x3ffffff); + uint64_t x94 = (x92 + x76); + uint64_t x95 = (x94 >> 0x19); + uint32_t x96 = ((uint32_t)x94 & 0x1ffffff); + uint64_t x97 = (x95 + x73); + uint64_t x98 = (x97 >> 0x1a); + uint32_t x99 = ((uint32_t)x97 & 0x3ffffff); + uint64_t x100 = (x98 + x70); + uint64_t x101 = (x100 >> 0x19); + uint32_t x102 = ((uint32_t)x100 & 0x1ffffff); + uint64_t x103 = (x101 + x67); + uint64_t x104 = (x103 >> 0x1a); + uint32_t x105 = ((uint32_t)x103 & 0x3ffffff); + uint64_t x106 = (x104 + x64); + uint64_t x107 = (x106 >> 0x19); + uint32_t x108 = ((uint32_t)x106 & 0x1ffffff); + uint64_t x109 = (x107 + x61); + uint64_t x110 = (x109 >> 0x1a); + uint32_t x111 = ((uint32_t)x109 & 0x3ffffff); + uint64_t x112 = (x110 + x49); + uint64_t x113 = (x112 >> 0x19); + uint32_t x114 = ((uint32_t)x112 & 0x1ffffff); + uint64_t x115 = (x87 + (0x13 * x113)); + uint32_t x116 = (uint32_t) (x115 >> 0x1a); + uint32_t x117 = ((uint32_t)x115 & 0x3ffffff); + uint32_t x118 = (x116 + x90); + uint32_t x119 = (x118 >> 0x19); + uint32_t x120 = (x118 & 0x1ffffff); + out[0] = x117; + out[1] = x120; + out[2] = (x119 + x93); + out[3] = x96; + out[4] = x99; + out[5] = x102; + out[6] = x105; + out[7] = x108; + out[8] = x111; + out[9] = x114; +} + +static inline void fe_mul121666(fe *h, const fe_loose *f) +{ + fe_mul_121666_impl(h->v, f->v); +} + +static const uint8_t curve25519_null_point[CURVE25519_KEY_SIZE]; + +bool curve25519(uint8_t out[CURVE25519_KEY_SIZE], + const uint8_t scalar[CURVE25519_KEY_SIZE], + const uint8_t point[CURVE25519_KEY_SIZE]) +{ + fe x1, x2, z2, x3, z3; + fe_loose x2l, z2l, x3l; + unsigned swap = 0; + int pos; + uint8_t e[32]; + + memcpy(e, scalar, 32); + curve25519_clamp_secret(e); + + /* The following implementation was transcribed to Coq and proven to + * correspond to unary scalar multiplication in affine coordinates given + * that x1 != 0 is the x coordinate of some point on the curve. It was + * also checked in Coq that doing a ladderstep with x1 = x3 = 0 gives + * z2' = z3' = 0, and z2 = z3 = 0 gives z2' = z3' = 0. The statement was + * quantified over the underlying field, so it applies to Curve25519 + * itself and the quadratic twist of Curve25519. It was not proven in + * Coq that prime-field arithmetic correctly simulates extension-field + * arithmetic on prime-field values. The decoding of the byte array + * representation of e was not considered. + * + * Specification of Montgomery curves in affine coordinates: + * + * + * Proof that these form a group that is isomorphic to a Weierstrass + * curve: + * + * + * Coq transcription and correctness proof of the loop + * (where scalarbits=255): + * + * + * preconditions: 0 <= e < 2^255 (not necessarily e < order), + * fe_invert(0) = 0 + */ + fe_frombytes(&x1, point); + fe_1(&x2); + fe_0(&z2); + fe_copy(&x3, &x1); + fe_1(&z3); + + for (pos = 254; pos >= 0; --pos) { + fe tmp0, tmp1; + fe_loose tmp0l, tmp1l; + /* loop invariant as of right before the test, for the case + * where x1 != 0: + * pos >= -1; if z2 = 0 then x2 is nonzero; if z3 = 0 then x3 + * is nonzero + * let r := e >> (pos+1) in the following equalities of + * projective points: + * to_xz (r*P) === if swap then (x3, z3) else (x2, z2) + * to_xz ((r+1)*P) === if swap then (x2, z2) else (x3, z3) + * x1 is the nonzero x coordinate of the nonzero + * point (r*P-(r+1)*P) + */ + unsigned b = 1 & (e[pos / 8] >> (pos & 7)); + swap ^= b; + fe_cswap(&x2, &x3, swap); + fe_cswap(&z2, &z3, swap); + swap = b; + /* Coq transcription of ladderstep formula (called from + * transcribed loop): + * + * + * x1 != 0 + * x1 = 0 + */ + fe_sub(&tmp0l, &x3, &z3); + fe_sub(&tmp1l, &x2, &z2); + fe_add(&x2l, &x2, &z2); + fe_add(&z2l, &x3, &z3); + fe_mul_tll(&z3, &tmp0l, &x2l); + fe_mul_tll(&z2, &z2l, &tmp1l); + fe_sq_tl(&tmp0, &tmp1l); + fe_sq_tl(&tmp1, &x2l); + fe_add(&x3l, &z3, &z2); + fe_sub(&z2l, &z3, &z2); + fe_mul_ttt(&x2, &tmp1, &tmp0); + fe_sub(&tmp1l, &tmp1, &tmp0); + fe_sq_tl(&z2, &z2l); + fe_mul121666(&z3, &tmp1l); + fe_sq_tl(&x3, &x3l); + fe_add(&tmp0l, &tmp0, &z3); + fe_mul_ttt(&z3, &x1, &z2); + fe_mul_tll(&z2, &tmp1l, &tmp0l); + } + /* here pos=-1, so r=e, so to_xz (e*P) === if swap then (x3, z3) + * else (x2, z2) + */ + fe_cswap(&x2, &x3, swap); + fe_cswap(&z2, &z3, swap); + + fe_invert(&z2, &z2); + fe_mul_ttt(&x2, &x2, &z2); + fe_tobytes(out, &x2); + + explicit_bzero(&x1, sizeof(x1)); + explicit_bzero(&x2, sizeof(x2)); + explicit_bzero(&z2, sizeof(z2)); + explicit_bzero(&x3, sizeof(x3)); + explicit_bzero(&z3, sizeof(z3)); + explicit_bzero(&x2l, sizeof(x2l)); + explicit_bzero(&z2l, sizeof(z2l)); + explicit_bzero(&x3l, sizeof(x3l)); + explicit_bzero(&e, sizeof(e)); + + return timingsafe_bcmp(out, curve25519_null_point, CURVE25519_KEY_SIZE) != 0; +} diff --git a/src/crypto.h b/src/crypto.h new file mode 100644 index 0000000..0ac23f9 --- /dev/null +++ b/src/crypto.h @@ -0,0 +1,103 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2015-2021 Jason A. Donenfeld . All Rights Reserved. + */ + +#ifndef _WG_CRYPTO +#define _WG_CRYPTO + +#include + +enum chacha20poly1305_lengths { + XCHACHA20POLY1305_NONCE_SIZE = 24, + CHACHA20POLY1305_KEY_SIZE = 32, + CHACHA20POLY1305_AUTHTAG_SIZE = 16 +}; + +void +chacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, const size_t src_len, + const uint8_t *ad, const size_t ad_len, + const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]); + +bool +chacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, const size_t src_len, + const uint8_t *ad, const size_t ad_len, + const uint64_t nonce, + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]); + +void +xchacha20poly1305_encrypt(uint8_t *dst, const uint8_t *src, + const size_t src_len, const uint8_t *ad, + const size_t ad_len, + const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE], + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]); + +bool +xchacha20poly1305_decrypt(uint8_t *dst, const uint8_t *src, + const size_t src_len, const uint8_t *ad, + const size_t ad_len, + const uint8_t nonce[XCHACHA20POLY1305_NONCE_SIZE], + const uint8_t key[CHACHA20POLY1305_KEY_SIZE]); + + +enum blake2s_lengths { + BLAKE2S_BLOCK_SIZE = 64, + BLAKE2S_HASH_SIZE = 32, + BLAKE2S_KEY_SIZE = 32 +}; + +struct blake2s_state { + uint32_t h[8]; + uint32_t t[2]; + uint32_t f[2]; + uint8_t buf[BLAKE2S_BLOCK_SIZE]; + unsigned int buflen; + unsigned int outlen; +}; + +void blake2s_init(struct blake2s_state *state, const size_t outlen); + +void blake2s_init_key(struct blake2s_state *state, const size_t outlen, + const uint8_t *key, const size_t keylen); + +void blake2s_update(struct blake2s_state *state, const uint8_t *in, size_t inlen); + +void blake2s_final(struct blake2s_state *state, uint8_t *out); + +void blake2s(uint8_t *out, const uint8_t *in, const uint8_t *key, + const size_t outlen, const size_t inlen, const size_t keylen); + +void blake2s_hmac(uint8_t *out, const uint8_t *in, const uint8_t *key, + const size_t outlen, const size_t inlen, const size_t keylen); + +enum curve25519_lengths { + CURVE25519_KEY_SIZE = 32 +}; + +bool curve25519(uint8_t mypublic[static CURVE25519_KEY_SIZE], + const uint8_t secret[static CURVE25519_KEY_SIZE], + const uint8_t basepoint[static CURVE25519_KEY_SIZE]); + +static inline bool +curve25519_generate_public(uint8_t pub[static CURVE25519_KEY_SIZE], + const uint8_t secret[static CURVE25519_KEY_SIZE]) +{ + static const uint8_t basepoint[CURVE25519_KEY_SIZE] = { 9 }; + + return curve25519(pub, secret, basepoint); +} + +static inline void curve25519_clamp_secret(uint8_t secret[static CURVE25519_KEY_SIZE]) +{ + secret[0] &= 248; + secret[31] = (secret[31] & 127) | 64; +} + +static inline void curve25519_generate_secret(uint8_t secret[CURVE25519_KEY_SIZE]) +{ + arc4random_buf(secret, CURVE25519_KEY_SIZE); + curve25519_clamp_secret(secret); +} + +#endif diff --git a/src/if_wg.c b/src/if_wg.c new file mode 100644 index 0000000..1fa7b4c --- /dev/null +++ b/src/if_wg.c @@ -0,0 +1,3451 @@ +/* SPDX-License-Identifier: BSD-2-Clause-FreeBSD + * + * Copyright (C) 2015-2021 Jason A. Donenfeld . All Rights Reserved. + * Copyright (C) 2019-2021 Matt Dunwoodie + * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate) + * Copyright (c) 2021 Kyle Evans + */ + +/* TODO audit imports */ +#include "opt_inet.h" +#include "opt_inet6.h" + +#include +__FBSDID("$FreeBSD$"); + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "support.h" +#include "wg_noise.h" +#include "wg_cookie.h" +#include "if_wg.h" + +/* It'd be nice to use IF_MAXMTU, but that means more complicated mbuf allocations, + * so instead just do the biggest mbuf we can easily allocate minus the usual maximum + * IPv6 overhead of 80 bytes. If somebody wants bigger frames, we can revisit this. */ +#define MAX_MTU (MJUM16BYTES - 80) + +#define DEFAULT_MTU 1420 + +#define MAX_STAGED_PKT 128 +#define MAX_QUEUED_PKT 1024 +#define MAX_QUEUED_PKT_MASK (MAX_QUEUED_PKT - 1) + +#define MAX_QUEUED_HANDSHAKES 4096 + +#define HASHTABLE_PEER_SIZE (1 << 11) +#define HASHTABLE_INDEX_SIZE (1 << 13) +#define MAX_PEERS_PER_IFACE (1 << 20) + +#define REKEY_TIMEOUT 5 +#define REKEY_TIMEOUT_JITTER 334 /* 1/3 sec, round for arc4random_uniform */ +#define KEEPALIVE_TIMEOUT 10 +#define MAX_TIMER_HANDSHAKES (90 / REKEY_TIMEOUT) +#define NEW_HANDSHAKE_TIMEOUT (REKEY_TIMEOUT + KEEPALIVE_TIMEOUT) +#define UNDERLOAD_TIMEOUT 1 + +#define DPRINTF(sc, ...) if (wireguard_debug) if_printf(sc->sc_ifp, ##__VA_ARGS__) + +/* First byte indicating packet type on the wire */ +#define WG_PKT_INITIATION htole32(1) +#define WG_PKT_RESPONSE htole32(2) +#define WG_PKT_COOKIE htole32(3) +#define WG_PKT_DATA htole32(4) + +#define WG_PKT_WITH_PADDING(n) (((n) + (16-1)) & (~(16-1))) +#define WG_KEY_SIZE 32 + +struct wg_pkt_initiation { + uint32_t t; + uint32_t s_idx; + uint8_t ue[NOISE_PUBLIC_KEY_LEN]; + uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN]; + uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]; + struct cookie_macs m; +}; + +struct wg_pkt_response { + uint32_t t; + uint32_t s_idx; + uint32_t r_idx; + uint8_t ue[NOISE_PUBLIC_KEY_LEN]; + uint8_t en[0 + NOISE_AUTHTAG_LEN]; + struct cookie_macs m; +}; + +struct wg_pkt_cookie { + uint32_t t; + uint32_t r_idx; + uint8_t nonce[COOKIE_NONCE_SIZE]; + uint8_t ec[COOKIE_ENCRYPTED_SIZE]; +}; + +struct wg_pkt_data { + uint32_t t; + uint32_t r_idx; + uint8_t nonce[sizeof(uint64_t)]; + uint8_t buf[]; +}; + +struct wg_endpoint { + union { + struct sockaddr r_sa; + struct sockaddr_in r_sin; +#ifdef INET6 + struct sockaddr_in6 r_sin6; +#endif + } e_remote; + union { + struct in_addr l_in; +#ifdef INET6 + struct in6_pktinfo l_pktinfo6; +#define l_in6 l_pktinfo6.ipi6_addr +#endif + } e_local; +}; + +struct wg_tag { + struct m_tag t_tag; + struct wg_endpoint t_endpoint; + struct wg_peer *t_peer; + struct mbuf *t_mbuf; + int t_done; + int t_mtu; +}; + +struct wg_index { + LIST_ENTRY(wg_index) i_entry; + SLIST_ENTRY(wg_index) i_unused_entry; + uint32_t i_key; + struct noise_remote *i_value; +}; + +struct wg_timers { + /* t_lock is for blocking wg_timers_event_* when setting t_disabled. */ + struct rwlock t_lock; + + int t_disabled; + int t_need_another_keepalive; + uint16_t t_persistent_keepalive_interval; + struct callout t_new_handshake; + struct callout t_send_keepalive; + struct callout t_retry_handshake; + struct callout t_zero_key_material; + struct callout t_persistent_keepalive; + + struct mtx t_handshake_mtx; + struct timespec t_handshake_last_sent; + struct timespec t_handshake_complete; + volatile int t_handshake_retries; +}; + +struct wg_aip { + struct radix_node r_nodes[2]; + CK_LIST_ENTRY(wg_aip) r_entry; + struct sockaddr_storage r_addr; + struct sockaddr_storage r_mask; + struct wg_peer *r_peer; +}; + +struct wg_queue { + struct mtx q_mtx; + struct mbufq q; +}; + +struct wg_peer { + CK_LIST_ENTRY(wg_peer) p_hash_entry; + CK_LIST_ENTRY(wg_peer) p_entry; + uint64_t p_id; + struct wg_softc *p_sc; + + struct noise_remote p_remote; + struct cookie_maker p_cookie; + struct wg_timers p_timers; + + struct rwlock p_endpoint_lock; + struct wg_endpoint p_endpoint; + + SLIST_HEAD(,wg_index) p_unused_index; + struct wg_index p_index[3]; + + struct wg_queue p_stage_queue; + struct wg_queue p_encap_queue; + struct wg_queue p_decap_queue; + + struct grouptask p_clear_secrets; + struct grouptask p_send_initiation; + struct grouptask p_send_keepalive; + struct grouptask p_send; + struct grouptask p_recv; + + counter_u64_t p_tx_bytes; + counter_u64_t p_rx_bytes; + + CK_LIST_HEAD(, wg_aip) p_aips; + struct mtx p_lock; + struct epoch_context p_ctx; +}; + +enum route_direction { + /* TODO OpenBSD doesn't use IN/OUT, instead passes the address buffer + * directly to route_lookup. */ + IN, + OUT, +}; + +struct wg_aip_table { + size_t t_count; + struct radix_node_head *t_ip; + struct radix_node_head *t_ip6; +}; + +struct wg_allowedip { + uint16_t family; + union { + struct in_addr ip4; + struct in6_addr ip6; + }; + uint8_t cidr; +}; + +struct wg_hashtable { + struct mtx h_mtx; + SIPHASH_KEY h_secret; + CK_LIST_HEAD(, wg_peer) h_peers_list; + CK_LIST_HEAD(, wg_peer) *h_peers; + u_long h_peers_mask; + size_t h_num_peers; +}; + +struct wg_socket { + struct mtx so_mtx; + struct socket *so_so4; + struct socket *so_so6; + uint32_t so_user_cookie; + in_port_t so_port; +}; + +struct wg_softc { + LIST_ENTRY(wg_softc) sc_entry; + struct ifnet *sc_ifp; + int sc_flags; + + struct ucred *sc_ucred; + struct wg_socket sc_socket; + struct wg_hashtable sc_hashtable; + struct wg_aip_table sc_aips; + + struct mbufq sc_handshake_queue; + struct grouptask sc_handshake; + + struct noise_local sc_local; + struct cookie_checker sc_cookie; + + struct buf_ring *sc_encap_ring; + struct buf_ring *sc_decap_ring; + + struct grouptask *sc_encrypt; + struct grouptask *sc_decrypt; + + struct rwlock sc_index_lock; + LIST_HEAD(,wg_index) *sc_index; + u_long sc_index_mask; + + struct sx sc_lock; + volatile u_int sc_peer_count; +}; + +#define WGF_DYING 0x0001 + +/* TODO the following defines are freebsd specific, we should see what is + * necessary and cleanup from there (i suspect a lot can be junked). */ + +#ifndef ENOKEY +#define ENOKEY ENOTCAPABLE +#endif + +#if __FreeBSD_version > 1300000 +typedef void timeout_t (void *); +#endif + +#define GROUPTASK_DRAIN(gtask) \ + gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task) + +#define MTAG_WIREGUARD 0xBEAD +#define M_ENQUEUED M_PROTO1 + +static int clone_count; +static uma_zone_t ratelimit_zone; +static int wireguard_debug; +static volatile unsigned long peer_counter = 0; +static const char wgname[] = "wg"; +static unsigned wg_osd_jail_slot; + +static struct sx wg_sx; +SX_SYSINIT(wg_sx, &wg_sx, "wg_sx"); + +static LIST_HEAD(, wg_softc) wg_list = LIST_HEAD_INITIALIZER(wg_list); + +SYSCTL_NODE(_net, OID_AUTO, wg, CTLFLAG_RW, 0, "WireGuard"); +SYSCTL_INT(_net_wg, OID_AUTO, debug, CTLFLAG_RWTUN, &wireguard_debug, 0, + "enable debug logging"); + +TASKQGROUP_DECLARE(if_io_tqg); + +MALLOC_DEFINE(M_WG, "WG", "wireguard"); +VNET_DEFINE_STATIC(struct if_clone *, wg_cloner); + + +#define V_wg_cloner VNET(wg_cloner) +#define WG_CAPS IFCAP_LINKSTATE +#define ph_family PH_loc.eight[5] + +struct wg_timespec64 { + uint64_t tv_sec; + uint64_t tv_nsec; +}; + +struct wg_peer_export { + struct sockaddr_storage endpoint; + struct timespec last_handshake; + uint8_t public_key[WG_KEY_SIZE]; + uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN]; + size_t endpoint_sz; + struct wg_allowedip *aip; + uint64_t rx_bytes; + uint64_t tx_bytes; + int aip_count; + uint16_t persistent_keepalive; +}; + +static struct wg_tag *wg_tag_get(struct mbuf *); +static struct wg_endpoint *wg_mbuf_endpoint_get(struct mbuf *); +static int wg_socket_init(struct wg_softc *, in_port_t); +static int wg_socket_bind(struct socket *, struct socket *, in_port_t *); +static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *); +static void wg_socket_uninit(struct wg_softc *); +static void wg_socket_set_cookie(struct wg_softc *, uint32_t); +static int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *); +static void wg_timers_event_data_sent(struct wg_timers *); +static void wg_timers_event_data_received(struct wg_timers *); +static void wg_timers_event_any_authenticated_packet_sent(struct wg_timers *); +static void wg_timers_event_any_authenticated_packet_received(struct wg_timers *); +static void wg_timers_event_any_authenticated_packet_traversal(struct wg_timers *); +static void wg_timers_event_handshake_initiated(struct wg_timers *); +static void wg_timers_event_handshake_responded(struct wg_timers *); +static void wg_timers_event_handshake_complete(struct wg_timers *); +static void wg_timers_event_session_derived(struct wg_timers *); +static void wg_timers_event_want_initiation(struct wg_timers *); +static void wg_timers_event_reset_handshake_last_sent(struct wg_timers *); +static void wg_timers_run_send_initiation(struct wg_timers *, int); +static void wg_timers_run_retry_handshake(struct wg_timers *); +static void wg_timers_run_send_keepalive(struct wg_timers *); +static void wg_timers_run_new_handshake(struct wg_timers *); +static void wg_timers_run_zero_key_material(struct wg_timers *); +static void wg_timers_run_persistent_keepalive(struct wg_timers *); +static void wg_timers_init(struct wg_timers *); +static void wg_timers_enable(struct wg_timers *); +static void wg_timers_disable(struct wg_timers *); +static void wg_timers_set_persistent_keepalive(struct wg_timers *, uint16_t); +static void wg_timers_get_last_handshake(struct wg_timers *, struct timespec *); +static int wg_timers_expired_handshake_last_sent(struct wg_timers *); +static int wg_timers_check_handshake_last_sent(struct wg_timers *); +static void wg_queue_init(struct wg_queue *, const char *); +static void wg_queue_deinit(struct wg_queue *); +static void wg_queue_purge(struct wg_queue *); +static struct mbuf *wg_queue_dequeue(struct wg_queue *, struct wg_tag **); +static int wg_queue_len(struct wg_queue *); +static int wg_queue_in(struct wg_peer *, struct mbuf *); +static void wg_queue_out(struct wg_peer *); +static void wg_queue_stage(struct wg_peer *, struct mbuf *); +static int wg_aip_init(struct wg_aip_table *); +static void wg_aip_destroy(struct wg_aip_table *); +static void wg_aip_populate_aip4(struct wg_aip *, const struct in_addr *, uint8_t); +static void wg_aip_populate_aip6(struct wg_aip *, const struct in6_addr *, uint8_t); +static int wg_aip_add(struct wg_aip_table *, struct wg_peer *, const struct wg_allowedip *); +static int wg_peer_remove(struct radix_node *, void *); +static void wg_peer_remove_all(struct wg_softc *); +static int wg_aip_delete(struct wg_aip_table *, struct wg_peer *); +static struct wg_peer *wg_aip_lookup(struct wg_aip_table *, struct mbuf *, enum route_direction); +static void wg_hashtable_init(struct wg_hashtable *); +static void wg_hashtable_destroy(struct wg_hashtable *); +static void wg_hashtable_peer_insert(struct wg_hashtable *, struct wg_peer *); +static struct wg_peer *wg_peer_lookup(struct wg_softc *, const uint8_t [32]); +static void wg_hashtable_peer_remove(struct wg_hashtable *, struct wg_peer *); +static int wg_cookie_validate_packet(struct cookie_checker *, struct mbuf *, int); +static struct wg_peer *wg_peer_alloc(struct wg_softc *); +static void wg_peer_free_deferred(epoch_context_t); +static void wg_peer_destroy(struct wg_peer *); +static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t); +static void wg_send_initiation(struct wg_peer *); +static void wg_send_response(struct wg_peer *); +static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct mbuf *); +static void wg_peer_set_endpoint_from_tag(struct wg_peer *, struct wg_tag *); +static void wg_peer_clear_src(struct wg_peer *); +static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *); +static void wg_deliver_out(struct wg_peer *); +static void wg_deliver_in(struct wg_peer *); +static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t); +static void wg_send_keepalive(struct wg_peer *); +static void wg_handshake(struct wg_softc *, struct mbuf *); +static void wg_encap(struct wg_softc *, struct mbuf *); +static void wg_decap(struct wg_softc *, struct mbuf *); +static void wg_softc_handshake_receive(struct wg_softc *); +static void wg_softc_decrypt(struct wg_softc *); +static void wg_softc_encrypt(struct wg_softc *); +static struct noise_remote *wg_remote_get(struct wg_softc *, uint8_t [NOISE_PUBLIC_KEY_LEN]); +static uint32_t wg_index_set(struct wg_softc *, struct noise_remote *); +static struct noise_remote *wg_index_get(struct wg_softc *, uint32_t); +static void wg_index_drop(struct wg_softc *, uint32_t); +static int wg_update_endpoint_addrs(struct wg_endpoint *, const struct sockaddr *, struct ifnet *); +static void wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *); +static void wg_encrypt_dispatch(struct wg_softc *); +static void wg_decrypt_dispatch(struct wg_softc *); +static void crypto_taskq_setup(struct wg_softc *); +static void crypto_taskq_destroy(struct wg_softc *); +static int wg_clone_create(struct if_clone *, int, caddr_t); +static void wg_qflush(struct ifnet *); +static int wg_transmit(struct ifnet *, struct mbuf *); +static int wg_output(struct ifnet *, struct mbuf *, const struct sockaddr *, struct route *); +static void wg_clone_destroy(struct ifnet *); +static int wg_peer_to_export(struct wg_peer *, struct wg_peer_export *); +static bool wgc_privileged(struct wg_softc *); +static int wgc_get(struct wg_softc *, struct wg_data_io *); +static int wgc_set(struct wg_softc *, struct wg_data_io *); +static int wg_up(struct wg_softc *); +static void wg_down(struct wg_softc *); +static void wg_reassign(struct ifnet *, struct vnet *, char *unused); +static void wg_init(void *); +static int wg_ioctl(struct ifnet *, u_long, caddr_t); +static void vnet_wg_init(const void *); +static void vnet_wg_uninit(const void *); +static void wg_module_init(void); +static void wg_module_deinit(void); + +/* TODO Peer */ +static struct wg_peer * +wg_peer_alloc(struct wg_softc *sc) +{ + struct wg_peer *peer; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + peer = malloc(sizeof(*peer), M_WG, M_WAITOK|M_ZERO); + peer->p_sc = sc; + peer->p_id = peer_counter++; + CK_LIST_INIT(&peer->p_aips); + + rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint"); + wg_queue_init(&peer->p_stage_queue, "stageq"); + wg_queue_init(&peer->p_encap_queue, "txq"); + wg_queue_init(&peer->p_decap_queue, "rxq"); + + GROUPTASK_INIT(&peer->p_send_initiation, 0, (gtask_fn_t *)wg_send_initiation, peer); + taskqgroup_attach(qgroup_if_io_tqg, &peer->p_send_initiation, peer, NULL, NULL, "wg initiation"); + GROUPTASK_INIT(&peer->p_send_keepalive, 0, (gtask_fn_t *)wg_send_keepalive, peer); + taskqgroup_attach(qgroup_if_io_tqg, &peer->p_send_keepalive, peer, NULL, NULL, "wg keepalive"); + GROUPTASK_INIT(&peer->p_clear_secrets, 0, (gtask_fn_t *)noise_remote_clear, &peer->p_remote); + taskqgroup_attach(qgroup_if_io_tqg, &peer->p_clear_secrets, + &peer->p_remote, NULL, NULL, "wg clear secrets"); + + GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer); + taskqgroup_attach(qgroup_if_io_tqg, &peer->p_send, peer, NULL, NULL, "wg send"); + GROUPTASK_INIT(&peer->p_recv, 0, (gtask_fn_t *)wg_deliver_in, peer); + taskqgroup_attach(qgroup_if_io_tqg, &peer->p_recv, peer, NULL, NULL, "wg recv"); + + wg_timers_init(&peer->p_timers); + + peer->p_tx_bytes = counter_u64_alloc(M_WAITOK); + peer->p_rx_bytes = counter_u64_alloc(M_WAITOK); + + SLIST_INIT(&peer->p_unused_index); + SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[0], + i_unused_entry); + SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[1], + i_unused_entry); + SLIST_INSERT_HEAD(&peer->p_unused_index, &peer->p_index[2], + i_unused_entry); + + return (peer); +} + +#define WG_HASHTABLE_PEER_FOREACH(peer, i, ht) \ + for (i = 0; i < HASHTABLE_PEER_SIZE; i++) \ + LIST_FOREACH(peer, &(ht)->h_peers[i], p_hash_entry) +#define WG_HASHTABLE_PEER_FOREACH_SAFE(peer, i, ht, tpeer) \ + for (i = 0; i < HASHTABLE_PEER_SIZE; i++) \ + CK_LIST_FOREACH_SAFE(peer, &(ht)->h_peers[i], p_hash_entry, tpeer) +static void +wg_hashtable_init(struct wg_hashtable *ht) +{ + mtx_init(&ht->h_mtx, "hash lock", NULL, MTX_DEF); + arc4random_buf(&ht->h_secret, sizeof(ht->h_secret)); + ht->h_num_peers = 0; + ht->h_peers = hashinit(HASHTABLE_PEER_SIZE, M_DEVBUF, + &ht->h_peers_mask); +} + +static void +wg_hashtable_destroy(struct wg_hashtable *ht) +{ + MPASS(ht->h_num_peers == 0); + mtx_destroy(&ht->h_mtx); + hashdestroy(ht->h_peers, M_DEVBUF, ht->h_peers_mask); +} + +static void +wg_hashtable_peer_insert(struct wg_hashtable *ht, struct wg_peer *peer) +{ + uint64_t key; + + key = siphash24(&ht->h_secret, peer->p_remote.r_public, + sizeof(peer->p_remote.r_public)); + + mtx_lock(&ht->h_mtx); + ht->h_num_peers++; + CK_LIST_INSERT_HEAD(&ht->h_peers[key & ht->h_peers_mask], peer, p_hash_entry); + CK_LIST_INSERT_HEAD(&ht->h_peers_list, peer, p_entry); + mtx_unlock(&ht->h_mtx); +} + +static struct wg_peer * +wg_peer_lookup(struct wg_softc *sc, + const uint8_t pubkey[WG_KEY_SIZE]) +{ + struct wg_hashtable *ht = &sc->sc_hashtable; + uint64_t key; + struct wg_peer *i = NULL; + + key = siphash24(&ht->h_secret, pubkey, WG_KEY_SIZE); + + mtx_lock(&ht->h_mtx); + CK_LIST_FOREACH(i, &ht->h_peers[key & ht->h_peers_mask], p_hash_entry) { + if (timingsafe_bcmp(i->p_remote.r_public, pubkey, + WG_KEY_SIZE) == 0) + break; + } + mtx_unlock(&ht->h_mtx); + + return i; +} + +static void +wg_hashtable_peer_remove(struct wg_hashtable *ht, struct wg_peer *peer) +{ + mtx_lock(&ht->h_mtx); + ht->h_num_peers--; + CK_LIST_REMOVE(peer, p_hash_entry); + CK_LIST_REMOVE(peer, p_entry); + mtx_unlock(&ht->h_mtx); +} + +static void +wg_peer_free_deferred(epoch_context_t ctx) +{ + struct wg_peer *peer = __containerof(ctx, struct wg_peer, p_ctx); + counter_u64_free(peer->p_tx_bytes); + counter_u64_free(peer->p_rx_bytes); + rw_destroy(&peer->p_timers.t_lock); + rw_destroy(&peer->p_endpoint_lock); + free(peer, M_WG); +} + +static void +wg_peer_destroy(struct wg_peer *peer) +{ + /* Callers should already have called: + * wg_hashtable_peer_remove(&sc->sc_hashtable, peer); + */ + wg_aip_delete(&peer->p_sc->sc_aips, peer); + MPASS(CK_LIST_EMPTY(&peer->p_aips)); + + /* We disable all timers, so we can't call the following tasks. */ + wg_timers_disable(&peer->p_timers); + + /* Ensure the tasks have finished running */ + GROUPTASK_DRAIN(&peer->p_clear_secrets); + GROUPTASK_DRAIN(&peer->p_send_initiation); + GROUPTASK_DRAIN(&peer->p_send_keepalive); + GROUPTASK_DRAIN(&peer->p_recv); + GROUPTASK_DRAIN(&peer->p_send); + + taskqgroup_detach(qgroup_if_io_tqg, &peer->p_clear_secrets); + taskqgroup_detach(qgroup_if_io_tqg, &peer->p_send_initiation); + taskqgroup_detach(qgroup_if_io_tqg, &peer->p_send_keepalive); + taskqgroup_detach(qgroup_if_io_tqg, &peer->p_recv); + taskqgroup_detach(qgroup_if_io_tqg, &peer->p_send); + + wg_queue_deinit(&peer->p_decap_queue); + wg_queue_deinit(&peer->p_encap_queue); + wg_queue_deinit(&peer->p_stage_queue); + + /* Final cleanup */ + --peer->p_sc->sc_peer_count; + noise_remote_clear(&peer->p_remote); + DPRINTF(peer->p_sc, "Peer %llu destroyed\n", (unsigned long long)peer->p_id); + NET_EPOCH_CALL(wg_peer_free_deferred, &peer->p_ctx); +} + +static void +wg_peer_set_endpoint_from_tag(struct wg_peer *peer, struct wg_tag *t) +{ + struct wg_endpoint *e = &t->t_endpoint; + + MPASS(e->e_remote.r_sa.sa_family != 0); + if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0) + return; + + peer->p_endpoint = *e; +} + +static void +wg_peer_clear_src(struct wg_peer *peer) +{ + rw_rlock(&peer->p_endpoint_lock); + bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local)); + rw_runlock(&peer->p_endpoint_lock); +} + +static void +wg_peer_get_endpoint(struct wg_peer *p, struct wg_endpoint *e) +{ + memcpy(e, &p->p_endpoint, sizeof(*e)); +} + +/* Allowed IP */ +static int +wg_aip_init(struct wg_aip_table *tbl) +{ + int rc; + + tbl->t_count = 0; + rc = rn_inithead((void **)&tbl->t_ip, + offsetof(struct sockaddr_in, sin_addr) * NBBY); + + if (rc == 0) + return (ENOMEM); + RADIX_NODE_HEAD_LOCK_INIT(tbl->t_ip); +#ifdef INET6 + rc = rn_inithead((void **)&tbl->t_ip6, + offsetof(struct sockaddr_in6, sin6_addr) * NBBY); + if (rc == 0) { + free(tbl->t_ip, M_RTABLE); + return (ENOMEM); + } + RADIX_NODE_HEAD_LOCK_INIT(tbl->t_ip6); +#endif + return (0); +} + +static void +wg_aip_destroy(struct wg_aip_table *tbl) +{ + RADIX_NODE_HEAD_DESTROY(tbl->t_ip); + free(tbl->t_ip, M_RTABLE); +#ifdef INET6 + RADIX_NODE_HEAD_DESTROY(tbl->t_ip6); + free(tbl->t_ip6, M_RTABLE); +#endif +} + +static void +wg_aip_populate_aip4(struct wg_aip *aip, const struct in_addr *addr, + uint8_t mask) +{ + struct sockaddr_in *raddr, *rmask; + uint8_t *p; + unsigned int i; + + raddr = (struct sockaddr_in *)&aip->r_addr; + rmask = (struct sockaddr_in *)&aip->r_mask; + + raddr->sin_len = sizeof(*raddr); + raddr->sin_family = AF_INET; + raddr->sin_addr = *addr; + + rmask->sin_len = sizeof(*rmask); + p = (uint8_t *)&rmask->sin_addr.s_addr; + for (i = 0; i < mask / NBBY; i++) + p[i] = 0xff; + if ((mask % NBBY) != 0) + p[i] = (0xff00 >> (mask % NBBY)) & 0xff; + raddr->sin_addr.s_addr &= rmask->sin_addr.s_addr; +} + +static void +wg_aip_populate_aip6(struct wg_aip *aip, const struct in6_addr *addr, + uint8_t mask) +{ + struct sockaddr_in6 *raddr, *rmask; + + raddr = (struct sockaddr_in6 *)&aip->r_addr; + rmask = (struct sockaddr_in6 *)&aip->r_mask; + + raddr->sin6_len = sizeof(*raddr); + raddr->sin6_family = AF_INET6; + raddr->sin6_addr = *addr; + + rmask->sin6_len = sizeof(*rmask); + in6_prefixlen2mask(&rmask->sin6_addr, mask); + for (int i = 0; i < 4; ++i) + raddr->sin6_addr.__u6_addr.__u6_addr32[i] &= rmask->sin6_addr.__u6_addr.__u6_addr32[i]; +} + +/* wg_aip_take assumes that the caller guarantees the allowed-ip exists. */ +static void +wg_aip_take(struct radix_node_head *root, struct wg_peer *peer, + struct wg_aip *route) +{ + struct radix_node *node; + struct wg_peer *ppeer; + + RADIX_NODE_HEAD_LOCK_ASSERT(root); + + node = root->rnh_lookup(&route->r_addr, &route->r_mask, + &root->rh); + MPASS(node != NULL); + + route = (struct wg_aip *)node; + ppeer = route->r_peer; + if (ppeer != peer) { + route->r_peer = peer; + + CK_LIST_REMOVE(route, r_entry); + CK_LIST_INSERT_HEAD(&peer->p_aips, route, r_entry); + } +} + +static int +wg_aip_add(struct wg_aip_table *tbl, struct wg_peer *peer, + const struct wg_allowedip *aip) +{ + struct radix_node *node; + struct radix_node_head *root; + struct wg_aip *route; + sa_family_t family; + bool needfree = false; + + family = aip->family; + if (family != AF_INET && family != AF_INET6) { + return (EINVAL); + } + + route = malloc(sizeof(*route), M_WG, M_WAITOK|M_ZERO); + switch (family) { + case AF_INET: + root = tbl->t_ip; + + wg_aip_populate_aip4(route, &aip->ip4, aip->cidr); + break; + case AF_INET6: + root = tbl->t_ip6; + + wg_aip_populate_aip6(route, &aip->ip6, aip->cidr); + break; + } + + route->r_peer = peer; + + RADIX_NODE_HEAD_LOCK(root); + node = root->rnh_addaddr(&route->r_addr, &route->r_mask, &root->rh, + route->r_nodes); + if (node == route->r_nodes) { + tbl->t_count++; + CK_LIST_INSERT_HEAD(&peer->p_aips, route, r_entry); + } else { + needfree = true; + wg_aip_take(root, peer, route); + } + RADIX_NODE_HEAD_UNLOCK(root); + if (needfree) { + free(route, M_WG); + } + return (0); +} + +static struct wg_peer * +wg_aip_lookup(struct wg_aip_table *tbl, struct mbuf *m, + enum route_direction dir) +{ + RADIX_NODE_HEAD_RLOCK_TRACKER; + struct ip *iphdr; + struct ip6_hdr *ip6hdr; + struct radix_node_head *root; + struct radix_node *node; + struct wg_peer *peer = NULL; + struct sockaddr_in sin; + struct sockaddr_in6 sin6; + void *addr; + int version; + + NET_EPOCH_ASSERT(); + iphdr = mtod(m, struct ip *); + version = iphdr->ip_v; + + if (__predict_false(dir != IN && dir != OUT)) + return NULL; + + if (version == 4) { + root = tbl->t_ip; + memset(&sin, 0, sizeof(sin)); + sin.sin_len = sizeof(struct sockaddr_in); + if (dir == IN) + sin.sin_addr = iphdr->ip_src; + else + sin.sin_addr = iphdr->ip_dst; + addr = &sin; + } else if (version == 6) { + ip6hdr = mtod(m, struct ip6_hdr *); + memset(&sin6, 0, sizeof(sin6)); + sin6.sin6_len = sizeof(struct sockaddr_in6); + + root = tbl->t_ip6; + if (dir == IN) + addr = &ip6hdr->ip6_src; + else + addr = &ip6hdr->ip6_dst; + memcpy(&sin6.sin6_addr, addr, sizeof(sin6.sin6_addr)); + addr = &sin6; + } else { + return (NULL); + } + RADIX_NODE_HEAD_RLOCK(root); + if ((node = root->rnh_matchaddr(addr, &root->rh)) != NULL) { + peer = ((struct wg_aip *) node)->r_peer; + } + RADIX_NODE_HEAD_RUNLOCK(root); + return (peer); +} + +struct peer_del_arg { + struct radix_node_head * pda_head; + struct wg_peer *pda_peer; + struct wg_aip_table *pda_tbl; +}; + +static int +wg_peer_remove(struct radix_node *rn, void *arg) +{ + struct peer_del_arg *pda = arg; + struct wg_peer *peer = pda->pda_peer; + struct radix_node_head * rnh = pda->pda_head; + struct wg_aip_table *tbl = pda->pda_tbl; + struct wg_aip *route = (struct wg_aip *)rn; + struct radix_node *x; + + if (route->r_peer != peer) + return (0); + x = (struct radix_node *)rnh->rnh_deladdr(&route->r_addr, + &route->r_mask, &rnh->rh); + if (x != NULL) { + tbl->t_count--; + CK_LIST_REMOVE(route, r_entry); + free(route, M_WG); + } + return (0); +} + +static void +wg_peer_remove_all(struct wg_softc *sc) +{ + struct wg_peer *peer, *tpeer; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + CK_LIST_FOREACH_SAFE(peer, &sc->sc_hashtable.h_peers_list, + p_entry, tpeer) { + wg_hashtable_peer_remove(&sc->sc_hashtable, peer); + wg_peer_destroy(peer); + } +} + +static int +wg_aip_delete(struct wg_aip_table *tbl, struct wg_peer *peer) +{ + struct peer_del_arg pda; + + pda.pda_peer = peer; + pda.pda_tbl = tbl; + RADIX_NODE_HEAD_LOCK(tbl->t_ip); + pda.pda_head = tbl->t_ip; + rn_walktree(&tbl->t_ip->rh, wg_peer_remove, &pda); + RADIX_NODE_HEAD_UNLOCK(tbl->t_ip); + + RADIX_NODE_HEAD_LOCK(tbl->t_ip6); + pda.pda_head = tbl->t_ip6; + rn_walktree(&tbl->t_ip6->rh, wg_peer_remove, &pda); + RADIX_NODE_HEAD_UNLOCK(tbl->t_ip6); + return (0); +} + +static int +wg_socket_init(struct wg_softc *sc, in_port_t port) +{ + struct thread *td; + struct ucred *cred; + struct socket *so4, *so6; + int rc; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + so4 = so6 = NULL; + td = curthread; + if ((cred = sc->sc_ucred) == NULL) + return (EBUSY); + + /* + * For socket creation, we use the creds of the thread that created the + * tunnel rather than the current thread to maintain the semantics that + * WireGuard has on Linux with network namespaces -- that the sockets + * are created in their home vnet so that they can be configured and + * functionally attached to a foreign vnet as the jail's only interface + * to the network. + */ + rc = socreate(AF_INET, &so4, SOCK_DGRAM, IPPROTO_UDP, cred, td); + if (rc != 0) + goto out; + + rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc); + /* + * udp_set_kernel_tunneling can only fail if there is already a tunneling function set. + * This should never happen with a new socket. + */ + MPASS(rc == 0); + + rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, td); + if (rc != 0) + goto out; + rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc); + MPASS(rc == 0); + + so4->so_user_cookie = so6->so_user_cookie = sc->sc_socket.so_user_cookie; + + rc = wg_socket_bind(so4, so6, &port); + if (rc == 0) { + sc->sc_socket.so_port = port; + wg_socket_set(sc, so4, so6); + } +out: + if (rc != 0) { + if (so4 != NULL) + soclose(so4); + if (so6 != NULL) + soclose(so6); + } + return (rc); +} + +static void wg_socket_set_cookie(struct wg_softc *sc, uint32_t user_cookie) +{ + struct wg_socket *so = &sc->sc_socket; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + so->so_user_cookie = user_cookie; + if (so->so_so4) + so->so_so4->so_user_cookie = user_cookie; + if (so->so_so6) + so->so_so6->so_user_cookie = user_cookie; +} + +static void +wg_socket_uninit(struct wg_softc *sc) +{ + wg_socket_set(sc, NULL, NULL); +} + +static void +wg_socket_set(struct wg_softc *sc, struct socket *new_so4, struct socket *new_so6) +{ + struct wg_socket *so = &sc->sc_socket; + struct socket *so4, *so6; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + so4 = atomic_load_ptr(&so->so_so4); + so6 = atomic_load_ptr(&so->so_so6); + atomic_store_ptr(&so->so_so4, new_so4); + atomic_store_ptr(&so->so_so6, new_so6); + + if (!so4 && !so6) + return; + NET_EPOCH_WAIT(); + if (so4) + soclose(so4); + if (so6) + soclose(so6); +} + +union wg_sockaddr { + struct sockaddr sa; + struct sockaddr_in in4; + struct sockaddr_in6 in6; +}; + +static int +wg_socket_bind(struct socket *so4, struct socket *so6, in_port_t *requested_port) +{ + int rc; + struct thread *td; + union wg_sockaddr laddr; + struct sockaddr_in *sin; + struct sockaddr_in6 *sin6; + in_port_t port = *requested_port; + + td = curthread; + bzero(&laddr, sizeof(laddr)); + sin = &laddr.in4; + sin->sin_len = sizeof(laddr.in4); + sin->sin_family = AF_INET; + sin->sin_port = htons(port); + sin->sin_addr = (struct in_addr) { 0 }; + + if ((rc = sobind(so4, &laddr.sa, td)) != 0) + return (rc); + + if (port == 0) { + rc = sogetsockaddr(so4, (struct sockaddr **)&sin); + if (rc != 0) + return (rc); + port = ntohs(sin->sin_port); + free(sin, M_SONAME); + } + + sin6 = &laddr.in6; + sin6->sin6_len = sizeof(laddr.in6); + sin6->sin6_family = AF_INET6; + sin6->sin6_port = htons(port); + sin6->sin6_addr = (struct in6_addr) { .s6_addr = { 0 } }; + rc = sobind(so6, &laddr.sa, td); + if (rc != 0) + return (rc); + *requested_port = port; + return (0); +} + +static int +wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m) +{ + struct epoch_tracker et; + struct sockaddr *sa; + struct wg_socket *so = &sc->sc_socket; + struct socket *so4, *so6; + struct mbuf *control = NULL; + int ret = 0; + size_t len = m->m_pkthdr.len; + + /* Get local control address before locking */ + if (e->e_remote.r_sa.sa_family == AF_INET) { + if (e->e_local.l_in.s_addr != INADDR_ANY) + control = sbcreatecontrol((caddr_t)&e->e_local.l_in, + sizeof(struct in_addr), IP_SENDSRCADDR, + IPPROTO_IP); +#ifdef INET6 + } else if (e->e_remote.r_sa.sa_family == AF_INET6) { + if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6)) + control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6, + sizeof(struct in6_pktinfo), IPV6_PKTINFO, + IPPROTO_IPV6); +#endif + } else { + m_freem(m); + return (EAFNOSUPPORT); + } + + /* Get remote address */ + sa = &e->e_remote.r_sa; + + NET_EPOCH_ENTER(et); + so4 = atomic_load_ptr(&so->so_so4); + so6 = atomic_load_ptr(&so->so_so6); + if (e->e_remote.r_sa.sa_family == AF_INET && so4 != NULL) + ret = sosend(so4, sa, NULL, m, control, 0, curthread); + else if (e->e_remote.r_sa.sa_family == AF_INET6 && so6 != NULL) + ret = sosend(so6, sa, NULL, m, control, 0, curthread); + else { + ret = ENOTCONN; + m_freem(control); + m_freem(m); + } + NET_EPOCH_EXIT(et); + if (ret == 0) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_OPACKETS, 1); + if_inc_counter(sc->sc_ifp, IFCOUNTER_OBYTES, len); + } + return (ret); +} + +static void +wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf, + size_t len) +{ + struct mbuf *m; + int ret = 0; + +retry: + m = m_gethdr(M_WAITOK, MT_DATA); + m->m_len = 0; + m_copyback(m, 0, len, buf); + + if (ret == 0) { + ret = wg_send(sc, e, m); + /* Retry if we couldn't bind to e->e_local */ + if (ret == EADDRNOTAVAIL) { + bzero(&e->e_local, sizeof(e->e_local)); + goto retry; + } + } else { + ret = wg_send(sc, e, m); + } + if (ret) + DPRINTF(sc, "Unable to send packet: %d\n", ret); +} + +/* TODO Tag */ +static struct wg_tag * +wg_tag_get(struct mbuf *m) +{ + struct m_tag *tag; + + tag = m_tag_find(m, MTAG_WIREGUARD, NULL); + if (tag == NULL) { + tag = m_tag_get(MTAG_WIREGUARD, sizeof(struct wg_tag), M_NOWAIT|M_ZERO); + m_tag_prepend(m, tag); + MPASS(!SLIST_EMPTY(&m->m_pkthdr.tags)); + MPASS(m_tag_locate(m, MTAG_ABI_COMPAT, MTAG_WIREGUARD, NULL) == tag); + } + return (struct wg_tag *)tag; +} + +static struct wg_endpoint * +wg_mbuf_endpoint_get(struct mbuf *m) +{ + struct wg_tag *hdr; + + if ((hdr = wg_tag_get(m)) == NULL) + return (NULL); + + return (&hdr->t_endpoint); +} + +/* Timers */ +static void +wg_timers_init(struct wg_timers *t) +{ + bzero(t, sizeof(*t)); + + t->t_disabled = 1; + rw_init(&t->t_lock, "wg peer timers"); + callout_init(&t->t_retry_handshake, true); + callout_init(&t->t_send_keepalive, true); + callout_init(&t->t_new_handshake, true); + callout_init(&t->t_zero_key_material, true); + callout_init(&t->t_persistent_keepalive, true); +} + +static void +wg_timers_enable(struct wg_timers *t) +{ + rw_wlock(&t->t_lock); + t->t_disabled = 0; + rw_wunlock(&t->t_lock); + wg_timers_run_persistent_keepalive(t); +} + +static void +wg_timers_disable(struct wg_timers *t) +{ + rw_wlock(&t->t_lock); + t->t_disabled = 1; + t->t_need_another_keepalive = 0; + rw_wunlock(&t->t_lock); + + callout_stop(&t->t_retry_handshake); + callout_stop(&t->t_send_keepalive); + callout_stop(&t->t_new_handshake); + callout_stop(&t->t_zero_key_material); + callout_stop(&t->t_persistent_keepalive); +} + +static void +wg_timers_set_persistent_keepalive(struct wg_timers *t, uint16_t interval) +{ + rw_rlock(&t->t_lock); + if (!t->t_disabled) { + t->t_persistent_keepalive_interval = interval; + wg_timers_run_persistent_keepalive(t); + } + rw_runlock(&t->t_lock); +} + +static void +wg_timers_get_last_handshake(struct wg_timers *t, struct timespec *time) +{ + rw_rlock(&t->t_lock); + time->tv_sec = t->t_handshake_complete.tv_sec; + time->tv_nsec = t->t_handshake_complete.tv_nsec; + rw_runlock(&t->t_lock); +} + +static int +wg_timers_expired_handshake_last_sent(struct wg_timers *t) +{ + struct timespec uptime; + struct timespec expire = { .tv_sec = REKEY_TIMEOUT, .tv_nsec = 0 }; + + getnanouptime(&uptime); + timespecadd(&t->t_handshake_last_sent, &expire, &expire); + return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0; +} + +static int +wg_timers_check_handshake_last_sent(struct wg_timers *t) +{ + int ret; + + rw_wlock(&t->t_lock); + if ((ret = wg_timers_expired_handshake_last_sent(t)) == ETIMEDOUT) + getnanouptime(&t->t_handshake_last_sent); + rw_wunlock(&t->t_lock); + return (ret); +} + +/* Should be called after an authenticated data packet is sent. */ +static void +wg_timers_event_data_sent(struct wg_timers *t) +{ + rw_rlock(&t->t_lock); + if (!t->t_disabled && !callout_pending(&t->t_new_handshake)) + callout_reset(&t->t_new_handshake, MSEC_2_TICKS( + NEW_HANDSHAKE_TIMEOUT * 1000 + + arc4random_uniform(REKEY_TIMEOUT_JITTER)), + (timeout_t *)wg_timers_run_new_handshake, t); + rw_runlock(&t->t_lock); +} + +/* Should be called after an authenticated data packet is received. */ +static void +wg_timers_event_data_received(struct wg_timers *t) +{ + rw_rlock(&t->t_lock); + if (!t->t_disabled) { + if (!callout_pending(&t->t_send_keepalive)) { + callout_reset(&t->t_send_keepalive, + MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000), + (timeout_t *)wg_timers_run_send_keepalive, t); + } else { + t->t_need_another_keepalive = 1; + } + } + rw_runlock(&t->t_lock); +} + +/* + * Should be called after any type of authenticated packet is sent, whether + * keepalive, data, or handshake. + */ +static void +wg_timers_event_any_authenticated_packet_sent(struct wg_timers *t) +{ + callout_stop(&t->t_send_keepalive); +} + +/* + * Should be called after any type of authenticated packet is received, whether + * keepalive, data, or handshake. + */ +static void +wg_timers_event_any_authenticated_packet_received(struct wg_timers *t) +{ + callout_stop(&t->t_new_handshake); +} + +/* + * Should be called before a packet with authentication, whether + * keepalive, data, or handshake is sent, or after one is received. + */ +static void +wg_timers_event_any_authenticated_packet_traversal(struct wg_timers *t) +{ + rw_rlock(&t->t_lock); + if (!t->t_disabled && t->t_persistent_keepalive_interval > 0) + callout_reset(&t->t_persistent_keepalive, + MSEC_2_TICKS(t->t_persistent_keepalive_interval * 1000), + (timeout_t *)wg_timers_run_persistent_keepalive, t); + rw_runlock(&t->t_lock); +} + +/* Should be called after a handshake initiation message is sent. */ +static void +wg_timers_event_handshake_initiated(struct wg_timers *t) +{ + rw_rlock(&t->t_lock); + if (!t->t_disabled) + callout_reset(&t->t_retry_handshake, MSEC_2_TICKS( + REKEY_TIMEOUT * 1000 + + arc4random_uniform(REKEY_TIMEOUT_JITTER)), + (timeout_t *)wg_timers_run_retry_handshake, t); + rw_runlock(&t->t_lock); +} + +static void +wg_timers_event_handshake_responded(struct wg_timers *t) +{ + rw_wlock(&t->t_lock); + getnanouptime(&t->t_handshake_last_sent); + rw_wunlock(&t->t_lock); +} + +/* + * Should be called after a handshake response message is received and processed + * or when getting key confirmation via the first data message. + */ +static void +wg_timers_event_handshake_complete(struct wg_timers *t) +{ + rw_wlock(&t->t_lock); + if (!t->t_disabled) { + callout_stop(&t->t_retry_handshake); + t->t_handshake_retries = 0; + getnanotime(&t->t_handshake_complete); + wg_timers_run_send_keepalive(t); + } + rw_wunlock(&t->t_lock); +} + +/* + * Should be called after an ephemeral key is created, which is before sending a + * handshake response or after receiving a handshake response. + */ +static void +wg_timers_event_session_derived(struct wg_timers *t) +{ + rw_rlock(&t->t_lock); + if (!t->t_disabled) { + callout_reset(&t->t_zero_key_material, + MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000), + (timeout_t *)wg_timers_run_zero_key_material, t); + } + rw_runlock(&t->t_lock); +} + +static void +wg_timers_event_want_initiation(struct wg_timers *t) +{ + rw_rlock(&t->t_lock); + if (!t->t_disabled) + wg_timers_run_send_initiation(t, 0); + rw_runlock(&t->t_lock); +} + +static void +wg_timers_event_reset_handshake_last_sent(struct wg_timers *t) +{ + rw_wlock(&t->t_lock); + t->t_handshake_last_sent.tv_sec -= (REKEY_TIMEOUT + 1); + rw_wunlock(&t->t_lock); +} + +static void +wg_timers_run_send_initiation(struct wg_timers *t, int is_retry) +{ + struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers); + if (!is_retry) + t->t_handshake_retries = 0; + if (wg_timers_expired_handshake_last_sent(t) == ETIMEDOUT) + GROUPTASK_ENQUEUE(&peer->p_send_initiation); +} + +static void +wg_timers_run_retry_handshake(struct wg_timers *t) +{ + struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers); + + rw_wlock(&t->t_lock); + if (t->t_handshake_retries <= MAX_TIMER_HANDSHAKES) { + t->t_handshake_retries++; + rw_wunlock(&t->t_lock); + + DPRINTF(peer->p_sc, "Handshake for peer %llu did not complete " + "after %d seconds, retrying (try %d)\n", + (unsigned long long)peer->p_id, + REKEY_TIMEOUT, t->t_handshake_retries + 1); + wg_peer_clear_src(peer); + wg_timers_run_send_initiation(t, 1); + } else { + rw_wunlock(&t->t_lock); + + DPRINTF(peer->p_sc, "Handshake for peer %llu did not complete " + "after %d retries, giving up\n", + (unsigned long long) peer->p_id, MAX_TIMER_HANDSHAKES + 2); + + callout_stop(&t->t_send_keepalive); + wg_queue_purge(&peer->p_stage_queue); + if (!callout_pending(&t->t_zero_key_material)) + callout_reset(&t->t_zero_key_material, + MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000), + (timeout_t *)wg_timers_run_zero_key_material, t); + } +} + +static void +wg_timers_run_send_keepalive(struct wg_timers *t) +{ + struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers); + + GROUPTASK_ENQUEUE(&peer->p_send_keepalive); + if (t->t_need_another_keepalive) { + t->t_need_another_keepalive = 0; + callout_reset(&t->t_send_keepalive, + MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000), + (timeout_t *)wg_timers_run_send_keepalive, t); + } +} + +static void +wg_timers_run_new_handshake(struct wg_timers *t) +{ + struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers); + + DPRINTF(peer->p_sc, "Retrying handshake with peer %llu because we " + "stopped hearing back after %d seconds\n", + (unsigned long long)peer->p_id, NEW_HANDSHAKE_TIMEOUT); + wg_peer_clear_src(peer); + + wg_timers_run_send_initiation(t, 0); +} + +static void +wg_timers_run_zero_key_material(struct wg_timers *t) +{ + struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers); + + DPRINTF(peer->p_sc, "Zeroing out all keys for peer %llu, since we " + "haven't received a new one in %d seconds\n", + (unsigned long long)peer->p_id, REJECT_AFTER_TIME * 3); + GROUPTASK_ENQUEUE(&peer->p_clear_secrets); +} + +static void +wg_timers_run_persistent_keepalive(struct wg_timers *t) +{ + struct wg_peer *peer = __containerof(t, struct wg_peer, p_timers); + + if (t->t_persistent_keepalive_interval != 0) + GROUPTASK_ENQUEUE(&peer->p_send_keepalive); +} + +/* TODO Handshake */ +static void +wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len) +{ + struct wg_endpoint endpoint; + + counter_u64_add(peer->p_tx_bytes, len); + wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers); + wg_timers_event_any_authenticated_packet_sent(&peer->p_timers); + wg_peer_get_endpoint(peer, &endpoint); + wg_send_buf(peer->p_sc, &endpoint, buf, len); +} + +static void +wg_send_initiation(struct wg_peer *peer) +{ + struct wg_pkt_initiation pkt; + struct epoch_tracker et; + + if (wg_timers_check_handshake_last_sent(&peer->p_timers) != ETIMEDOUT) + return; + DPRINTF(peer->p_sc, "Sending handshake initiation to peer %llu\n", + (unsigned long long)peer->p_id); + + NET_EPOCH_ENTER(et); + if (noise_create_initiation(&peer->p_remote, &pkt.s_idx, pkt.ue, + pkt.es, pkt.ets) != 0) + goto out; + pkt.t = WG_PKT_INITIATION; + cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, + sizeof(pkt)-sizeof(pkt.m)); + wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt)); + wg_timers_event_handshake_initiated(&peer->p_timers); +out: + NET_EPOCH_EXIT(et); +} + +static void +wg_send_response(struct wg_peer *peer) +{ + struct wg_pkt_response pkt; + struct epoch_tracker et; + + NET_EPOCH_ENTER(et); + + DPRINTF(peer->p_sc, "Sending handshake response to peer %llu\n", + (unsigned long long)peer->p_id); + + if (noise_create_response(&peer->p_remote, &pkt.s_idx, &pkt.r_idx, + pkt.ue, pkt.en) != 0) + goto out; + if (noise_remote_begin_session(&peer->p_remote) != 0) + goto out; + + wg_timers_event_session_derived(&peer->p_timers); + pkt.t = WG_PKT_RESPONSE; + cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt, + sizeof(pkt)-sizeof(pkt.m)); + wg_timers_event_handshake_responded(&peer->p_timers); + wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt)); +out: + NET_EPOCH_EXIT(et); +} + +static void +wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx, + struct mbuf *m) +{ + struct wg_pkt_cookie pkt; + struct wg_endpoint *e; + + DPRINTF(sc, "Sending cookie response for denied handshake message\n"); + + pkt.t = WG_PKT_COOKIE; + pkt.r_idx = idx; + + e = wg_mbuf_endpoint_get(m); + cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce, + pkt.ec, &e->e_remote.r_sa); + wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt)); +} + +static void +wg_send_keepalive(struct wg_peer *peer) +{ + struct mbuf *m = NULL; + struct wg_tag *t; + struct epoch_tracker et; + + if (wg_queue_len(&peer->p_stage_queue) != 0) { + NET_EPOCH_ENTER(et); + goto send; + } + if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL) + return; + if ((t = wg_tag_get(m)) == NULL) { + m_freem(m); + return; + } + t->t_peer = peer; + t->t_mbuf = NULL; + t->t_done = 0; + t->t_mtu = 0; /* MTU == 0 OK for keepalive */ + + NET_EPOCH_ENTER(et); + wg_queue_stage(peer, m); +send: + wg_queue_out(peer); + NET_EPOCH_EXIT(et); +} + +static int +wg_cookie_validate_packet(struct cookie_checker *checker, struct mbuf *m, + int under_load) +{ + struct wg_pkt_initiation *init; + struct wg_pkt_response *resp; + struct cookie_macs *macs; + struct wg_endpoint *e; + int type, size; + void *data; + + type = *mtod(m, uint32_t *); + data = m->m_data; + e = wg_mbuf_endpoint_get(m); + if (type == WG_PKT_INITIATION) { + init = mtod(m, struct wg_pkt_initiation *); + macs = &init->m; + size = sizeof(*init) - sizeof(*macs); + } else if (type == WG_PKT_RESPONSE) { + resp = mtod(m, struct wg_pkt_response *); + macs = &resp->m; + size = sizeof(*resp) - sizeof(*macs); + } else + return 0; + + return (cookie_checker_validate_macs(checker, macs, data, size, + under_load, &e->e_remote.r_sa)); +} + + +static void +wg_handshake(struct wg_softc *sc, struct mbuf *m) +{ + struct wg_pkt_initiation *init; + struct wg_pkt_response *resp; + struct noise_remote *remote; + struct wg_pkt_cookie *cook; + struct wg_peer *peer; + struct wg_tag *t; + + /* This is global, so that our load calculation applies to the whole + * system. We don't care about races with it at all. + */ + static struct timeval wg_last_underload; + static const struct timeval underload_interval = { UNDERLOAD_TIMEOUT, 0 }; + bool packet_needs_cookie = false; + int underload, res; + + underload = mbufq_len(&sc->sc_handshake_queue) >= + MAX_QUEUED_HANDSHAKES / 8; + if (underload) + getmicrouptime(&wg_last_underload); + else if (wg_last_underload.tv_sec != 0) { + if (!ratecheck(&wg_last_underload, &underload_interval)) + underload = 1; + else + bzero(&wg_last_underload, sizeof(wg_last_underload)); + } + + res = wg_cookie_validate_packet(&sc->sc_cookie, m, underload); + + if (res && res != EAGAIN) { + printf("validate_packet got %d\n", res); + goto free; + } + if (res == EINVAL) { + DPRINTF(sc, "Invalid initiation MAC\n"); + goto free; + } else if (res == ECONNREFUSED) { + DPRINTF(sc, "Handshake ratelimited\n"); + goto free; + } else if (res == EAGAIN) { + packet_needs_cookie = true; + } else if (res != 0) { + DPRINTF(sc, "Unexpected handshake ratelimit response: %d\n", res); + goto free; + } + + t = wg_tag_get(m); + switch (*mtod(m, uint32_t *)) { + case WG_PKT_INITIATION: + init = mtod(m, struct wg_pkt_initiation *); + + if (packet_needs_cookie) { + wg_send_cookie(sc, &init->m, init->s_idx, m); + goto free; + } + if (noise_consume_initiation(&sc->sc_local, &remote, + init->s_idx, init->ue, init->es, init->ets) != 0) { + DPRINTF(sc, "Invalid handshake initiation"); + goto free; + } + + peer = __containerof(remote, struct wg_peer, p_remote); + DPRINTF(sc, "Receiving handshake initiation from peer %llu\n", + (unsigned long long)peer->p_id); + counter_u64_add(peer->p_rx_bytes, sizeof(*init)); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, sizeof(*init)); + wg_peer_set_endpoint_from_tag(peer, t); + wg_send_response(peer); + break; + case WG_PKT_RESPONSE: + resp = mtod(m, struct wg_pkt_response *); + + if (packet_needs_cookie) { + wg_send_cookie(sc, &resp->m, resp->s_idx, m); + goto free; + } + + if ((remote = wg_index_get(sc, resp->r_idx)) == NULL) { + DPRINTF(sc, "Unknown handshake response\n"); + goto free; + } + peer = __containerof(remote, struct wg_peer, p_remote); + if (noise_consume_response(remote, resp->s_idx, resp->r_idx, + resp->ue, resp->en) != 0) { + DPRINTF(sc, "Invalid handshake response\n"); + goto free; + } + + DPRINTF(sc, "Receiving handshake response from peer %llu\n", + (unsigned long long)peer->p_id); + counter_u64_add(peer->p_rx_bytes, sizeof(*resp)); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, sizeof(*resp)); + wg_peer_set_endpoint_from_tag(peer, t); + if (noise_remote_begin_session(&peer->p_remote) == 0) { + wg_timers_event_session_derived(&peer->p_timers); + wg_timers_event_handshake_complete(&peer->p_timers); + } + break; + case WG_PKT_COOKIE: + cook = mtod(m, struct wg_pkt_cookie *); + + if ((remote = wg_index_get(sc, cook->r_idx)) == NULL) { + DPRINTF(sc, "Unknown cookie index\n"); + goto free; + } + + peer = __containerof(remote, struct wg_peer, p_remote); + + if (cookie_maker_consume_payload(&peer->p_cookie, + cook->nonce, cook->ec) != 0) { + DPRINTF(sc, "Could not decrypt cookie response\n"); + goto free; + } + + DPRINTF(sc, "Receiving cookie response\n"); + goto free; + default: + goto free; + } + MPASS(peer != NULL); + wg_timers_event_any_authenticated_packet_received(&peer->p_timers); + wg_timers_event_any_authenticated_packet_traversal(&peer->p_timers); + +free: + m_freem(m); +} + +static void +wg_softc_handshake_receive(struct wg_softc *sc) +{ + struct mbuf *m; + + while ((m = mbufq_dequeue(&sc->sc_handshake_queue)) != NULL) + wg_handshake(sc, m); +} + +/* TODO Encrypt */ +static void +wg_encap(struct wg_softc *sc, struct mbuf *m) +{ + struct wg_pkt_data *data; + size_t padding_len, plaintext_len, out_len; + struct mbuf *mc; + struct wg_peer *peer; + struct wg_tag *t; + uint64_t nonce; + int res, allocation_order; + + NET_EPOCH_ASSERT(); + t = wg_tag_get(m); + peer = t->t_peer; + + plaintext_len = MIN(WG_PKT_WITH_PADDING(m->m_pkthdr.len), t->t_mtu); + padding_len = plaintext_len - m->m_pkthdr.len; + out_len = sizeof(struct wg_pkt_data) + plaintext_len + NOISE_AUTHTAG_LEN; + + if (out_len <= MCLBYTES) + allocation_order = MCLBYTES; + else if (out_len <= MJUMPAGESIZE) + allocation_order = MJUMPAGESIZE; + else if (out_len <= MJUM9BYTES) + allocation_order = MJUM9BYTES; + else if (out_len <= MJUM16BYTES) + allocation_order = MJUM16BYTES; + else + goto error; + + if ((mc = m_getjcl(M_NOWAIT, MT_DATA, M_PKTHDR, allocation_order)) == NULL) + goto error; + + data = mtod(mc, struct wg_pkt_data *); + m_copydata(m, 0, m->m_pkthdr.len, data->buf); + bzero(data->buf + m->m_pkthdr.len, padding_len); + + data->t = WG_PKT_DATA; + + res = noise_remote_encrypt(&peer->p_remote, &data->r_idx, &nonce, + data->buf, plaintext_len); + nonce = htole64(nonce); /* Wire format is little endian. */ + memcpy(data->nonce, &nonce, sizeof(data->nonce)); + + if (__predict_false(res)) { + if (res == EINVAL) { + wg_timers_event_want_initiation(&peer->p_timers); + m_freem(mc); + goto error; + } else if (res == ESTALE) { + wg_timers_event_want_initiation(&peer->p_timers); + } else { + m_freem(mc); + goto error; + } + } + + /* A packet with length 0 is a keepalive packet */ + if (m->m_pkthdr.len == 0) + DPRINTF(sc, "Sending keepalive packet to peer %llu\n", + (unsigned long long)peer->p_id); + /* + * Set the correct output value here since it will be copied + * when we move the pkthdr in send. + */ + mc->m_len = mc->m_pkthdr.len = out_len; + mc->m_flags &= ~(M_MCAST | M_BCAST); + + t->t_mbuf = mc; + error: + /* XXX membar ? */ + t->t_done = 1; + GROUPTASK_ENQUEUE(&peer->p_send); +} + +static void +wg_decap(struct wg_softc *sc, struct mbuf *m) +{ + struct wg_pkt_data *data; + struct wg_peer *peer, *routed_peer; + struct wg_tag *t; + size_t plaintext_len; + uint8_t version; + uint64_t nonce; + int res; + + NET_EPOCH_ASSERT(); + data = mtod(m, struct wg_pkt_data *); + plaintext_len = m->m_pkthdr.len - sizeof(struct wg_pkt_data); + + t = wg_tag_get(m); + peer = t->t_peer; + + memcpy(&nonce, data->nonce, sizeof(nonce)); + nonce = le64toh(nonce); /* Wire format is little endian. */ + + res = noise_remote_decrypt(&peer->p_remote, data->r_idx, nonce, + data->buf, plaintext_len); + + if (__predict_false(res)) { + if (res == EINVAL) { + goto error; + } else if (res == ECONNRESET) { + wg_timers_event_handshake_complete(&peer->p_timers); + } else if (res == ESTALE) { + wg_timers_event_want_initiation(&peer->p_timers); + } else { + panic("unexpected response: %d\n", res); + } + } + wg_peer_set_endpoint_from_tag(peer, t); + + /* Remove the data header, and crypto mac tail from the packet */ + m_adj(m, sizeof(struct wg_pkt_data)); + m_adj(m, -NOISE_AUTHTAG_LEN); + + /* A packet with length 0 is a keepalive packet */ + if (m->m_pkthdr.len == 0) { + DPRINTF(peer->p_sc, "Receiving keepalive packet from peer " + "%llu\n", (unsigned long long)peer->p_id); + goto done; + } + + version = mtod(m, struct ip *)->ip_v; + if (!((version == 4 && m->m_pkthdr.len >= sizeof(struct ip)) || + (version == 6 && m->m_pkthdr.len >= sizeof(struct ip6_hdr)))) { + DPRINTF(peer->p_sc, "Packet is neither ipv4 nor ipv6 from peer " + "%llu\n", (unsigned long long)peer->p_id); + goto error; + } + + routed_peer = wg_aip_lookup(&peer->p_sc->sc_aips, m, IN); + if (routed_peer != peer) { + DPRINTF(peer->p_sc, "Packet has unallowed src IP from peer " + "%llu\n", (unsigned long long)peer->p_id); + goto error; + } + +done: + t->t_mbuf = m; +error: + t->t_done = 1; + GROUPTASK_ENQUEUE(&peer->p_recv); +} + +static void +wg_softc_decrypt(struct wg_softc *sc) +{ + struct epoch_tracker et; + struct mbuf *m; + + NET_EPOCH_ENTER(et); + while ((m = buf_ring_dequeue_mc(sc->sc_decap_ring)) != NULL) + wg_decap(sc, m); + NET_EPOCH_EXIT(et); +} + +static void +wg_softc_encrypt(struct wg_softc *sc) +{ + struct mbuf *m; + struct epoch_tracker et; + + NET_EPOCH_ENTER(et); + while ((m = buf_ring_dequeue_mc(sc->sc_encap_ring)) != NULL) + wg_encap(sc, m); + NET_EPOCH_EXIT(et); +} + +static void +wg_encrypt_dispatch(struct wg_softc *sc) +{ + for (int i = 0; i < mp_ncpus; i++) { + if (sc->sc_encrypt[i].gt_task.ta_flags & TASK_ENQUEUED) + continue; + GROUPTASK_ENQUEUE(&sc->sc_encrypt[i]); + } +} + +static void +wg_decrypt_dispatch(struct wg_softc *sc) +{ + for (int i = 0; i < mp_ncpus; i++) { + if (sc->sc_decrypt[i].gt_task.ta_flags & TASK_ENQUEUED) + continue; + GROUPTASK_ENQUEUE(&sc->sc_decrypt[i]); + } +} + +static void +wg_deliver_out(struct wg_peer *peer) +{ + struct epoch_tracker et; + struct wg_tag *t; + struct mbuf *m; + struct wg_endpoint endpoint; + size_t len; + int ret; + + NET_EPOCH_ENTER(et); + if (peer->p_sc->sc_ifp->if_link_state == LINK_STATE_DOWN) + goto done; + + wg_peer_get_endpoint(peer, &endpoint); + + while ((m = wg_queue_dequeue(&peer->p_encap_queue, &t)) != NULL) { + /* t_mbuf will contain the encrypted packet */ + if (t->t_mbuf == NULL) { + if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OERRORS, 1); + m_freem(m); + continue; + } + len = t->t_mbuf->m_pkthdr.len; + ret = wg_send(peer->p_sc, &endpoint, t->t_mbuf); + + if (ret == 0) { + wg_timers_event_any_authenticated_packet_traversal( + &peer->p_timers); + wg_timers_event_any_authenticated_packet_sent( + &peer->p_timers); + + if (m->m_pkthdr.len != 0) + wg_timers_event_data_sent(&peer->p_timers); + counter_u64_add(peer->p_tx_bytes, len); + } else if (ret == EADDRNOTAVAIL) { + wg_peer_clear_src(peer); + wg_peer_get_endpoint(peer, &endpoint); + } + m_freem(m); + } +done: + NET_EPOCH_EXIT(et); +} + +static void +wg_deliver_in(struct wg_peer *peer) +{ + struct mbuf *m; + struct ifnet *ifp; + struct wg_softc *sc; + struct epoch_tracker et; + struct wg_tag *t; + uint32_t af; + int version; + + NET_EPOCH_ENTER(et); + sc = peer->p_sc; + ifp = sc->sc_ifp; + + while ((m = wg_queue_dequeue(&peer->p_decap_queue, &t)) != NULL) { + /* t_mbuf will contain the encrypted packet */ + if (t->t_mbuf == NULL) { + if_inc_counter(ifp, IFCOUNTER_IERRORS, 1); + m_freem(m); + continue; + } + MPASS(m == t->t_mbuf); + + wg_timers_event_any_authenticated_packet_received( + &peer->p_timers); + wg_timers_event_any_authenticated_packet_traversal( + &peer->p_timers); + + counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); + if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len + sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN); + + if (m->m_pkthdr.len == 0) { + m_freem(m); + continue; + } + + m->m_flags &= ~(M_MCAST | M_BCAST); + m->m_pkthdr.rcvif = ifp; + version = mtod(m, struct ip *)->ip_v; + if (version == IPVERSION) { + af = AF_INET; + BPF_MTAP2(ifp, &af, sizeof(af), m); + CURVNET_SET(ifp->if_vnet); + ip_input(m); + CURVNET_RESTORE(); + } else if (version == 6) { + af = AF_INET6; + BPF_MTAP2(ifp, &af, sizeof(af), m); + CURVNET_SET(ifp->if_vnet); + ip6_input(m); + CURVNET_RESTORE(); + } else + m_freem(m); + + wg_timers_event_data_received(&peer->p_timers); + } + NET_EPOCH_EXIT(et); +} + +static int +wg_queue_in(struct wg_peer *peer, struct mbuf *m) +{ + struct buf_ring *parallel = peer->p_sc->sc_decap_ring; + struct wg_queue *serial = &peer->p_decap_queue; + struct wg_tag *t; + int rc; + + MPASS(wg_tag_get(m) != NULL); + + mtx_lock(&serial->q_mtx); + if ((rc = mbufq_enqueue(&serial->q, m)) == ENOBUFS) { + m_freem(m); + if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1); + } else { + m->m_flags |= M_ENQUEUED; + rc = buf_ring_enqueue(parallel, m); + if (rc == ENOBUFS) { + t = wg_tag_get(m); + t->t_done = 1; + } + } + mtx_unlock(&serial->q_mtx); + return (rc); +} + +static void +wg_queue_stage(struct wg_peer *peer, struct mbuf *m) +{ + struct wg_queue *q = &peer->p_stage_queue; + mtx_lock(&q->q_mtx); + STAILQ_INSERT_TAIL(&q->q.mq_head, m, m_stailqpkt); + q->q.mq_len++; + while (mbufq_full(&q->q)) { + m = mbufq_dequeue(&q->q); + if (m) { + m_freem(m); + if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1); + } + } + mtx_unlock(&q->q_mtx); +} + +static void +wg_queue_out(struct wg_peer *peer) +{ + struct buf_ring *parallel = peer->p_sc->sc_encap_ring; + struct wg_queue *serial = &peer->p_encap_queue; + struct wg_tag *t; + struct mbufq staged; + struct mbuf *m; + + if (noise_remote_ready(&peer->p_remote) != 0) { + if (wg_queue_len(&peer->p_stage_queue)) + wg_timers_event_want_initiation(&peer->p_timers); + return; + } + + /* We first "steal" the staged queue to a local queue, so that we can do these + * remaining operations without having to hold the staged queue mutex. */ + STAILQ_INIT(&staged.mq_head); + mtx_lock(&peer->p_stage_queue.q_mtx); + STAILQ_SWAP(&staged.mq_head, &peer->p_stage_queue.q.mq_head, mbuf); + staged.mq_len = peer->p_stage_queue.q.mq_len; + peer->p_stage_queue.q.mq_len = 0; + staged.mq_maxlen = peer->p_stage_queue.q.mq_maxlen; + mtx_unlock(&peer->p_stage_queue.q_mtx); + + while ((m = mbufq_dequeue(&staged)) != NULL) { + if ((t = wg_tag_get(m)) == NULL) { + m_freem(m); + continue; + } + t->t_peer = peer; + mtx_lock(&serial->q_mtx); + if (mbufq_enqueue(&serial->q, m) != 0) { + m_freem(m); + if_inc_counter(peer->p_sc->sc_ifp, IFCOUNTER_OQDROPS, 1); + } else { + m->m_flags |= M_ENQUEUED; + if (buf_ring_enqueue(parallel, m)) { + t = wg_tag_get(m); + t->t_done = 1; + } + } + mtx_unlock(&serial->q_mtx); + } + wg_encrypt_dispatch(peer->p_sc); +} + +static struct mbuf * +wg_queue_dequeue(struct wg_queue *q, struct wg_tag **t) +{ + struct mbuf *m_, *m; + + m = NULL; + mtx_lock(&q->q_mtx); + m_ = mbufq_first(&q->q); + if (m_ != NULL && (*t = wg_tag_get(m_))->t_done) { + m = mbufq_dequeue(&q->q); + m->m_flags &= ~M_ENQUEUED; + } + mtx_unlock(&q->q_mtx); + return (m); +} + +static int +wg_queue_len(struct wg_queue *q) +{ + /* This access races. We might consider adding locking here. */ + return (mbufq_len(&q->q)); +} + +static void +wg_queue_init(struct wg_queue *q, const char *name) +{ + mtx_init(&q->q_mtx, name, NULL, MTX_DEF); + mbufq_init(&q->q, MAX_QUEUED_PKT); +} + +static void +wg_queue_deinit(struct wg_queue *q) +{ + wg_queue_purge(q); + mtx_destroy(&q->q_mtx); +} + +static void +wg_queue_purge(struct wg_queue *q) +{ + mtx_lock(&q->q_mtx); + mbufq_drain(&q->q); + mtx_unlock(&q->q_mtx); +} + +/* TODO Indexes */ +static struct noise_remote * +wg_remote_get(struct wg_softc *sc, uint8_t public[NOISE_PUBLIC_KEY_LEN]) +{ + struct wg_peer *peer; + + if ((peer = wg_peer_lookup(sc, public)) == NULL) + return (NULL); + return (&peer->p_remote); +} + +static uint32_t +wg_index_set(struct wg_softc *sc, struct noise_remote *remote) +{ + struct wg_index *index, *iter; + struct wg_peer *peer; + uint32_t key; + + /* We can modify this without a lock as wg_index_set, wg_index_drop are + * guaranteed to be serialised (per remote). */ + peer = __containerof(remote, struct wg_peer, p_remote); + index = SLIST_FIRST(&peer->p_unused_index); + MPASS(index != NULL); + SLIST_REMOVE_HEAD(&peer->p_unused_index, i_unused_entry); + + index->i_value = remote; + + rw_wlock(&sc->sc_index_lock); +assign_id: + key = index->i_key = arc4random(); + key &= sc->sc_index_mask; + LIST_FOREACH(iter, &sc->sc_index[key], i_entry) + if (iter->i_key == index->i_key) + goto assign_id; + + LIST_INSERT_HEAD(&sc->sc_index[key], index, i_entry); + + rw_wunlock(&sc->sc_index_lock); + + /* Likewise, no need to lock for index here. */ + return index->i_key; +} + +static struct noise_remote * +wg_index_get(struct wg_softc *sc, uint32_t key0) +{ + struct wg_index *iter; + struct noise_remote *remote = NULL; + uint32_t key = key0 & sc->sc_index_mask; + + rw_enter_read(&sc->sc_index_lock); + LIST_FOREACH(iter, &sc->sc_index[key], i_entry) + if (iter->i_key == key0) { + remote = iter->i_value; + break; + } + rw_exit_read(&sc->sc_index_lock); + return remote; +} + +static void +wg_index_drop(struct wg_softc *sc, uint32_t key0) +{ + struct wg_index *iter; + struct wg_peer *peer = NULL; + uint32_t key = key0 & sc->sc_index_mask; + + rw_enter_write(&sc->sc_index_lock); + LIST_FOREACH(iter, &sc->sc_index[key], i_entry) + if (iter->i_key == key0) { + LIST_REMOVE(iter, i_entry); + break; + } + rw_exit_write(&sc->sc_index_lock); + + if (iter == NULL) + return; + + /* We expect a peer */ + peer = __containerof(iter->i_value, struct wg_peer, p_remote); + MPASS(peer != NULL); + SLIST_INSERT_HEAD(&peer->p_unused_index, iter, i_unused_entry); +} + +static int +wg_update_endpoint_addrs(struct wg_endpoint *e, const struct sockaddr *srcsa, + struct ifnet *rcvif) +{ + const struct sockaddr_in *sa4; +#ifdef INET6 + const struct sockaddr_in6 *sa6; +#endif + int ret = 0; + + /* + * UDP passes a 2-element sockaddr array: first element is the + * source addr/port, second the destination addr/port. + */ + if (srcsa->sa_family == AF_INET) { + sa4 = (const struct sockaddr_in *)srcsa; + e->e_remote.r_sin = sa4[0]; + e->e_local.l_in = sa4[1].sin_addr; +#ifdef INET6 + } else if (srcsa->sa_family == AF_INET6) { + sa6 = (const struct sockaddr_in6 *)srcsa; + e->e_remote.r_sin6 = sa6[0]; + e->e_local.l_in6 = sa6[1].sin6_addr; +#endif + } else { + ret = EAFNOSUPPORT; + } + + return (ret); +} + +static void +wg_input(struct mbuf *m0, int offset, struct inpcb *inpcb, + const struct sockaddr *srcsa, void *_sc) +{ + struct wg_pkt_data *pkt_data; + struct wg_endpoint *e; + struct wg_softc *sc = _sc; + struct mbuf *m; + int pktlen, pkttype; + struct noise_remote *remote; + struct wg_tag *t; + void *data; + + /* Caller provided us with srcsa, no need for this header. */ + m_adj(m0, offset + sizeof(struct udphdr)); + + /* + * Ensure mbuf has at least enough contiguous data to peel off our + * headers at the beginning. + */ + if ((m = m_defrag(m0, M_NOWAIT)) == NULL) { + m_freem(m0); + return; + } + data = mtod(m, void *); + pkttype = *(uint32_t*)data; + t = wg_tag_get(m); + if (t == NULL) { + goto free; + } + e = wg_mbuf_endpoint_get(m); + + if (wg_update_endpoint_addrs(e, srcsa, m->m_pkthdr.rcvif)) { + goto free; + } + + pktlen = m->m_pkthdr.len; + + if ((pktlen == sizeof(struct wg_pkt_initiation) && + pkttype == WG_PKT_INITIATION) || + (pktlen == sizeof(struct wg_pkt_response) && + pkttype == WG_PKT_RESPONSE) || + (pktlen == sizeof(struct wg_pkt_cookie) && + pkttype == WG_PKT_COOKIE)) { + if (mbufq_enqueue(&sc->sc_handshake_queue, m) == 0) { + GROUPTASK_ENQUEUE(&sc->sc_handshake); + } else { + DPRINTF(sc, "Dropping handshake packet\n"); + m_freem(m); + } + } else if (pktlen >= sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN + && pkttype == WG_PKT_DATA) { + + pkt_data = data; + remote = wg_index_get(sc, pkt_data->r_idx); + if (remote == NULL) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1); + m_freem(m); + } else if (buf_ring_count(sc->sc_decap_ring) > MAX_QUEUED_PKT) { + if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1); + m_freem(m); + } else { + t->t_peer = __containerof(remote, struct wg_peer, + p_remote); + t->t_mbuf = NULL; + t->t_done = 0; + + wg_queue_in(t->t_peer, m); + wg_decrypt_dispatch(sc); + } + } else { +free: + m_freem(m); + } +} + +static int +wg_transmit(struct ifnet *ifp, struct mbuf *m) +{ + struct wg_softc *sc; + sa_family_t family; + struct epoch_tracker et; + struct wg_peer *peer; + struct wg_tag *t; + uint32_t af; + int rc; + + /* + * Work around lifetime issue in the ipv6 mld code. + */ + if (__predict_false(ifp->if_flags & IFF_DYING)) + return (ENXIO); + + rc = 0; + sc = ifp->if_softc; + if ((t = wg_tag_get(m)) == NULL) { + rc = ENOBUFS; + goto early_out; + } + af = m->m_pkthdr.ph_family; + BPF_MTAP2(ifp, &af, sizeof(af), m); + + NET_EPOCH_ENTER(et); + peer = wg_aip_lookup(&sc->sc_aips, m, OUT); + if (__predict_false(peer == NULL)) { + rc = ENOKEY; + goto err; + } + + family = peer->p_endpoint.e_remote.r_sa.sa_family; + if (__predict_false(family != AF_INET && family != AF_INET6)) { + DPRINTF(sc, "No valid endpoint has been configured or " + "discovered for peer %llu\n", (unsigned long long)peer->p_id); + + rc = EHOSTUNREACH; + goto err; + } + t->t_peer = peer; + t->t_mbuf = NULL; + t->t_done = 0; + t->t_mtu = ifp->if_mtu; + + wg_queue_stage(peer, m); + wg_queue_out(peer); + NET_EPOCH_EXIT(et); + return (rc); +err: + NET_EPOCH_EXIT(et); +early_out: + if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1); + /* TODO: send ICMP unreachable */ + m_free(m); + return (rc); +} + +static int +wg_output(struct ifnet *ifp, struct mbuf *m, const struct sockaddr *sa, struct route *rt) +{ + m->m_pkthdr.ph_family = sa->sa_family; + return (wg_transmit(ifp, m)); +} + +static int +wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl) +{ + uint8_t public[WG_KEY_SIZE]; + const void *pub_key; + const struct sockaddr *endpoint; + int err; + size_t size; + struct wg_peer *peer = NULL; + bool need_insert = false; + + sx_assert(&sc->sc_lock, SX_XLOCKED); + + if (!nvlist_exists_binary(nvl, "public-key")) { + return (EINVAL); + } + pub_key = nvlist_get_binary(nvl, "public-key", &size); + if (size != WG_KEY_SIZE) { + return (EINVAL); + } + if (noise_local_keys(&sc->sc_local, public, NULL) == 0 && + bcmp(public, pub_key, WG_KEY_SIZE) == 0) { + return (0); // Silently ignored; not actually a failure. + } + peer = wg_peer_lookup(sc, pub_key); + if (nvlist_exists_bool(nvl, "remove") && + nvlist_get_bool(nvl, "remove")) { + if (peer != NULL) { + wg_hashtable_peer_remove(&sc->sc_hashtable, peer); + wg_peer_destroy(peer); + } + return (0); + } + if (nvlist_exists_bool(nvl, "replace-allowedips") && + nvlist_get_bool(nvl, "replace-allowedips") && + peer != NULL) { + + wg_aip_delete(&peer->p_sc->sc_aips, peer); + } + if (peer == NULL) { + if (sc->sc_peer_count >= MAX_PEERS_PER_IFACE) + return (E2BIG); + sc->sc_peer_count++; + + need_insert = true; + peer = wg_peer_alloc(sc); + MPASS(peer != NULL); + noise_remote_init(&peer->p_remote, pub_key, &sc->sc_local); + cookie_maker_init(&peer->p_cookie, pub_key); + } + if (nvlist_exists_binary(nvl, "endpoint")) { + endpoint = nvlist_get_binary(nvl, "endpoint", &size); + if (size > sizeof(peer->p_endpoint.e_remote)) { + err = EINVAL; + goto out; + } + memcpy(&peer->p_endpoint.e_remote, endpoint, size); + } + if (nvlist_exists_binary(nvl, "preshared-key")) { + const void *key; + + key = nvlist_get_binary(nvl, "preshared-key", &size); + if (size != WG_KEY_SIZE) { + err = EINVAL; + goto out; + } + noise_remote_set_psk(&peer->p_remote, key); + } + if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) { + uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval"); + if (pki > UINT16_MAX) { + err = EINVAL; + goto out; + } + wg_timers_set_persistent_keepalive(&peer->p_timers, pki); + } + if (nvlist_exists_nvlist_array(nvl, "allowed-ips")) { + const void *binary; + uint64_t cidr; + const nvlist_t * const * aipl; + struct wg_allowedip aip; + size_t allowedip_count; + + aipl = nvlist_get_nvlist_array(nvl, "allowed-ips", + &allowedip_count); + for (size_t idx = 0; idx < allowedip_count; idx++) { + if (!nvlist_exists_number(aipl[idx], "cidr")) + continue; + cidr = nvlist_get_number(aipl[idx], "cidr"); + if (nvlist_exists_binary(aipl[idx], "ipv4")) { + binary = nvlist_get_binary(aipl[idx], "ipv4", &size); + if (binary == NULL || cidr > 32 || size != sizeof(aip.ip4)) { + err = EINVAL; + goto out; + } + aip.family = AF_INET; + memcpy(&aip.ip4, binary, sizeof(aip.ip4)); + } else if (nvlist_exists_binary(aipl[idx], "ipv6")) { + binary = nvlist_get_binary(aipl[idx], "ipv6", &size); + if (binary == NULL || cidr > 128 || size != sizeof(aip.ip6)) { + err = EINVAL; + goto out; + } + aip.family = AF_INET6; + memcpy(&aip.ip6, binary, sizeof(aip.ip6)); + } else { + continue; + } + aip.cidr = cidr; + + if ((err = wg_aip_add(&sc->sc_aips, peer, &aip)) != 0) { + goto out; + } + } + } + if (need_insert) { + wg_hashtable_peer_insert(&sc->sc_hashtable, peer); + if (sc->sc_ifp->if_link_state == LINK_STATE_UP) + wg_timers_enable(&peer->p_timers); + } + return (0); + +out: + if (need_insert) /* If we fail, only destroy if it was new. */ + wg_peer_destroy(peer); + return (err); +} + +static int +wgc_set(struct wg_softc *sc, struct wg_data_io *wgd) +{ + uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE]; + struct ifnet *ifp; + void *nvlpacked; + nvlist_t *nvl; + ssize_t size; + int err; + + ifp = sc->sc_ifp; + if (wgd->wgd_size == 0 || wgd->wgd_data == NULL) + return (EFAULT); + + sx_xlock(&sc->sc_lock); + + nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK); + err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size); + if (err) + goto out; + nvl = nvlist_unpack(nvlpacked, wgd->wgd_size, 0); + if (nvl == NULL) { + err = EBADMSG; + goto out; + } + if (nvlist_exists_bool(nvl, "replace-peers") && + nvlist_get_bool(nvl, "replace-peers")) + wg_peer_remove_all(sc); + if (nvlist_exists_number(nvl, "listen-port")) { + uint64_t new_port = nvlist_get_number(nvl, "listen-port"); + if (new_port > UINT16_MAX) { + err = EINVAL; + goto out; + } + if (new_port != sc->sc_socket.so_port) { + if ((ifp->if_drv_flags & IFF_DRV_RUNNING) != 0) { + if ((err = wg_socket_init(sc, new_port)) != 0) + goto out; + } else + sc->sc_socket.so_port = new_port; + } + } + if (nvlist_exists_binary(nvl, "private-key")) { + const void *key = nvlist_get_binary(nvl, "private-key", &size); + if (size != WG_KEY_SIZE) { + err = EINVAL; + goto out; + } + + if (noise_local_keys(&sc->sc_local, NULL, private) != 0 || + timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) { + struct noise_local *local; + struct wg_peer *peer; + struct wg_hashtable *ht = &sc->sc_hashtable; + bool has_identity; + + if (curve25519_generate_public(public, key)) { + /* Peer conflict: remove conflicting peer. */ + if ((peer = wg_peer_lookup(sc, public)) != + NULL) { + wg_hashtable_peer_remove(ht, peer); + wg_peer_destroy(peer); + } + } + + /* + * Set the private key and invalidate all existing + * handshakes. + */ + local = &sc->sc_local; + noise_local_lock_identity(local); + /* Note: we might be removing the private key. */ + has_identity = noise_local_set_private(local, key) == 0; + mtx_lock(&ht->h_mtx); + CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) { + noise_remote_precompute(&peer->p_remote); + wg_timers_event_reset_handshake_last_sent( + &peer->p_timers); + noise_remote_expire_current(&peer->p_remote); + } + mtx_unlock(&ht->h_mtx); + cookie_checker_update(&sc->sc_cookie, + has_identity ? public : NULL); + noise_local_unlock_identity(local); + } + } + if (nvlist_exists_number(nvl, "user-cookie")) { + uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie"); + if (user_cookie > UINT32_MAX) { + err = EINVAL; + goto out; + } + wg_socket_set_cookie(sc, user_cookie); + } + if (nvlist_exists_nvlist_array(nvl, "peers")) { + size_t peercount; + const nvlist_t * const*nvl_peers; + + nvl_peers = nvlist_get_nvlist_array(nvl, "peers", &peercount); + for (int i = 0; i < peercount; i++) { + err = wg_peer_add(sc, nvl_peers[i]); + if (err != 0) + goto out; + } + } + + nvlist_destroy(nvl); +out: + free(nvlpacked, M_TEMP); + sx_xunlock(&sc->sc_lock); + return (err); +} + +static unsigned int +in_mask2len(struct in_addr *mask) +{ + unsigned int x, y; + uint8_t *p; + + p = (uint8_t *)mask; + for (x = 0; x < sizeof(*mask); x++) { + if (p[x] != 0xff) + break; + } + y = 0; + if (x < sizeof(*mask)) { + for (y = 0; y < NBBY; y++) { + if ((p[x] & (0x80 >> y)) == 0) + break; + } + } + return x * NBBY + y; +} + +static int +wg_peer_to_export(struct wg_peer *peer, struct wg_peer_export *exp) +{ + struct wg_endpoint *ep; + struct wg_aip *rt; + struct noise_remote *remote; + int i; + + /* Non-sleepable context. */ + NET_EPOCH_ASSERT(); + + bzero(&exp->endpoint, sizeof(exp->endpoint)); + remote = &peer->p_remote; + ep = &peer->p_endpoint; + if (ep->e_remote.r_sa.sa_family != 0) { + exp->endpoint_sz = (ep->e_remote.r_sa.sa_family == AF_INET) ? + sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); + + memcpy(&exp->endpoint, &ep->e_remote, exp->endpoint_sz); + } + + /* We always export it. */ + (void)noise_remote_keys(remote, exp->public_key, exp->preshared_key); + exp->persistent_keepalive = + peer->p_timers.t_persistent_keepalive_interval; + wg_timers_get_last_handshake(&peer->p_timers, &exp->last_handshake); + exp->rx_bytes = counter_u64_fetch(peer->p_rx_bytes); + exp->tx_bytes = counter_u64_fetch(peer->p_tx_bytes); + + exp->aip_count = 0; + CK_LIST_FOREACH(rt, &peer->p_aips, r_entry) { + exp->aip_count++; + } + + /* Early success; no allowed-ips to copy out. */ + if (exp->aip_count == 0) + return (0); + + exp->aip = malloc(exp->aip_count * sizeof(*exp->aip), M_TEMP, M_NOWAIT); + if (exp->aip == NULL) + return (ENOMEM); + + i = 0; + CK_LIST_FOREACH(rt, &peer->p_aips, r_entry) { + exp->aip[i].family = rt->r_addr.ss_family; + if (exp->aip[i].family == AF_INET) { + struct sockaddr_in *sin = + (struct sockaddr_in *)&rt->r_addr; + + exp->aip[i].ip4 = sin->sin_addr; + + sin = (struct sockaddr_in *)&rt->r_mask; + exp->aip[i].cidr = in_mask2len(&sin->sin_addr); + } else if (exp->aip[i].family == AF_INET6) { + struct sockaddr_in6 *sin6 = + (struct sockaddr_in6 *)&rt->r_addr; + + exp->aip[i].ip6 = sin6->sin6_addr; + + sin6 = (struct sockaddr_in6 *)&rt->r_mask; + exp->aip[i].cidr = in6_mask2len(&sin6->sin6_addr, NULL); + } + i++; + if (i == exp->aip_count) + break; + } + + /* Again, AllowedIPs might have shrank; update it. */ + exp->aip_count = i; + + return (0); +} + +static nvlist_t * +wg_peer_export_to_nvl(struct wg_softc *sc, struct wg_peer_export *exp) +{ + struct wg_timespec64 ts64; + nvlist_t *nvl, **nvl_aips; + size_t i; + uint16_t family; + + nvl_aips = NULL; + if ((nvl = nvlist_create(0)) == NULL) + return (NULL); + + nvlist_add_binary(nvl, "public-key", exp->public_key, + sizeof(exp->public_key)); + if (wgc_privileged(sc)) + nvlist_add_binary(nvl, "preshared-key", exp->preshared_key, + sizeof(exp->preshared_key)); + if (exp->endpoint_sz != 0) + nvlist_add_binary(nvl, "endpoint", &exp->endpoint, + exp->endpoint_sz); + + if (exp->aip_count != 0) { + nvl_aips = mallocarray(exp->aip_count, sizeof(*nvl_aips), + M_WG, M_WAITOK | M_ZERO); + } + + for (i = 0; i < exp->aip_count; i++) { + nvl_aips[i] = nvlist_create(0); + if (nvl_aips[i] == NULL) + goto err; + family = exp->aip[i].family; + nvlist_add_number(nvl_aips[i], "cidr", exp->aip[i].cidr); + if (family == AF_INET) + nvlist_add_binary(nvl_aips[i], "ipv4", + &exp->aip[i].ip4, sizeof(exp->aip[i].ip4)); + else if (family == AF_INET6) + nvlist_add_binary(nvl_aips[i], "ipv6", + &exp->aip[i].ip6, sizeof(exp->aip[i].ip6)); + } + + if (i != 0) { + nvlist_add_nvlist_array(nvl, "allowed-ips", + (const nvlist_t *const *)nvl_aips, i); + } + + for (i = 0; i < exp->aip_count; ++i) + nvlist_destroy(nvl_aips[i]); + + free(nvl_aips, M_WG); + nvl_aips = NULL; + + ts64.tv_sec = exp->last_handshake.tv_sec; + ts64.tv_nsec = exp->last_handshake.tv_nsec; + nvlist_add_binary(nvl, "last-handshake-time", &ts64, sizeof(ts64)); + + if (exp->persistent_keepalive != 0) + nvlist_add_number(nvl, "persistent-keepalive-interval", + exp->persistent_keepalive); + + if (exp->rx_bytes != 0) + nvlist_add_number(nvl, "rx-bytes", exp->rx_bytes); + if (exp->tx_bytes != 0) + nvlist_add_number(nvl, "tx-bytes", exp->tx_bytes); + + return (nvl); +err: + for (i = 0; i < exp->aip_count && nvl_aips[i] != NULL; i++) { + nvlist_destroy(nvl_aips[i]); + } + + free(nvl_aips, M_WG); + nvlist_destroy(nvl); + return (NULL); +} + +static int +wg_marshal_peers(struct wg_softc *sc, nvlist_t **nvlp, nvlist_t ***nvl_arrayp, int *peer_countp) +{ + struct wg_peer *peer; + int err, i, peer_count; + nvlist_t *nvl, **nvl_array; + struct epoch_tracker et; + struct wg_peer_export *wpe; + + nvl = NULL; + nvl_array = NULL; + if (nvl_arrayp) + *nvl_arrayp = NULL; + if (nvlp) + *nvlp = NULL; + if (peer_countp) + *peer_countp = 0; + peer_count = sc->sc_hashtable.h_num_peers; + if (peer_count == 0) { + return (ENOENT); + } + + if (nvlp && (nvl = nvlist_create(0)) == NULL) + return (ENOMEM); + + err = i = 0; + nvl_array = malloc(peer_count*sizeof(void*), M_TEMP, M_WAITOK | M_ZERO); + wpe = malloc(peer_count*sizeof(*wpe), M_TEMP, M_WAITOK | M_ZERO); + + NET_EPOCH_ENTER(et); + CK_LIST_FOREACH(peer, &sc->sc_hashtable.h_peers_list, p_entry) { + if ((err = wg_peer_to_export(peer, &wpe[i])) != 0) { + break; + } + + i++; + if (i == peer_count) + break; + } + NET_EPOCH_EXIT(et); + + if (err != 0) + goto out; + + /* Update the peer count, in case we found fewer entries. */ + *peer_countp = peer_count = i; + if (peer_count == 0) { + err = ENOENT; + goto out; + } + + for (i = 0; i < peer_count; i++) { + int idx; + + /* + * Peers are added to the list in reverse order, effectively, + * because it's simpler/quicker to add at the head every time. + * + * Export them in reverse order. No worries if we fail mid-way + * through, the cleanup below will DTRT. + */ + idx = peer_count - i - 1; + nvl_array[idx] = wg_peer_export_to_nvl(sc, &wpe[i]); + if (nvl_array[idx] == NULL) { + break; + } + } + + if (i < peer_count) { + /* Error! */ + *peer_countp = 0; + err = ENOMEM; + } else if (nvl) { + nvlist_add_nvlist_array(nvl, "peers", + (const nvlist_t * const *)nvl_array, peer_count); + if ((err = nvlist_error(nvl))) { + goto out; + } + *nvlp = nvl; + } + *nvl_arrayp = nvl_array; + out: + if (err != 0) { + /* Note that nvl_array is populated in reverse order. */ + for (i = 0; i < peer_count; i++) { + nvlist_destroy(nvl_array[i]); + } + + free(nvl_array, M_TEMP); + if (nvl != NULL) + nvlist_destroy(nvl); + } + + for (i = 0; i < peer_count; i++) + free(wpe[i].aip, M_TEMP); + free(wpe, M_TEMP); + return (err); +} + +static int +wgc_get(struct wg_softc *sc, struct wg_data_io *wgd) +{ + nvlist_t *nvl, **nvl_array; + void *packed; + size_t size; + int peer_count, err; + + nvl = nvlist_create(0); + if (nvl == NULL) + return (ENOMEM); + + sx_slock(&sc->sc_lock); + + err = 0; + packed = NULL; + if (sc->sc_socket.so_port != 0) + nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port); + if (sc->sc_socket.so_user_cookie != 0) + nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie); + if (sc->sc_local.l_has_identity) { + nvlist_add_binary(nvl, "public-key", sc->sc_local.l_public, WG_KEY_SIZE); + if (wgc_privileged(sc)) + nvlist_add_binary(nvl, "private-key", sc->sc_local.l_private, WG_KEY_SIZE); + } + if (sc->sc_hashtable.h_num_peers > 0) { + err = wg_marshal_peers(sc, NULL, &nvl_array, &peer_count); + if (err) + goto out_nvl; + nvlist_add_nvlist_array(nvl, "peers", + (const nvlist_t * const *)nvl_array, peer_count); + } + packed = nvlist_pack(nvl, &size); + if (packed == NULL) { + err = ENOMEM; + goto out_nvl; + } + if (wgd->wgd_size == 0) { + wgd->wgd_size = size; + goto out_packed; + } + if (wgd->wgd_size < size) { + err = ENOSPC; + goto out_packed; + } + if (wgd->wgd_data == NULL) { + err = EFAULT; + goto out_packed; + } + err = copyout(packed, wgd->wgd_data, size); + wgd->wgd_size = size; + +out_packed: + free(packed, M_NVLIST); +out_nvl: + nvlist_destroy(nvl); + sx_sunlock(&sc->sc_lock); + return (err); +} + +static int +wg_ioctl(struct ifnet *ifp, u_long cmd, caddr_t data) +{ + struct wg_data_io *wgd = (struct wg_data_io *)data; + struct ifreq *ifr = (struct ifreq *)data; + struct wg_softc *sc = ifp->if_softc; + int ret = 0; + + switch (cmd) { + case SIOCSWG: + ret = priv_check(curthread, PRIV_NET_WG); + if (ret == 0) + ret = wgc_set(sc, wgd); + break; + case SIOCGWG: + ret = wgc_get(sc, wgd); + break; + /* Interface IOCTLs */ + case SIOCSIFADDR: + /* + * This differs from *BSD norms, but is more uniform with how + * WireGuard behaves elsewhere. + */ + break; + case SIOCSIFFLAGS: + if ((ifp->if_flags & IFF_UP) != 0) + ret = wg_up(sc); + else + wg_down(sc); + break; + case SIOCSIFMTU: + if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > MAX_MTU) + ret = EINVAL; + else + ifp->if_mtu = ifr->ifr_mtu; + break; + case SIOCADDMULTI: + case SIOCDELMULTI: + break; + default: + ret = ENOTTY; + } + + return ret; +} + +static int +wg_up(struct wg_softc *sc) +{ + struct wg_hashtable *ht = &sc->sc_hashtable; + struct ifnet *ifp = sc->sc_ifp; + struct wg_peer *peer; + int rc = EBUSY; + + sx_xlock(&sc->sc_lock); + /* Jail's being removed, no more wg_up(). */ + if ((sc->sc_flags & WGF_DYING) != 0) + goto out; + + /* Silent success if we're already running. */ + rc = 0; + if (ifp->if_drv_flags & IFF_DRV_RUNNING) + goto out; + ifp->if_drv_flags |= IFF_DRV_RUNNING; + + rc = wg_socket_init(sc, sc->sc_socket.so_port); + if (rc == 0) { + mtx_lock(&ht->h_mtx); + CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) { + wg_timers_enable(&peer->p_timers); + wg_queue_out(peer); + } + mtx_unlock(&ht->h_mtx); + + if_link_state_change(sc->sc_ifp, LINK_STATE_UP); + } else { + ifp->if_drv_flags &= ~IFF_DRV_RUNNING; + } +out: + sx_xunlock(&sc->sc_lock); + return (rc); +} + +static void +wg_down(struct wg_softc *sc) +{ + struct wg_hashtable *ht = &sc->sc_hashtable; + struct ifnet *ifp = sc->sc_ifp; + struct wg_peer *peer; + + sx_xlock(&sc->sc_lock); + if (!(ifp->if_drv_flags & IFF_DRV_RUNNING)) { + sx_xunlock(&sc->sc_lock); + return; + } + ifp->if_drv_flags &= ~IFF_DRV_RUNNING; + + mtx_lock(&ht->h_mtx); + CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) { + wg_queue_purge(&peer->p_stage_queue); + wg_timers_disable(&peer->p_timers); + } + mtx_unlock(&ht->h_mtx); + + mbufq_drain(&sc->sc_handshake_queue); + + mtx_lock(&ht->h_mtx); + CK_LIST_FOREACH(peer, &ht->h_peers_list, p_entry) { + noise_remote_clear(&peer->p_remote); + wg_timers_event_reset_handshake_last_sent(&peer->p_timers); + } + mtx_unlock(&ht->h_mtx); + + if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN); + wg_socket_uninit(sc); + + sx_xunlock(&sc->sc_lock); +} + +static void +crypto_taskq_setup(struct wg_softc *sc) +{ + + sc->sc_encrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK); + sc->sc_decrypt = malloc(sizeof(struct grouptask)*mp_ncpus, M_WG, M_WAITOK); + + for (int i = 0; i < mp_ncpus; i++) { + GROUPTASK_INIT(&sc->sc_encrypt[i], 0, + (gtask_fn_t *)wg_softc_encrypt, sc); + taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_encrypt[i], sc, i, NULL, NULL, "wg encrypt"); + GROUPTASK_INIT(&sc->sc_decrypt[i], 0, + (gtask_fn_t *)wg_softc_decrypt, sc); + taskqgroup_attach_cpu(qgroup_if_io_tqg, &sc->sc_decrypt[i], sc, i, NULL, NULL, "wg decrypt"); + } +} + +static void +crypto_taskq_destroy(struct wg_softc *sc) +{ + for (int i = 0; i < mp_ncpus; i++) { + taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_encrypt[i]); + taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_decrypt[i]); + } + free(sc->sc_encrypt, M_WG); + free(sc->sc_decrypt, M_WG); +} + +static int +wg_clone_create(struct if_clone *ifc, int unit, caddr_t params) +{ + struct wg_softc *sc; + struct ifnet *ifp; + struct noise_upcall noise_upcall; + + sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO); + sc->sc_ucred = crhold(curthread->td_ucred); + ifp = sc->sc_ifp = if_alloc(IFT_WIREGUARD); + ifp->if_softc = sc; + if_initname(ifp, wgname, unit); + + noise_upcall.u_arg = sc; + noise_upcall.u_remote_get = + (struct noise_remote *(*)(void *, uint8_t *))wg_remote_get; + noise_upcall.u_index_set = + (uint32_t (*)(void *, struct noise_remote *))wg_index_set; + noise_upcall.u_index_drop = + (void (*)(void *, uint32_t))wg_index_drop; + noise_local_init(&sc->sc_local, &noise_upcall); + cookie_checker_init(&sc->sc_cookie, ratelimit_zone); + + sc->sc_socket.so_port = 0; + + atomic_add_int(&clone_count, 1); + ifp->if_capabilities = ifp->if_capenable = WG_CAPS; + + mbufq_init(&sc->sc_handshake_queue, MAX_QUEUED_HANDSHAKES); + sx_init(&sc->sc_lock, "wg softc lock"); + rw_init(&sc->sc_index_lock, "wg index lock"); + sc->sc_peer_count = 0; + sc->sc_encap_ring = buf_ring_alloc(MAX_QUEUED_PKT, M_WG, M_WAITOK, NULL); + sc->sc_decap_ring = buf_ring_alloc(MAX_QUEUED_PKT, M_WG, M_WAITOK, NULL); + GROUPTASK_INIT(&sc->sc_handshake, 0, + (gtask_fn_t *)wg_softc_handshake_receive, sc); + taskqgroup_attach(qgroup_if_io_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation"); + crypto_taskq_setup(sc); + + wg_hashtable_init(&sc->sc_hashtable); + sc->sc_index = hashinit(HASHTABLE_INDEX_SIZE, M_DEVBUF, &sc->sc_index_mask); + wg_aip_init(&sc->sc_aips); + + if_setmtu(ifp, ETHERMTU - 80); + ifp->if_flags = IFF_BROADCAST | IFF_MULTICAST | IFF_NOARP; + ifp->if_init = wg_init; + ifp->if_reassign = wg_reassign; + ifp->if_qflush = wg_qflush; + ifp->if_transmit = wg_transmit; + ifp->if_output = wg_output; + ifp->if_ioctl = wg_ioctl; + + if_attach(ifp); + bpfattach(ifp, DLT_NULL, sizeof(uint32_t)); + + sx_xlock(&wg_sx); + LIST_INSERT_HEAD(&wg_list, sc, sc_entry); + sx_xunlock(&wg_sx); + + return 0; +} + +static void +wg_clone_destroy(struct ifnet *ifp) +{ + struct wg_softc *sc = ifp->if_softc; + struct ucred *cred; + + sx_xlock(&wg_sx); + sx_xlock(&sc->sc_lock); + sc->sc_flags |= WGF_DYING; + cred = sc->sc_ucred; + sc->sc_ucred = NULL; + sx_xunlock(&sc->sc_lock); + LIST_REMOVE(sc, sc_entry); + sx_xunlock(&wg_sx); + + if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN); + + sx_xlock(&sc->sc_lock); + wg_socket_uninit(sc); + sx_xunlock(&sc->sc_lock); + + /* + * No guarantees that all traffic have passed until the epoch has + * elapsed with the socket closed. + */ + NET_EPOCH_WAIT(); + + taskqgroup_drain_all(qgroup_if_io_tqg); + sx_xlock(&sc->sc_lock); + wg_peer_remove_all(sc); + epoch_drain_callbacks(net_epoch_preempt); + sx_xunlock(&sc->sc_lock); + sx_destroy(&sc->sc_lock); + rw_destroy(&sc->sc_index_lock); + taskqgroup_detach(qgroup_if_io_tqg, &sc->sc_handshake); + crypto_taskq_destroy(sc); + buf_ring_free(sc->sc_encap_ring, M_WG); + buf_ring_free(sc->sc_decap_ring, M_WG); + + wg_aip_destroy(&sc->sc_aips); + wg_hashtable_destroy(&sc->sc_hashtable); + + if (cred != NULL) + crfree(cred); + if_detach(sc->sc_ifp); + if_free(sc->sc_ifp); + /* Ensure any local/private keys are cleaned up */ + explicit_bzero(sc, sizeof(*sc)); + free(sc, M_WG); + + atomic_add_int(&clone_count, -1); +} + +static void +wg_qflush(struct ifnet *ifp __unused) +{ +} + +/* + * Privileged information (private-key, preshared-key) are only exported for + * root and jailed root by default. + */ +static bool +wgc_privileged(struct wg_softc *sc) +{ + struct thread *td; + + td = curthread; + return (priv_check(td, PRIV_NET_WG) == 0); +} + +static void +wg_reassign(struct ifnet *ifp, struct vnet *new_vnet __unused, + char *unused __unused) +{ + struct wg_softc *sc; + + sc = ifp->if_softc; + wg_down(sc); +} + +static void +wg_init(void *xsc) +{ + struct wg_softc *sc; + + sc = xsc; + wg_up(sc); +} + +static void +vnet_wg_init(const void *unused __unused) +{ + + V_wg_cloner = if_clone_simple(wgname, wg_clone_create, wg_clone_destroy, + 0); +} +VNET_SYSINIT(vnet_wg_init, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY, + vnet_wg_init, NULL); + +static void +vnet_wg_uninit(const void *unused __unused) +{ + + if_clone_detach(V_wg_cloner); +} +VNET_SYSUNINIT(vnet_wg_uninit, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY, + vnet_wg_uninit, NULL); + +static int +wg_prison_remove(void *obj, void *data __unused) +{ + const struct prison *pr = obj; + struct wg_softc *sc; + struct ucred *cred; + bool dying; + + /* + * Do a pass through all if_wg interfaces and release creds on any from + * the jail that are supposed to be going away. This will, in turn, let + * the jail die so that we don't end up with Schrödinger's jail. + */ + sx_slock(&wg_sx); + LIST_FOREACH(sc, &wg_list, sc_entry) { + cred = NULL; + + sx_xlock(&sc->sc_lock); + dying = (sc->sc_flags & WGF_DYING) != 0; + if (!dying && sc->sc_ucred != NULL && + sc->sc_ucred->cr_prison == pr) { + /* Home jail is going away. */ + cred = sc->sc_ucred; + sc->sc_ucred = NULL; + + sc->sc_flags |= WGF_DYING; + } + + /* + * If this is our foreign vnet going away, we'll also down the + * link and kill the socket because traffic needs to stop. Any + * address will be revoked in the rehoming process. + */ + if (cred != NULL || (!dying && + sc->sc_ifp->if_vnet == pr->pr_vnet)) { + if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN); + /* Have to kill the sockets, as they also hold refs. */ + wg_socket_uninit(sc); + } + + sx_xunlock(&sc->sc_lock); + + if (cred != NULL) { + CURVNET_SET(sc->sc_ifp->if_vnet); + if_purgeaddrs(sc->sc_ifp); + CURVNET_RESTORE(); + crfree(cred); + } + } + sx_sunlock(&wg_sx); + + return (0); +} + +static void +wg_module_init(void) +{ + osd_method_t methods[PR_MAXMETHOD] = { + [PR_METHOD_REMOVE] = wg_prison_remove, + }; + + ratelimit_zone = uma_zcreate("wg ratelimit", sizeof(struct ratelimit), + NULL, NULL, NULL, NULL, 0, 0); + wg_osd_jail_slot = osd_jail_register(NULL, methods); +} + +static void +wg_module_deinit(void) +{ + + uma_zdestroy(ratelimit_zone); + osd_jail_deregister(wg_osd_jail_slot); + + MPASS(LIST_EMPTY(&wg_list)); +} + +static int +wg_module_event_handler(module_t mod, int what, void *arg) +{ + + switch (what) { + case MOD_LOAD: + wg_module_init(); + break; + case MOD_UNLOAD: + if (atomic_load_int(&clone_count) == 0) + wg_module_deinit(); + else + return (EBUSY); + break; + default: + return (EOPNOTSUPP); + } + return (0); +} + +static moduledata_t wg_moduledata = { + "wg", + wg_module_event_handler, + NULL +}; + +DECLARE_MODULE(wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY); +MODULE_VERSION(wg, 1); +MODULE_DEPEND(wg, crypto, 1, 1, 1); diff --git a/src/if_wg.h b/src/if_wg.h new file mode 100644 index 0000000..f137c93 --- /dev/null +++ b/src/if_wg.h @@ -0,0 +1,37 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (c) 2019 Matt Dunwoodie + * + * Permission to use, copy, modify, and distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + * $FreeBSD$ + */ + +#ifndef __IF_WG_H__ +#define __IF_WG_H__ + +#include +#include + +struct wg_data_io { + char wgd_name[IFNAMSIZ]; + void *wgd_data; + size_t wgd_size; +}; + +#define WG_KEY_SIZE 32 + +#define SIOCSWG _IOWR('i', 210, struct wg_data_io) +#define SIOCGWG _IOWR('i', 211, struct wg_data_io) + +#endif /* __IF_WG_H__ */ diff --git a/src/support.h b/src/support.h new file mode 100644 index 0000000..f613f40 --- /dev/null +++ b/src/support.h @@ -0,0 +1,56 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2021 Jason A. Donenfeld . All Rights Reserved. + * Copyright (C) 2021 Matt Dunwoodie + */ + +#ifndef _WG_SUPPORT +#define _WG_SUPPORT + +#include +#include +#include +#include +#include +#include +#include +#include + +/* TODO the following is openbsd compat defines to allow us to copy the wg_* + * files from openbsd (almost) verbatim. this will greatly increase maintenance + * across the platforms. it should be moved to it's own file. the only thing + * we're missing from this is struct pool (freebsd: uma_zone_t), which isn't a + * show stopper, but is something worth considering in the future. + * - md */ + +#define rw_assert_wrlock(x) rw_assert(x, RA_WLOCKED) +#define rw_enter_write rw_wlock +#define rw_exit_write rw_wunlock +#define rw_enter_read rw_rlock +#define rw_exit_read rw_runlock +#define rw_exit rw_unlock + +#define RW_DOWNGRADE 1 +#define rw_enter(x, y) do { \ + CTASSERT(y == RW_DOWNGRADE); \ + rw_downgrade(x); \ +} while (0) + +MALLOC_DECLARE(M_WG); + +#include +typedef struct { + uint64_t k0; + uint64_t k1; +} SIPHASH_KEY; + +static inline uint64_t +siphash24(const SIPHASH_KEY *key, const void *src, size_t len) +{ + SIPHASH_CTX ctx; + + return (SipHashX(&ctx, 2, 4, (const uint8_t *)key, src, len)); +} +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) + +#endif diff --git a/src/wg_cookie.c b/src/wg_cookie.c new file mode 100644 index 0000000..bf0ce37 --- /dev/null +++ b/src/wg_cookie.c @@ -0,0 +1,427 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2015-2021 Jason A. Donenfeld . All Rights Reserved. + * Copyright (C) 2019-2021 Matt Dunwoodie + */ + +#include +#include +#include +#include +#include /* Because systm doesn't include M_NOWAIT, M_DEVBUF */ +#include + +#include "support.h" +#include "wg_cookie.h" + +static void cookie_precompute_key(uint8_t *, + const uint8_t[COOKIE_INPUT_SIZE], const char *); +static void cookie_macs_mac1(struct cookie_macs *, const void *, size_t, + const uint8_t[COOKIE_KEY_SIZE]); +static void cookie_macs_mac2(struct cookie_macs *, const void *, size_t, + const uint8_t[COOKIE_COOKIE_SIZE]); +static int cookie_timer_expired(struct timespec *, time_t, long); +static void cookie_checker_make_cookie(struct cookie_checker *, + uint8_t[COOKIE_COOKIE_SIZE], struct sockaddr *); +static int ratelimit_init(struct ratelimit *, uma_zone_t); +static void ratelimit_deinit(struct ratelimit *); +static void ratelimit_gc(struct ratelimit *, int); +static int ratelimit_allow(struct ratelimit *, struct sockaddr *); + +/* Public Functions */ +void +cookie_maker_init(struct cookie_maker *cp, const uint8_t key[COOKIE_INPUT_SIZE]) +{ + bzero(cp, sizeof(*cp)); + cookie_precompute_key(cp->cp_mac1_key, key, COOKIE_MAC1_KEY_LABEL); + cookie_precompute_key(cp->cp_cookie_key, key, COOKIE_COOKIE_KEY_LABEL); + rw_init(&cp->cp_lock, "cookie_maker"); +} + +int +cookie_checker_init(struct cookie_checker *cc, uma_zone_t zone) +{ + int res; + bzero(cc, sizeof(*cc)); + + rw_init(&cc->cc_key_lock, "cookie_checker_key"); + rw_init(&cc->cc_secret_lock, "cookie_checker_secret"); + + if ((res = ratelimit_init(&cc->cc_ratelimit_v4, zone)) != 0) + return res; +#ifdef INET6 + if ((res = ratelimit_init(&cc->cc_ratelimit_v6, zone)) != 0) { + ratelimit_deinit(&cc->cc_ratelimit_v4); + return res; + } +#endif + return 0; +} + +void +cookie_checker_update(struct cookie_checker *cc, + const uint8_t key[COOKIE_INPUT_SIZE]) +{ + rw_enter_write(&cc->cc_key_lock); + if (key) { + cookie_precompute_key(cc->cc_mac1_key, key, COOKIE_MAC1_KEY_LABEL); + cookie_precompute_key(cc->cc_cookie_key, key, COOKIE_COOKIE_KEY_LABEL); + } else { + bzero(cc->cc_mac1_key, sizeof(cc->cc_mac1_key)); + bzero(cc->cc_cookie_key, sizeof(cc->cc_cookie_key)); + } + rw_exit_write(&cc->cc_key_lock); +} + +void +cookie_checker_deinit(struct cookie_checker *cc) +{ + ratelimit_deinit(&cc->cc_ratelimit_v4); +#ifdef INET6 + ratelimit_deinit(&cc->cc_ratelimit_v6); +#endif +} + +void +cookie_checker_create_payload(struct cookie_checker *cc, + struct cookie_macs *cm, uint8_t nonce[COOKIE_NONCE_SIZE], + uint8_t ecookie[COOKIE_ENCRYPTED_SIZE], struct sockaddr *sa) +{ + uint8_t cookie[COOKIE_COOKIE_SIZE]; + + cookie_checker_make_cookie(cc, cookie, sa); + arc4random_buf(nonce, COOKIE_NONCE_SIZE); + + rw_enter_read(&cc->cc_key_lock); + xchacha20poly1305_encrypt(ecookie, cookie, COOKIE_COOKIE_SIZE, + cm->mac1, COOKIE_MAC_SIZE, nonce, cc->cc_cookie_key); + rw_exit_read(&cc->cc_key_lock); + + explicit_bzero(cookie, sizeof(cookie)); +} + +int +cookie_maker_consume_payload(struct cookie_maker *cp, + uint8_t nonce[COOKIE_NONCE_SIZE], uint8_t ecookie[COOKIE_ENCRYPTED_SIZE]) +{ + int ret = 0; + uint8_t cookie[COOKIE_COOKIE_SIZE]; + + rw_enter_write(&cp->cp_lock); + + if (cp->cp_mac1_valid == 0) { + ret = ETIMEDOUT; + goto error; + } + + if (xchacha20poly1305_decrypt(cookie, ecookie, COOKIE_ENCRYPTED_SIZE, + cp->cp_mac1_last, COOKIE_MAC_SIZE, nonce, cp->cp_cookie_key) == 0) { + ret = EINVAL; + goto error; + } + + memcpy(cp->cp_cookie, cookie, COOKIE_COOKIE_SIZE); + getnanouptime(&cp->cp_birthdate); + cp->cp_mac1_valid = 0; + +error: + rw_exit_write(&cp->cp_lock); + return ret; +} + +void +cookie_maker_mac(struct cookie_maker *cp, struct cookie_macs *cm, void *buf, + size_t len) +{ + rw_enter_read(&cp->cp_lock); + + cookie_macs_mac1(cm, buf, len, cp->cp_mac1_key); + + memcpy(cp->cp_mac1_last, cm->mac1, COOKIE_MAC_SIZE); + cp->cp_mac1_valid = 1; + + if (!cookie_timer_expired(&cp->cp_birthdate, + COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY, 0)) + cookie_macs_mac2(cm, buf, len, cp->cp_cookie); + else + bzero(cm->mac2, COOKIE_MAC_SIZE); + + rw_exit_read(&cp->cp_lock); +} + +int +cookie_checker_validate_macs(struct cookie_checker *cc, struct cookie_macs *cm, + void *buf, size_t len, int busy, struct sockaddr *sa) +{ + struct cookie_macs our_cm; + uint8_t cookie[COOKIE_COOKIE_SIZE]; + + /* Validate incoming MACs */ + rw_enter_read(&cc->cc_key_lock); + cookie_macs_mac1(&our_cm, buf, len, cc->cc_mac1_key); + rw_exit_read(&cc->cc_key_lock); + + /* If mac1 is invald, we want to drop the packet */ + if (timingsafe_bcmp(our_cm.mac1, cm->mac1, COOKIE_MAC_SIZE) != 0) + return EINVAL; + + if (busy != 0) { + cookie_checker_make_cookie(cc, cookie, sa); + cookie_macs_mac2(&our_cm, buf, len, cookie); + + /* If the mac2 is invalid, we want to send a cookie response */ + if (timingsafe_bcmp(our_cm.mac2, cm->mac2, COOKIE_MAC_SIZE) != 0) + return EAGAIN; + + /* If the mac2 is valid, we may want rate limit the peer. + * ratelimit_allow will return either 0 or ECONNREFUSED, + * implying there is no ratelimiting, or we should ratelimit + * (refuse) respectively. */ + if (sa->sa_family == AF_INET) + return ratelimit_allow(&cc->cc_ratelimit_v4, sa); +#ifdef INET6 + else if (sa->sa_family == AF_INET6) + return ratelimit_allow(&cc->cc_ratelimit_v6, sa); +#endif + else + return EAFNOSUPPORT; + } + return 0; +} + +/* Private functions */ +static void +cookie_precompute_key(uint8_t *key, const uint8_t input[COOKIE_INPUT_SIZE], + const char *label) +{ + struct blake2s_state blake; + + blake2s_init(&blake, COOKIE_KEY_SIZE); + blake2s_update(&blake, label, strlen(label)); + blake2s_update(&blake, input, COOKIE_INPUT_SIZE); + /* TODO we shouldn't need to provide outlen to _final. we can align + * this with openbsd after fixing the blake library. */ + blake2s_final(&blake, key); +} + +static void +cookie_macs_mac1(struct cookie_macs *cm, const void *buf, size_t len, + const uint8_t key[COOKIE_KEY_SIZE]) +{ + struct blake2s_state state; + blake2s_init_key(&state, COOKIE_MAC_SIZE, key, COOKIE_KEY_SIZE); + blake2s_update(&state, buf, len); + blake2s_final(&state, cm->mac1); +} + +static void +cookie_macs_mac2(struct cookie_macs *cm, const void *buf, size_t len, + const uint8_t key[COOKIE_COOKIE_SIZE]) +{ + struct blake2s_state state; + blake2s_init_key(&state, COOKIE_MAC_SIZE, key, COOKIE_COOKIE_SIZE); + blake2s_update(&state, buf, len); + blake2s_update(&state, cm->mac1, COOKIE_MAC_SIZE); + blake2s_final(&state, cm->mac2); +} + +static int +cookie_timer_expired(struct timespec *birthdate, time_t sec, long nsec) +{ + struct timespec uptime; + struct timespec expire = { .tv_sec = sec, .tv_nsec = nsec }; + + if (birthdate->tv_sec == 0 && birthdate->tv_nsec == 0) + return ETIMEDOUT; + + getnanouptime(&uptime); + timespecadd(birthdate, &expire, &expire); + return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0; +} + +static void +cookie_checker_make_cookie(struct cookie_checker *cc, + uint8_t cookie[COOKIE_COOKIE_SIZE], struct sockaddr *sa) +{ + struct blake2s_state state; + + rw_enter_write(&cc->cc_secret_lock); + if (cookie_timer_expired(&cc->cc_secret_birthdate, + COOKIE_SECRET_MAX_AGE, 0)) { + arc4random_buf(cc->cc_secret, COOKIE_SECRET_SIZE); + getnanouptime(&cc->cc_secret_birthdate); + } + blake2s_init_key(&state, COOKIE_COOKIE_SIZE, cc->cc_secret, + COOKIE_SECRET_SIZE); + rw_exit_write(&cc->cc_secret_lock); + + if (sa->sa_family == AF_INET) { + blake2s_update(&state, (uint8_t *)&satosin(sa)->sin_addr, + sizeof(struct in_addr)); + blake2s_update(&state, (uint8_t *)&satosin(sa)->sin_port, + sizeof(in_port_t)); + blake2s_final(&state, cookie); +#ifdef INET6 + } else if (sa->sa_family == AF_INET6) { + blake2s_update(&state, (uint8_t *)&satosin6(sa)->sin6_addr, + sizeof(struct in6_addr)); + blake2s_update(&state, (uint8_t *)&satosin6(sa)->sin6_port, + sizeof(in_port_t)); + blake2s_final(&state, cookie); +#endif + } else { + arc4random_buf(cookie, COOKIE_COOKIE_SIZE); + } +} + +static int +ratelimit_init(struct ratelimit *rl, uma_zone_t zone) +{ + rw_init(&rl->rl_lock, "ratelimit_lock"); + arc4random_buf(&rl->rl_secret, sizeof(rl->rl_secret)); + rl->rl_table = hashinit_flags(RATELIMIT_SIZE, M_DEVBUF, + &rl->rl_table_mask, M_NOWAIT); + rl->rl_zone = zone; + rl->rl_table_num = 0; + return rl->rl_table == NULL ? ENOBUFS : 0; +} + +static void +ratelimit_deinit(struct ratelimit *rl) +{ + rw_enter_write(&rl->rl_lock); + ratelimit_gc(rl, 1); + hashdestroy(rl->rl_table, M_DEVBUF, rl->rl_table_mask); + rw_exit_write(&rl->rl_lock); +} + +static void +ratelimit_gc(struct ratelimit *rl, int force) +{ + size_t i; + struct ratelimit_entry *r, *tr; + struct timespec expiry; + + rw_assert_wrlock(&rl->rl_lock); + + if (force) { + for (i = 0; i < RATELIMIT_SIZE; i++) { + LIST_FOREACH_SAFE(r, &rl->rl_table[i], r_entry, tr) { + rl->rl_table_num--; + LIST_REMOVE(r, r_entry); + uma_zfree(rl->rl_zone, r); + } + } + return; + } + + if ((cookie_timer_expired(&rl->rl_last_gc, ELEMENT_TIMEOUT, 0) && + rl->rl_table_num > 0)) { + getnanouptime(&rl->rl_last_gc); + getnanouptime(&expiry); + expiry.tv_sec -= ELEMENT_TIMEOUT; + + for (i = 0; i < RATELIMIT_SIZE; i++) { + LIST_FOREACH_SAFE(r, &rl->rl_table[i], r_entry, tr) { + if (timespeccmp(&r->r_last_time, &expiry, <)) { + rl->rl_table_num--; + LIST_REMOVE(r, r_entry); + uma_zfree(rl->rl_zone, r); + } + } + } + } +} + +static int +ratelimit_allow(struct ratelimit *rl, struct sockaddr *sa) +{ + uint64_t key, tokens; + struct timespec diff; + struct ratelimit_entry *r; + int ret = ECONNREFUSED; + + if (sa->sa_family == AF_INET) + /* TODO siphash24 is the FreeBSD siphash, OK? */ + key = siphash24(&rl->rl_secret, &satosin(sa)->sin_addr, + IPV4_MASK_SIZE); +#ifdef INET6 + else if (sa->sa_family == AF_INET6) + key = siphash24(&rl->rl_secret, &satosin6(sa)->sin6_addr, + IPV6_MASK_SIZE); +#endif + else + return ret; + + rw_enter_write(&rl->rl_lock); + + LIST_FOREACH(r, &rl->rl_table[key & rl->rl_table_mask], r_entry) { + if (r->r_af != sa->sa_family) + continue; + + if (r->r_af == AF_INET && bcmp(&r->r_in, + &satosin(sa)->sin_addr, IPV4_MASK_SIZE) != 0) + continue; + +#ifdef INET6 + if (r->r_af == AF_INET6 && bcmp(&r->r_in6, + &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE) != 0) + continue; +#endif + + /* If we get to here, we've found an entry for the endpoint. + * We apply standard token bucket, by calculating the time + * lapsed since our last_time, adding that, ensuring that we + * cap the tokens at TOKEN_MAX. If the endpoint has no tokens + * left (that is tokens <= INITIATION_COST) then we block the + * request, otherwise we subtract the INITITIATION_COST and + * return OK. */ + diff = r->r_last_time; + getnanouptime(&r->r_last_time); + timespecsub(&r->r_last_time, &diff, &diff); + + tokens = r->r_tokens + diff.tv_sec * NSEC_PER_SEC + diff.tv_nsec; + + if (tokens > TOKEN_MAX) + tokens = TOKEN_MAX; + + if (tokens >= INITIATION_COST) { + r->r_tokens = tokens - INITIATION_COST; + goto ok; + } else { + r->r_tokens = tokens; + goto error; + } + } + + /* If we get to here, we didn't have an entry for the endpoint. */ + ratelimit_gc(rl, 0); + + /* Hard limit on number of entries */ + if (rl->rl_table_num >= RATELIMIT_SIZE_MAX) + goto error; + + /* Goto error if out of memory */ + if ((r = uma_zalloc(rl->rl_zone, M_NOWAIT)) == NULL) + goto error; + + rl->rl_table_num++; + + /* Insert entry into the hashtable and ensure it's initialised */ + LIST_INSERT_HEAD(&rl->rl_table[key & rl->rl_table_mask], r, r_entry); + r->r_af = sa->sa_family; + if (r->r_af == AF_INET) + memcpy(&r->r_in, &satosin(sa)->sin_addr, IPV4_MASK_SIZE); +#ifdef INET6 + else if (r->r_af == AF_INET6) + memcpy(&r->r_in6, &satosin6(sa)->sin6_addr, IPV6_MASK_SIZE); +#endif + + getnanouptime(&r->r_last_time); + r->r_tokens = TOKEN_MAX - INITIATION_COST; +ok: + ret = 0; +error: + rw_exit_write(&rl->rl_lock); + return ret; +} diff --git a/src/wg_cookie.h b/src/wg_cookie.h new file mode 100644 index 0000000..c7338d8 --- /dev/null +++ b/src/wg_cookie.h @@ -0,0 +1,114 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2015-2021 Jason A. Donenfeld . All Rights Reserved. + * Copyright (C) 2019-2021 Matt Dunwoodie + */ + +#ifndef __COOKIE_H__ +#define __COOKIE_H__ + +#include +#include +#include +#include + +#include + +#include "crypto.h" + +#define COOKIE_MAC_SIZE 16 +#define COOKIE_KEY_SIZE 32 +#define COOKIE_NONCE_SIZE XCHACHA20POLY1305_NONCE_SIZE +#define COOKIE_COOKIE_SIZE 16 +#define COOKIE_SECRET_SIZE 32 +#define COOKIE_INPUT_SIZE 32 +#define COOKIE_ENCRYPTED_SIZE (COOKIE_COOKIE_SIZE + COOKIE_MAC_SIZE) + +#define COOKIE_MAC1_KEY_LABEL "mac1----" +#define COOKIE_COOKIE_KEY_LABEL "cookie--" +#define COOKIE_SECRET_MAX_AGE 120 +#define COOKIE_SECRET_LATENCY 5 + +/* Constants for initiation rate limiting */ +#define RATELIMIT_SIZE (1 << 13) +#define RATELIMIT_SIZE_MAX (RATELIMIT_SIZE * 8) +#define NSEC_PER_SEC 1000000000LL +#define INITIATIONS_PER_SECOND 20 +#define INITIATIONS_BURSTABLE 5 +#define INITIATION_COST (NSEC_PER_SEC / INITIATIONS_PER_SECOND) +#define TOKEN_MAX (INITIATION_COST * INITIATIONS_BURSTABLE) +#define ELEMENT_TIMEOUT 1 +#define IPV4_MASK_SIZE 4 /* Use all 4 bytes of IPv4 address */ +#define IPV6_MASK_SIZE 8 /* Use top 8 bytes (/64) of IPv6 address */ + +struct cookie_macs { + uint8_t mac1[COOKIE_MAC_SIZE]; + uint8_t mac2[COOKIE_MAC_SIZE]; +}; + +struct ratelimit_entry { + LIST_ENTRY(ratelimit_entry) r_entry; + sa_family_t r_af; + union { + struct in_addr r_in; +#ifdef INET6 + struct in6_addr r_in6; +#endif + }; + struct timespec r_last_time; /* nanouptime */ + uint64_t r_tokens; +}; + +struct ratelimit { + SIPHASH_KEY rl_secret; + uma_zone_t rl_zone; + + struct rwlock rl_lock; + LIST_HEAD(, ratelimit_entry) *rl_table; + u_long rl_table_mask; + size_t rl_table_num; + struct timespec rl_last_gc; /* nanouptime */ +}; + +struct cookie_maker { + uint8_t cp_mac1_key[COOKIE_KEY_SIZE]; + uint8_t cp_cookie_key[COOKIE_KEY_SIZE]; + + struct rwlock cp_lock; + uint8_t cp_cookie[COOKIE_COOKIE_SIZE]; + struct timespec cp_birthdate; /* nanouptime */ + int cp_mac1_valid; + uint8_t cp_mac1_last[COOKIE_MAC_SIZE]; +}; + +struct cookie_checker { + struct ratelimit cc_ratelimit_v4; +#ifdef INET6 + struct ratelimit cc_ratelimit_v6; +#endif + + struct rwlock cc_key_lock; + uint8_t cc_mac1_key[COOKIE_KEY_SIZE]; + uint8_t cc_cookie_key[COOKIE_KEY_SIZE]; + + struct rwlock cc_secret_lock; + struct timespec cc_secret_birthdate; /* nanouptime */ + uint8_t cc_secret[COOKIE_SECRET_SIZE]; +}; + +void cookie_maker_init(struct cookie_maker *, const uint8_t[COOKIE_INPUT_SIZE]); +int cookie_checker_init(struct cookie_checker *, uma_zone_t); +void cookie_checker_update(struct cookie_checker *, + const uint8_t[COOKIE_INPUT_SIZE]); +void cookie_checker_deinit(struct cookie_checker *); +void cookie_checker_create_payload(struct cookie_checker *, + struct cookie_macs *cm, uint8_t[COOKIE_NONCE_SIZE], + uint8_t [COOKIE_ENCRYPTED_SIZE], struct sockaddr *); +int cookie_maker_consume_payload(struct cookie_maker *, + uint8_t[COOKIE_NONCE_SIZE], uint8_t[COOKIE_ENCRYPTED_SIZE]); +void cookie_maker_mac(struct cookie_maker *, struct cookie_macs *, + void *, size_t); +int cookie_checker_validate_macs(struct cookie_checker *, + struct cookie_macs *, void *, size_t, int, struct sockaddr *); + +#endif /* __COOKIE_H__ */ diff --git a/src/wg_noise.c b/src/wg_noise.c new file mode 100644 index 0000000..42dcc87 --- /dev/null +++ b/src/wg_noise.c @@ -0,0 +1,952 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2015-2021 Jason A. Donenfeld . All Rights Reserved. + * Copyright (C) 2019-2021 Matt Dunwoodie + */ + +#include +#include +#include +#include + +#include "support.h" +#include "wg_noise.h" + +/* Private functions */ +static struct noise_keypair * + noise_remote_keypair_allocate(struct noise_remote *); +static void + noise_remote_keypair_free(struct noise_remote *, + struct noise_keypair *); +static uint32_t noise_remote_handshake_index_get(struct noise_remote *); +static void noise_remote_handshake_index_drop(struct noise_remote *); + +static uint64_t noise_counter_send(struct noise_counter *); +static int noise_counter_recv(struct noise_counter *, uint64_t); + +static void noise_kdf(uint8_t *, uint8_t *, uint8_t *, const uint8_t *, + size_t, size_t, size_t, size_t, + const uint8_t [NOISE_HASH_LEN]); +static int noise_mix_dh( + uint8_t [NOISE_HASH_LEN], + uint8_t [NOISE_SYMMETRIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); +static int noise_mix_ss( + uint8_t ck[NOISE_HASH_LEN], + uint8_t key[NOISE_SYMMETRIC_KEY_LEN], + const uint8_t ss[NOISE_PUBLIC_KEY_LEN]); +static void noise_mix_hash( + uint8_t [NOISE_HASH_LEN], + const uint8_t *, + size_t); +static void noise_mix_psk( + uint8_t [NOISE_HASH_LEN], + uint8_t [NOISE_HASH_LEN], + uint8_t [NOISE_SYMMETRIC_KEY_LEN], + const uint8_t [NOISE_SYMMETRIC_KEY_LEN]); +static void noise_param_init( + uint8_t [NOISE_HASH_LEN], + uint8_t [NOISE_HASH_LEN], + const uint8_t [NOISE_PUBLIC_KEY_LEN]); + +static void noise_msg_encrypt(uint8_t *, const uint8_t *, size_t, + uint8_t [NOISE_SYMMETRIC_KEY_LEN], + uint8_t [NOISE_HASH_LEN]); +static int noise_msg_decrypt(uint8_t *, const uint8_t *, size_t, + uint8_t [NOISE_SYMMETRIC_KEY_LEN], + uint8_t [NOISE_HASH_LEN]); +static void noise_msg_ephemeral( + uint8_t [NOISE_HASH_LEN], + uint8_t [NOISE_HASH_LEN], + const uint8_t src[NOISE_PUBLIC_KEY_LEN]); + +static void noise_tai64n_now(uint8_t [NOISE_TIMESTAMP_LEN]); +static int noise_timer_expired(struct timespec *, time_t, long); + +/* Set/Get noise parameters */ +void +noise_local_init(struct noise_local *l, struct noise_upcall *upcall) +{ + bzero(l, sizeof(*l)); + rw_init(&l->l_identity_lock, "noise_local_identity"); + l->l_upcall = *upcall; +} + +void +noise_local_lock_identity(struct noise_local *l) +{ + rw_enter_write(&l->l_identity_lock); +} + +void +noise_local_unlock_identity(struct noise_local *l) +{ + rw_exit_write(&l->l_identity_lock); +} + +int +noise_local_set_private(struct noise_local *l, + const uint8_t private[NOISE_PUBLIC_KEY_LEN]) +{ + rw_assert_wrlock(&l->l_identity_lock); + + memcpy(l->l_private, private, NOISE_PUBLIC_KEY_LEN); + curve25519_clamp_secret(l->l_private); + l->l_has_identity = curve25519_generate_public(l->l_public, private); + + return l->l_has_identity ? 0 : ENXIO; +} + +int +noise_local_keys(struct noise_local *l, uint8_t public[NOISE_PUBLIC_KEY_LEN], + uint8_t private[NOISE_PUBLIC_KEY_LEN]) +{ + int ret = 0; + rw_enter_read(&l->l_identity_lock); + if (l->l_has_identity) { + if (public != NULL) + memcpy(public, l->l_public, NOISE_PUBLIC_KEY_LEN); + if (private != NULL) + memcpy(private, l->l_private, NOISE_PUBLIC_KEY_LEN); + } else { + ret = ENXIO; + } + rw_exit_read(&l->l_identity_lock); + return ret; +} + +void +noise_remote_init(struct noise_remote *r, + const uint8_t public[NOISE_PUBLIC_KEY_LEN], struct noise_local *l) +{ + bzero(r, sizeof(*r)); + memcpy(r->r_public, public, NOISE_PUBLIC_KEY_LEN); + rw_init(&r->r_handshake_lock, "noise_handshake"); + rw_init(&r->r_keypair_lock, "noise_keypair"); + + SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[0], kp_entry); + SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[1], kp_entry); + SLIST_INSERT_HEAD(&r->r_unused_keypairs, &r->r_keypair[2], kp_entry); + + KASSERT(l != NULL, ("must provide local")); + r->r_local = l; + + rw_enter_write(&l->l_identity_lock); + noise_remote_precompute(r); + rw_exit_write(&l->l_identity_lock); +} + +int +noise_remote_set_psk(struct noise_remote *r, + const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) +{ + int same; + rw_enter_write(&r->r_handshake_lock); + same = !timingsafe_bcmp(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN); + if (!same) { + memcpy(r->r_psk, psk, NOISE_SYMMETRIC_KEY_LEN); + } + rw_exit_write(&r->r_handshake_lock); + return same ? EEXIST : 0; +} + +int +noise_remote_keys(struct noise_remote *r, uint8_t public[NOISE_PUBLIC_KEY_LEN], + uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) +{ + static uint8_t null_psk[NOISE_SYMMETRIC_KEY_LEN]; + int ret; + + if (public != NULL) + memcpy(public, r->r_public, NOISE_PUBLIC_KEY_LEN); + + rw_enter_read(&r->r_handshake_lock); + if (psk != NULL) + memcpy(psk, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); + ret = timingsafe_bcmp(r->r_psk, null_psk, NOISE_SYMMETRIC_KEY_LEN); + rw_exit_read(&r->r_handshake_lock); + + /* If r_psk != null_psk return 0, else ENOENT (no psk) */ + return ret ? 0 : ENOENT; +} + +void +noise_remote_precompute(struct noise_remote *r) +{ + struct noise_local *l = r->r_local; + rw_assert_wrlock(&l->l_identity_lock); + if (!l->l_has_identity) + bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN); + else if (!curve25519(r->r_ss, l->l_private, r->r_public)) + bzero(r->r_ss, NOISE_PUBLIC_KEY_LEN); + + rw_enter_write(&r->r_handshake_lock); + noise_remote_handshake_index_drop(r); + explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); + rw_exit_write(&r->r_handshake_lock); +} + +/* Handshake functions */ +int +noise_create_initiation(struct noise_remote *r, uint32_t *s_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], + uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]) +{ + struct noise_handshake *hs = &r->r_handshake; + struct noise_local *l = r->r_local; + uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; + int ret = EINVAL; + + rw_enter_read(&l->l_identity_lock); + rw_enter_write(&r->r_handshake_lock); + if (!l->l_has_identity) + goto error; + noise_param_init(hs->hs_ck, hs->hs_hash, r->r_public); + + /* e */ + curve25519_generate_secret(hs->hs_e); + if (curve25519_generate_public(ue, hs->hs_e) == 0) + goto error; + noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue); + + /* es */ + if (noise_mix_dh(hs->hs_ck, key, hs->hs_e, r->r_public) != 0) + goto error; + + /* s */ + noise_msg_encrypt(es, l->l_public, + NOISE_PUBLIC_KEY_LEN, key, hs->hs_hash); + + /* ss */ + if (noise_mix_ss(hs->hs_ck, key, r->r_ss) != 0) + goto error; + + /* {t} */ + noise_tai64n_now(ets); + noise_msg_encrypt(ets, ets, + NOISE_TIMESTAMP_LEN, key, hs->hs_hash); + + noise_remote_handshake_index_drop(r); + hs->hs_state = CREATED_INITIATION; + hs->hs_local_index = noise_remote_handshake_index_get(r); + *s_idx = hs->hs_local_index; + ret = 0; +error: + rw_exit_write(&r->r_handshake_lock); + rw_exit_read(&l->l_identity_lock); + explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); + return ret; +} + +int +noise_consume_initiation(struct noise_local *l, struct noise_remote **rp, + uint32_t s_idx, uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], + uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]) +{ + struct noise_remote *r; + struct noise_handshake hs; + uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t r_public[NOISE_PUBLIC_KEY_LEN]; + uint8_t timestamp[NOISE_TIMESTAMP_LEN]; + int ret = EINVAL; + + rw_enter_read(&l->l_identity_lock); + if (!l->l_has_identity) + goto error; + noise_param_init(hs.hs_ck, hs.hs_hash, l->l_public); + + /* e */ + noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue); + + /* es */ + if (noise_mix_dh(hs.hs_ck, key, l->l_private, ue) != 0) + goto error; + + /* s */ + if (noise_msg_decrypt(r_public, es, + NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0) + goto error; + + /* Lookup the remote we received from */ + if ((r = l->l_upcall.u_remote_get(l->l_upcall.u_arg, r_public)) == NULL) + goto error; + + /* ss */ + if (noise_mix_ss(hs.hs_ck, key, r->r_ss) != 0) + goto error; + + /* {t} */ + if (noise_msg_decrypt(timestamp, ets, + NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0) + goto error; + + hs.hs_state = CONSUMED_INITIATION; + hs.hs_local_index = 0; + hs.hs_remote_index = s_idx; + memcpy(hs.hs_e, ue, NOISE_PUBLIC_KEY_LEN); + + /* We have successfully computed the same results, now we ensure that + * this is not an initiation replay, or a flood attack */ + rw_enter_write(&r->r_handshake_lock); + + /* Replay */ + if (memcmp(timestamp, r->r_timestamp, NOISE_TIMESTAMP_LEN) > 0) + memcpy(r->r_timestamp, timestamp, NOISE_TIMESTAMP_LEN); + else + goto error_set; + /* Flood attack */ + if (noise_timer_expired(&r->r_last_init, 0, REJECT_INTERVAL)) + getnanouptime(&r->r_last_init); + else + goto error_set; + + /* Ok, we're happy to accept this initiation now */ + noise_remote_handshake_index_drop(r); + r->r_handshake = hs; + *rp = r; + ret = 0; +error_set: + rw_exit_write(&r->r_handshake_lock); +error: + rw_exit_read(&l->l_identity_lock); + explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); + explicit_bzero(&hs, sizeof(hs)); + return ret; +} + +int +noise_create_response(struct noise_remote *r, uint32_t *s_idx, uint32_t *r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN]) +{ + struct noise_handshake *hs = &r->r_handshake; + uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t e[NOISE_PUBLIC_KEY_LEN]; + int ret = EINVAL; + + rw_enter_read(&r->r_local->l_identity_lock); + rw_enter_write(&r->r_handshake_lock); + + if (hs->hs_state != CONSUMED_INITIATION) + goto error; + + /* e */ + curve25519_generate_secret(e); + if (curve25519_generate_public(ue, e) == 0) + goto error; + noise_msg_ephemeral(hs->hs_ck, hs->hs_hash, ue); + + /* ee */ + if (noise_mix_dh(hs->hs_ck, NULL, e, hs->hs_e) != 0) + goto error; + + /* se */ + if (noise_mix_dh(hs->hs_ck, NULL, e, r->r_public) != 0) + goto error; + + /* psk */ + noise_mix_psk(hs->hs_ck, hs->hs_hash, key, r->r_psk); + + /* {} */ + noise_msg_encrypt(en, NULL, 0, key, hs->hs_hash); + + hs->hs_state = CREATED_RESPONSE; + hs->hs_local_index = noise_remote_handshake_index_get(r); + *r_idx = hs->hs_remote_index; + *s_idx = hs->hs_local_index; + ret = 0; +error: + rw_exit_write(&r->r_handshake_lock); + rw_exit_read(&r->r_local->l_identity_lock); + explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); + explicit_bzero(e, NOISE_PUBLIC_KEY_LEN); + return ret; +} + +int +noise_consume_response(struct noise_remote *r, uint32_t s_idx, uint32_t r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], uint8_t en[0 + NOISE_AUTHTAG_LEN]) +{ + struct noise_local *l = r->r_local; + struct noise_handshake hs; + uint8_t key[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t preshared_key[NOISE_PUBLIC_KEY_LEN]; + int ret = EINVAL; + + rw_enter_read(&l->l_identity_lock); + if (!l->l_has_identity) + goto error; + + rw_enter_read(&r->r_handshake_lock); + hs = r->r_handshake; + memcpy(preshared_key, r->r_psk, NOISE_SYMMETRIC_KEY_LEN); + rw_exit_read(&r->r_handshake_lock); + + if (hs.hs_state != CREATED_INITIATION || + hs.hs_local_index != r_idx) + goto error; + + /* e */ + noise_msg_ephemeral(hs.hs_ck, hs.hs_hash, ue); + + /* ee */ + if (noise_mix_dh(hs.hs_ck, NULL, hs.hs_e, ue) != 0) + goto error; + + /* se */ + if (noise_mix_dh(hs.hs_ck, NULL, l->l_private, ue) != 0) + goto error; + + /* psk */ + noise_mix_psk(hs.hs_ck, hs.hs_hash, key, preshared_key); + + /* {} */ + if (noise_msg_decrypt(NULL, en, + 0 + NOISE_AUTHTAG_LEN, key, hs.hs_hash) != 0) + goto error; + + hs.hs_remote_index = s_idx; + + rw_enter_write(&r->r_handshake_lock); + if (r->r_handshake.hs_state == hs.hs_state && + r->r_handshake.hs_local_index == hs.hs_local_index) { + r->r_handshake = hs; + r->r_handshake.hs_state = CONSUMED_RESPONSE; + ret = 0; + } + rw_exit_write(&r->r_handshake_lock); +error: + rw_exit_read(&l->l_identity_lock); + explicit_bzero(&hs, sizeof(hs)); + explicit_bzero(key, NOISE_SYMMETRIC_KEY_LEN); + return ret; +} + +int +noise_remote_begin_session(struct noise_remote *r) +{ + struct noise_handshake *hs = &r->r_handshake; + struct noise_keypair kp, *next, *current, *previous; + + rw_enter_write(&r->r_handshake_lock); + + /* We now derive the keypair from the handshake */ + if (hs->hs_state == CONSUMED_RESPONSE) { + kp.kp_is_initiator = 1; + noise_kdf(kp.kp_send, kp.kp_recv, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + hs->hs_ck); + } else if (hs->hs_state == CREATED_RESPONSE) { + kp.kp_is_initiator = 0; + noise_kdf(kp.kp_recv, kp.kp_send, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + hs->hs_ck); + } else { + rw_exit_write(&r->r_handshake_lock); + return EINVAL; + } + + kp.kp_valid = 1; + kp.kp_local_index = hs->hs_local_index; + kp.kp_remote_index = hs->hs_remote_index; + getnanouptime(&kp.kp_birthdate); + bzero(&kp.kp_ctr, sizeof(kp.kp_ctr)); + rw_init(&kp.kp_ctr.c_lock, "noise_counter"); + + /* Now we need to add_new_keypair */ + rw_enter_write(&r->r_keypair_lock); + next = r->r_next; + current = r->r_current; + previous = r->r_previous; + + if (kp.kp_is_initiator) { + if (next != NULL) { + r->r_next = NULL; + r->r_previous = next; + noise_remote_keypair_free(r, current); + } else { + r->r_previous = current; + } + + noise_remote_keypair_free(r, previous); + + r->r_current = noise_remote_keypair_allocate(r); + *r->r_current = kp; + } else { + noise_remote_keypair_free(r, next); + r->r_previous = NULL; + noise_remote_keypair_free(r, previous); + + r->r_next = noise_remote_keypair_allocate(r); + *r->r_next = kp; + } + rw_exit_write(&r->r_keypair_lock); + + explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); + rw_exit_write(&r->r_handshake_lock); + + explicit_bzero(&kp, sizeof(kp)); + return 0; +} + +void +noise_remote_clear(struct noise_remote *r) +{ + rw_enter_write(&r->r_handshake_lock); + noise_remote_handshake_index_drop(r); + explicit_bzero(&r->r_handshake, sizeof(r->r_handshake)); + rw_exit_write(&r->r_handshake_lock); + + rw_enter_write(&r->r_keypair_lock); + noise_remote_keypair_free(r, r->r_next); + noise_remote_keypair_free(r, r->r_current); + noise_remote_keypair_free(r, r->r_previous); + r->r_next = NULL; + r->r_current = NULL; + r->r_previous = NULL; + rw_exit_write(&r->r_keypair_lock); +} + +void +noise_remote_expire_current(struct noise_remote *r) +{ + rw_enter_write(&r->r_keypair_lock); + if (r->r_next != NULL) + r->r_next->kp_valid = 0; + if (r->r_current != NULL) + r->r_current->kp_valid = 0; + rw_exit_write(&r->r_keypair_lock); +} + +int +noise_remote_ready(struct noise_remote *r) +{ + struct noise_keypair *kp; + int ret; + + rw_enter_read(&r->r_keypair_lock); + /* kp_ctr isn't locked here, we're happy to accept a racy read. */ + if ((kp = r->r_current) == NULL || + !kp->kp_valid || + noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) || + kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES || + kp->kp_ctr.c_send >= REJECT_AFTER_MESSAGES) + ret = EINVAL; + else + ret = 0; + rw_exit_read(&r->r_keypair_lock); + return ret; +} + +int +noise_remote_encrypt(struct noise_remote *r, uint32_t *r_idx, uint64_t *nonce, + uint8_t *buf, size_t buflen) +{ + struct noise_keypair *kp; + int ret = EINVAL; + + rw_enter_read(&r->r_keypair_lock); + if ((kp = r->r_current) == NULL) + goto error; + + /* We confirm that our values are within our tolerances. We want: + * - a valid keypair + * - our keypair to be less than REJECT_AFTER_TIME seconds old + * - our receive counter to be less than REJECT_AFTER_MESSAGES + * - our send counter to be less than REJECT_AFTER_MESSAGES + * + * kp_ctr isn't locked here, we're happy to accept a racy read. */ + if (!kp->kp_valid || + noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) || + kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES || + ((*nonce = noise_counter_send(&kp->kp_ctr)) > REJECT_AFTER_MESSAGES)) + goto error; + + /* We encrypt into the same buffer, so the caller must ensure that buf + * has NOISE_AUTHTAG_LEN bytes to store the MAC. The nonce and index + * are passed back out to the caller through the provided data pointer. */ + *r_idx = kp->kp_remote_index; + chacha20poly1305_encrypt(buf, buf, buflen, + NULL, 0, *nonce, kp->kp_send); + + /* If our values are still within tolerances, but we are approaching + * the tolerances, we notify the caller with ESTALE that they should + * establish a new keypair. The current keypair can continue to be used + * until the tolerances are hit. We notify if: + * - our send counter is valid and not less than REKEY_AFTER_MESSAGES + * - we're the initiator and our keypair is older than + * REKEY_AFTER_TIME seconds */ + ret = ESTALE; + if ((kp->kp_valid && *nonce >= REKEY_AFTER_MESSAGES) || + (kp->kp_is_initiator && + noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME, 0))) + goto error; + + ret = 0; +error: + rw_exit_read(&r->r_keypair_lock); + return ret; +} + +int +noise_remote_decrypt(struct noise_remote *r, uint32_t r_idx, uint64_t nonce, + uint8_t *buf, size_t buflen) +{ + struct noise_keypair *kp; + int ret = EINVAL; + + /* We retrieve the keypair corresponding to the provided index. We + * attempt the current keypair first as that is most likely. We also + * want to make sure that the keypair is valid as it would be + * catastrophic to decrypt against a zero'ed keypair. */ + rw_enter_read(&r->r_keypair_lock); + + if (r->r_current != NULL && r->r_current->kp_local_index == r_idx) { + kp = r->r_current; + } else if (r->r_previous != NULL && r->r_previous->kp_local_index == r_idx) { + kp = r->r_previous; + } else if (r->r_next != NULL && r->r_next->kp_local_index == r_idx) { + kp = r->r_next; + } else { + goto error; + } + + /* We confirm that our values are within our tolerances. These values + * are the same as the encrypt routine. + * + * kp_ctr isn't locked here, we're happy to accept a racy read. */ + if (noise_timer_expired(&kp->kp_birthdate, REJECT_AFTER_TIME, 0) || + kp->kp_ctr.c_recv >= REJECT_AFTER_MESSAGES) + goto error; + + /* Decrypt, then validate the counter. We don't want to validate the + * counter before decrypting as we do not know the message is authentic + * prior to decryption. */ + if (chacha20poly1305_decrypt(buf, buf, buflen, + NULL, 0, nonce, kp->kp_recv) == 0) + goto error; + + if (noise_counter_recv(&kp->kp_ctr, nonce) != 0) + goto error; + + /* If we've received the handshake confirming data packet then move the + * next keypair into current. If we do slide the next keypair in, then + * we skip the REKEY_AFTER_TIME_RECV check. This is safe to do as a + * data packet can't confirm a session that we are an INITIATOR of. */ + if (kp == r->r_next) { + rw_exit_read(&r->r_keypair_lock); + rw_enter_write(&r->r_keypair_lock); + if (kp == r->r_next && kp->kp_local_index == r_idx) { + noise_remote_keypair_free(r, r->r_previous); + r->r_previous = r->r_current; + r->r_current = r->r_next; + r->r_next = NULL; + + ret = ECONNRESET; + goto error; + } + rw_enter(&r->r_keypair_lock, RW_DOWNGRADE); + } + + /* Similar to when we encrypt, we want to notify the caller when we + * are approaching our tolerances. We notify if: + * - we're the initiator and the current keypair is older than + * REKEY_AFTER_TIME_RECV seconds. */ + ret = ESTALE; + kp = r->r_current; + if (kp != NULL && + kp->kp_valid && + kp->kp_is_initiator && + noise_timer_expired(&kp->kp_birthdate, REKEY_AFTER_TIME_RECV, 0)) + goto error; + + ret = 0; + +error: + rw_exit(&r->r_keypair_lock); + return ret; +} + +/* Private functions - these should not be called outside this file under any + * circumstances. */ +static struct noise_keypair * +noise_remote_keypair_allocate(struct noise_remote *r) +{ + struct noise_keypair *kp; + kp = SLIST_FIRST(&r->r_unused_keypairs); + SLIST_REMOVE_HEAD(&r->r_unused_keypairs, kp_entry); + return kp; +} + +static void +noise_remote_keypair_free(struct noise_remote *r, struct noise_keypair *kp) +{ + struct noise_upcall *u = &r->r_local->l_upcall; + if (kp != NULL) { + SLIST_INSERT_HEAD(&r->r_unused_keypairs, kp, kp_entry); + u->u_index_drop(u->u_arg, kp->kp_local_index); + bzero(kp->kp_send, sizeof(kp->kp_send)); + bzero(kp->kp_recv, sizeof(kp->kp_recv)); + } +} + +static uint32_t +noise_remote_handshake_index_get(struct noise_remote *r) +{ + struct noise_upcall *u = &r->r_local->l_upcall; + return u->u_index_set(u->u_arg, r); +} + +static void +noise_remote_handshake_index_drop(struct noise_remote *r) +{ + struct noise_handshake *hs = &r->r_handshake; + struct noise_upcall *u = &r->r_local->l_upcall; + rw_assert_wrlock(&r->r_handshake_lock); + if (hs->hs_state != HS_ZEROED) + u->u_index_drop(u->u_arg, hs->hs_local_index); +} + +static uint64_t +noise_counter_send(struct noise_counter *ctr) +{ + uint64_t ret; + rw_enter_write(&ctr->c_lock); + ret = ctr->c_send++; + rw_exit_write(&ctr->c_lock); + return ret; +} + +static int +noise_counter_recv(struct noise_counter *ctr, uint64_t recv) +{ + uint64_t i, top, index_recv, index_ctr; + unsigned long bit; + int ret = EEXIST; + + rw_enter_write(&ctr->c_lock); + + /* Check that the recv counter is valid */ + if (ctr->c_recv >= REJECT_AFTER_MESSAGES || + recv >= REJECT_AFTER_MESSAGES) + goto error; + + /* If the packet is out of the window, invalid */ + if (recv + COUNTER_WINDOW_SIZE < ctr->c_recv) + goto error; + + /* If the new counter is ahead of the current counter, we'll need to + * zero out the bitmap that has previously been used */ + index_recv = recv / COUNTER_BITS; + index_ctr = ctr->c_recv / COUNTER_BITS; + + if (recv > ctr->c_recv) { + top = MIN(index_recv - index_ctr, COUNTER_NUM); + for (i = 1; i <= top; i++) + ctr->c_backtrack[ + (i + index_ctr) & (COUNTER_NUM - 1)] = 0; + ctr->c_recv = recv; + } + + index_recv %= COUNTER_NUM; + bit = 1ul << (recv % COUNTER_BITS); + + if (ctr->c_backtrack[index_recv] & bit) + goto error; + + ctr->c_backtrack[index_recv] |= bit; + + ret = 0; +error: + rw_exit_write(&ctr->c_lock); + return ret; +} + +static void +noise_kdf(uint8_t *a, uint8_t *b, uint8_t *c, const uint8_t *x, + size_t a_len, size_t b_len, size_t c_len, size_t x_len, + const uint8_t ck[NOISE_HASH_LEN]) +{ + uint8_t out[BLAKE2S_HASH_SIZE + 1]; + uint8_t sec[BLAKE2S_HASH_SIZE]; + +#ifdef DIAGNOSTIC + MPASS(a_len <= BLAKE2S_HASH_SIZE && b_len <= BLAKE2S_HASH_SIZE && + c_len <= BLAKE2S_HASH_SIZE); + MPASS(!(b || b_len || c || c_len) || (a && a_len)); + MPASS(!(c || c_len) || (b && b_len)); +#endif + + /* Extract entropy from "x" into sec */ + blake2s_hmac(sec, x, ck, BLAKE2S_HASH_SIZE, x_len, NOISE_HASH_LEN); + + if (a == NULL || a_len == 0) + goto out; + + /* Expand first key: key = sec, data = 0x1 */ + out[0] = 1; + blake2s_hmac(out, out, sec, BLAKE2S_HASH_SIZE, 1, BLAKE2S_HASH_SIZE); + memcpy(a, out, a_len); + + if (b == NULL || b_len == 0) + goto out; + + /* Expand second key: key = sec, data = "a" || 0x2 */ + out[BLAKE2S_HASH_SIZE] = 2; + blake2s_hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1, + BLAKE2S_HASH_SIZE); + memcpy(b, out, b_len); + + if (c == NULL || c_len == 0) + goto out; + + /* Expand third key: key = sec, data = "b" || 0x3 */ + out[BLAKE2S_HASH_SIZE] = 3; + blake2s_hmac(out, out, sec, BLAKE2S_HASH_SIZE, BLAKE2S_HASH_SIZE + 1, + BLAKE2S_HASH_SIZE); + memcpy(c, out, c_len); + +out: + /* Clear sensitive data from stack */ + explicit_bzero(sec, BLAKE2S_HASH_SIZE); + explicit_bzero(out, BLAKE2S_HASH_SIZE + 1); +} + +static int +noise_mix_dh(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN], + const uint8_t private[NOISE_PUBLIC_KEY_LEN], + const uint8_t public[NOISE_PUBLIC_KEY_LEN]) +{ + uint8_t dh[NOISE_PUBLIC_KEY_LEN]; + + if (!curve25519(dh, private, public)) + return EINVAL; + noise_kdf(ck, key, NULL, dh, + NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck); + explicit_bzero(dh, NOISE_PUBLIC_KEY_LEN); + return 0; +} + +static int +noise_mix_ss(uint8_t ck[NOISE_HASH_LEN], uint8_t key[NOISE_SYMMETRIC_KEY_LEN], + const uint8_t ss[NOISE_PUBLIC_KEY_LEN]) +{ + static uint8_t null_point[NOISE_PUBLIC_KEY_LEN]; + if (timingsafe_bcmp(ss, null_point, NOISE_PUBLIC_KEY_LEN) == 0) + return ENOENT; + noise_kdf(ck, key, NULL, ss, + NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, ck); + return 0; +} + +static void +noise_mix_hash(uint8_t hash[NOISE_HASH_LEN], const uint8_t *src, + size_t src_len) +{ + struct blake2s_state blake; + + blake2s_init(&blake, NOISE_HASH_LEN); + blake2s_update(&blake, hash, NOISE_HASH_LEN); + blake2s_update(&blake, src, src_len); + blake2s_final(&blake, hash); +} + +static void +noise_mix_psk(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN], + uint8_t key[NOISE_SYMMETRIC_KEY_LEN], + const uint8_t psk[NOISE_SYMMETRIC_KEY_LEN]) +{ + uint8_t tmp[NOISE_HASH_LEN]; + + noise_kdf(ck, tmp, key, psk, + NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, + NOISE_SYMMETRIC_KEY_LEN, ck); + noise_mix_hash(hash, tmp, NOISE_HASH_LEN); + explicit_bzero(tmp, NOISE_HASH_LEN); +} + +static void +noise_param_init(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN], + const uint8_t s[NOISE_PUBLIC_KEY_LEN]) +{ + struct blake2s_state blake; + + blake2s(ck, (uint8_t *)NOISE_HANDSHAKE_NAME, NULL, + NOISE_HASH_LEN, strlen(NOISE_HANDSHAKE_NAME), 0); + blake2s_init(&blake, NOISE_HASH_LEN); + blake2s_update(&blake, ck, NOISE_HASH_LEN); + blake2s_update(&blake, (uint8_t *)NOISE_IDENTIFIER_NAME, + strlen(NOISE_IDENTIFIER_NAME)); + blake2s_final(&blake, hash); + + noise_mix_hash(hash, s, NOISE_PUBLIC_KEY_LEN); +} + +static void +noise_msg_encrypt(uint8_t *dst, const uint8_t *src, size_t src_len, + uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN]) +{ + /* Nonce always zero for Noise_IK */ + chacha20poly1305_encrypt(dst, src, src_len, + hash, NOISE_HASH_LEN, 0, key); + noise_mix_hash(hash, dst, src_len + NOISE_AUTHTAG_LEN); +} + +static int +noise_msg_decrypt(uint8_t *dst, const uint8_t *src, size_t src_len, + uint8_t key[NOISE_SYMMETRIC_KEY_LEN], uint8_t hash[NOISE_HASH_LEN]) +{ + /* Nonce always zero for Noise_IK */ + if (!chacha20poly1305_decrypt(dst, src, src_len, + hash, NOISE_HASH_LEN, 0, key)) + return EINVAL; + noise_mix_hash(hash, src, src_len); + return 0; +} + +static void +noise_msg_ephemeral(uint8_t ck[NOISE_HASH_LEN], uint8_t hash[NOISE_HASH_LEN], + const uint8_t src[NOISE_PUBLIC_KEY_LEN]) +{ + noise_mix_hash(hash, src, NOISE_PUBLIC_KEY_LEN); + noise_kdf(ck, NULL, NULL, src, NOISE_HASH_LEN, 0, 0, + NOISE_PUBLIC_KEY_LEN, ck); +} + +static void +noise_tai64n_now(uint8_t output[NOISE_TIMESTAMP_LEN]) +{ + struct timespec time; + uint64_t sec; + uint32_t nsec; + + getnanotime(&time); + + /* Round down the nsec counter to limit precise timing leak. */ + time.tv_nsec &= REJECT_INTERVAL_MASK; + + /* https://cr.yp.to/libtai/tai64.html */ + sec = htobe64(0x400000000000000aULL + time.tv_sec); + nsec = htobe32(time.tv_nsec); + + /* memcpy to output buffer, assuming output could be unaligned. */ + memcpy(output, &sec, sizeof(sec)); + memcpy(output + sizeof(sec), &nsec, sizeof(nsec)); +} + +static int +noise_timer_expired(struct timespec *birthdate, time_t sec, long nsec) +{ + struct timespec uptime; + struct timespec expire = { .tv_sec = sec, .tv_nsec = nsec }; + + /* We don't really worry about a zeroed birthdate, to avoid the extra + * check on every encrypt/decrypt. This does mean that r_last_init + * check may fail if getnanouptime is < REJECT_INTERVAL from 0. */ + + getnanouptime(&uptime); + timespecadd(birthdate, &expire, &expire); + return timespeccmp(&uptime, &expire, >) ? ETIMEDOUT : 0; +} diff --git a/src/wg_noise.h b/src/wg_noise.h new file mode 100644 index 0000000..198ddd1 --- /dev/null +++ b/src/wg_noise.h @@ -0,0 +1,180 @@ +/* SPDX-License-Identifier: ISC + * + * Copyright (C) 2015-2021 Jason A. Donenfeld . All Rights Reserved. + * Copyright (C) 2019-2021 Matt Dunwoodie + */ + +#ifndef __NOISE_H__ +#define __NOISE_H__ + +#include +#include +#include + +#include "crypto.h" + +#define NOISE_PUBLIC_KEY_LEN CURVE25519_KEY_SIZE +#define NOISE_SYMMETRIC_KEY_LEN CHACHA20POLY1305_KEY_SIZE +#define NOISE_TIMESTAMP_LEN (sizeof(uint64_t) + sizeof(uint32_t)) +#define NOISE_AUTHTAG_LEN CHACHA20POLY1305_AUTHTAG_SIZE +#define NOISE_HASH_LEN BLAKE2S_HASH_SIZE + +/* Protocol string constants */ +#define NOISE_HANDSHAKE_NAME "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" +#define NOISE_IDENTIFIER_NAME "WireGuard v1 zx2c4 Jason@zx2c4.com" + +/* Constants for the counter */ +#define COUNTER_BITS_TOTAL 8192 +#define COUNTER_BITS (sizeof(unsigned long) * 8) +#define COUNTER_NUM (COUNTER_BITS_TOTAL / COUNTER_BITS) +#define COUNTER_WINDOW_SIZE (COUNTER_BITS_TOTAL - COUNTER_BITS) + +/* Constants for the keypair */ +#define REKEY_AFTER_MESSAGES (1ull << 60) +#define REJECT_AFTER_MESSAGES (UINT64_MAX - COUNTER_WINDOW_SIZE - 1) +#define REKEY_AFTER_TIME 120 +#define REKEY_AFTER_TIME_RECV 165 +#define REJECT_AFTER_TIME 180 +#define REJECT_INTERVAL (1000000000 / 50) /* fifty times per sec */ +/* 24 = floor(log2(REJECT_INTERVAL)) */ +#define REJECT_INTERVAL_MASK (~((1ull<<24)-1)) + +enum noise_state_hs { + HS_ZEROED = 0, + CREATED_INITIATION, + CONSUMED_INITIATION, + CREATED_RESPONSE, + CONSUMED_RESPONSE, +}; + +struct noise_handshake { + enum noise_state_hs hs_state; + uint32_t hs_local_index; + uint32_t hs_remote_index; + uint8_t hs_e[NOISE_PUBLIC_KEY_LEN]; + uint8_t hs_hash[NOISE_HASH_LEN]; + uint8_t hs_ck[NOISE_HASH_LEN]; +}; + +struct noise_counter { + struct rwlock c_lock; + uint64_t c_send; + uint64_t c_recv; + unsigned long c_backtrack[COUNTER_NUM]; +}; + +struct noise_keypair { + SLIST_ENTRY(noise_keypair) kp_entry; + int kp_valid; + int kp_is_initiator; + uint32_t kp_local_index; + uint32_t kp_remote_index; + uint8_t kp_send[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t kp_recv[NOISE_SYMMETRIC_KEY_LEN]; + struct timespec kp_birthdate; /* nanouptime */ + struct noise_counter kp_ctr; +}; + +struct noise_remote { + uint8_t r_public[NOISE_PUBLIC_KEY_LEN]; + struct noise_local *r_local; + uint8_t r_ss[NOISE_PUBLIC_KEY_LEN]; + + struct rwlock r_handshake_lock; + struct noise_handshake r_handshake; + uint8_t r_psk[NOISE_SYMMETRIC_KEY_LEN]; + uint8_t r_timestamp[NOISE_TIMESTAMP_LEN]; + struct timespec r_last_init; /* nanouptime */ + + struct rwlock r_keypair_lock; + SLIST_HEAD(,noise_keypair) r_unused_keypairs; + struct noise_keypair *r_next, *r_current, *r_previous; + struct noise_keypair r_keypair[3]; /* 3: next, current, previous. */ + +}; + +struct noise_local { + struct rwlock l_identity_lock; + int l_has_identity; + uint8_t l_public[NOISE_PUBLIC_KEY_LEN]; + uint8_t l_private[NOISE_PUBLIC_KEY_LEN]; + + struct noise_upcall { + void *u_arg; + struct noise_remote * + (*u_remote_get)(void *, uint8_t[NOISE_PUBLIC_KEY_LEN]); + uint32_t + (*u_index_set)(void *, struct noise_remote *); + void (*u_index_drop)(void *, uint32_t); + } l_upcall; +}; + +/* Set/Get noise parameters */ +void noise_local_init(struct noise_local *, struct noise_upcall *); +void noise_local_lock_identity(struct noise_local *); +void noise_local_unlock_identity(struct noise_local *); +int noise_local_set_private(struct noise_local *, + const uint8_t[NOISE_PUBLIC_KEY_LEN]); +int noise_local_keys(struct noise_local *, uint8_t[NOISE_PUBLIC_KEY_LEN], + uint8_t[NOISE_PUBLIC_KEY_LEN]); + +void noise_remote_init(struct noise_remote *, + const uint8_t[NOISE_PUBLIC_KEY_LEN], struct noise_local *); +int noise_remote_set_psk(struct noise_remote *, + const uint8_t[NOISE_SYMMETRIC_KEY_LEN]); +int noise_remote_keys(struct noise_remote *, uint8_t[NOISE_PUBLIC_KEY_LEN], + uint8_t[NOISE_SYMMETRIC_KEY_LEN]); + +/* Should be called anytime noise_local_set_private is called */ +void noise_remote_precompute(struct noise_remote *); + +/* Cryptographic functions */ +int noise_create_initiation( + struct noise_remote *, + uint32_t *s_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], + uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]); + +int noise_consume_initiation( + struct noise_local *, + struct noise_remote **, + uint32_t s_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN], + uint8_t ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN]); + +int noise_create_response( + struct noise_remote *, + uint32_t *s_idx, + uint32_t *r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t en[0 + NOISE_AUTHTAG_LEN]); + +int noise_consume_response( + struct noise_remote *, + uint32_t s_idx, + uint32_t r_idx, + uint8_t ue[NOISE_PUBLIC_KEY_LEN], + uint8_t en[0 + NOISE_AUTHTAG_LEN]); + +int noise_remote_begin_session(struct noise_remote *); +void noise_remote_clear(struct noise_remote *); +void noise_remote_expire_current(struct noise_remote *); + +int noise_remote_ready(struct noise_remote *); + +int noise_remote_encrypt( + struct noise_remote *, + uint32_t *r_idx, + uint64_t *nonce, + uint8_t *buf, + size_t buflen); +int noise_remote_decrypt( + struct noise_remote *, + uint32_t r_idx, + uint64_t nonce, + uint8_t *buf, + size_t buflen); + +#endif /* __NOISE_H__ */ diff --git a/tests/if_wg_test.sh b/tests/if_wg_test.sh new file mode 100755 index 0000000..69e90d9 --- /dev/null +++ b/tests/if_wg_test.sh @@ -0,0 +1,164 @@ +# $FreeBSD$ +# +# SPDX-License-Identifier: BSD-2-Clause-FreeBSD +# +# Copyright (c) 2021 The FreeBSD Foundation + +. $(atf_get_srcdir)/../common/vnet.subr + +atf_test_case "wg_basic" "cleanup" +wg_basic_head() +{ + atf_set descr 'Create a wg(4) tunnel over an epair and pass traffic between jails' + atf_set require.user root +} + +wg_basic_body() +{ + local epair pri1 pri2 pub1 pub2 wg1 wg2 + local endpoint1 endpoint2 tunnel1 tunnel2 + + kldload -n if_wg + + pri1=$(openssl rand -base64 32) + pri2=$(openssl rand -base64 32) + + endpoint1=192.168.2.1 + endpoint2=192.168.2.2 + tunnel1=169.254.0.1 + tunnel2=169.254.0.2 + + epair=$(vnet_mkepair) + + vnet_init + + vnet_mkjail wgtest1 ${epair}a + vnet_mkjail wgtest2 ${epair}b + + # Workaround for PR 254212. + jexec wgtest1 ifconfig lo0 up + jexec wgtest2 ifconfig lo0 up + + jexec wgtest1 ifconfig ${epair}a $endpoint1 up + jexec wgtest2 ifconfig ${epair}b $endpoint2 up + + wg1=$(jexec wgtest1 ifconfig wg create listen-port 12345 private-key "$pri1") + pub1=$(jexec wgtest1 ifconfig $wg1 | awk '/public-key:/ {print $2}') + wg2=$(jexec wgtest2 ifconfig wg create listen-port 12345 private-key "$pri2") + pub2=$(jexec wgtest2 ifconfig $wg2 | awk '/public-key:/ {print $2}') + + atf_check -s exit:0 -o ignore \ + jexec wgtest1 ifconfig $wg1 peer public-key "$pub2" \ + endpoint ${endpoint2}:12345 allowed-ips ${tunnel2}/32 + atf_check -s exit:0 \ + jexec wgtest1 ifconfig $wg1 inet $tunnel1 up + + atf_check -s exit:0 -o ignore \ + jexec wgtest2 ifconfig $wg2 peer public-key "$pub1" \ + endpoint ${endpoint1}:12345 allowed-ips ${tunnel1}/32 + atf_check -s exit:0 \ + jexec wgtest2 ifconfig $wg2 inet $tunnel2 up + + # Generous timeout since the handshake takes some time. + atf_check -s exit:0 -o ignore jexec wgtest1 ping -o -t 5 -i 0.25 $tunnel2 + atf_check -s exit:0 -o ignore jexec wgtest2 ping -o -t 5 -i 0.25 $tunnel1 +} + +wg_basic_cleanup() +{ + vnet_cleanup +} + +# The kernel is expecteld to silently ignore any attempt to add a peer with a +# public key identical to the host's. +atf_test_case "wg_key_peerdev_shared" "cleanup" +wg_key_peerdev_shared_head() +{ + atf_set descr 'Create a wg(4) interface with a shared pubkey between device and a peer' + atf_set require.user root +} + +wg_key_peerdev_shared_body() +{ + local epair pri1 pub1 wg1 + local endpoint1 tunnel1 + + kldload -n if_wg + + pri1=$(openssl rand -base64 32) + + endpoint1=192.168.2.1 + tunnel1=169.254.0.1 + + vnet_mkjail wgtest1 + + wg1=$(jexec wgtest1 ifconfig wg create listen-port 12345 private-key "$pri1") + pub1=$(jexec wgtest1 ifconfig $wg1 | awk '/public-key:/ {print $2}') + + atf_check -s exit:0 \ + jexec wgtest1 ifconfig ${wg1} peer public-key "${pub1}" \ + allowed-ips "${tunnel1}/32" + + atf_check -o empty jexec wgtest1 ifconfig ${wg1} peers +} + +wg_key_peerdev_shared_cleanup() +{ + vnet_cleanup +} + +# When a wg(8) interface has a private key reassigned that corresponds to the +# public key already on a peer, the kernel is expected to deconfigure the peer +# to resolve the conflict. +atf_test_case "wg_key_peerdev_makeshared" "cleanup" +wg_key_peerdev_makeshared_head() +{ + atf_set descr 'Create a wg(4) interface and assign peer key to device' + atf_set require.progs wg +} + +wg_key_peerdev_makeshared_body() +{ + local epair pri1 pub1 pri2 wg1 wg2 + local endpoint1 tunnel1 + + kldload -n if_wg + + pri1=$(openssl rand -base64 32) + pri2=$(openssl rand -base64 32) + + endpoint1=192.168.2.1 + tunnel1=169.254.0.1 + + vnet_mkjail wgtest1 + + wg1=$(jexec wgtest1 ifconfig wg create listen-port 12345 private-key "$pri1") + pub1=$(jexec wgtest1 ifconfig $wg1 | awk '/public-key:/ {print $2}') + + wg2=$(jexec wgtest1 ifconfig wg create listen-port 12345 private-key "$pri2") + + atf_check -s exit:0 -o ignore \ + jexec wgtest1 ifconfig ${wg2} peer public-key "${pub1}" \ + allowed-ips "${tunnel1}/32" + + atf_check -o not-empty jexec wgtest1 ifconfig ${wg2} peers + + jexec wgtest1 sh -c "echo '${pri1}' > pri1" + + atf_check -s exit:0 \ + jexec wgtest1 wg set ${wg2} private-key pri1 + + atf_check -o empty jexec wgtest1 ifconfig ${wg2} peers +} + +wg_key_peerdev_makeshared_cleanup() +{ + vnet_cleanup +} + +atf_init_test_cases() +{ + atf_add_test_case "wg_basic" + atf_add_test_case "wg_key_peerdev_shared" + atf_add_test_case "wg_key_peerdev_makeshared" +} diff --git a/tests/netns.sh b/tests/netns.sh new file mode 100755 index 0000000..7502871 --- /dev/null +++ b/tests/netns.sh @@ -0,0 +1,643 @@ +#!/usr/bin/env bash +# +# SPDX-License-Identifier: GPL-2.0 +# +# Copyright (C) 2015-2019 Jason A. Donenfeld . All Rights Reserved. +# +# This script tests the below topology: +# +# ┌─────────────────────┐ ┌──────────────────────────────────┐ ┌─────────────────────┐ +# │ $ns1 namespace │ │ $ns0 namespace │ │ $ns2 namespace │ +# │ │ │ │ │ │ +# │┌────────┐ │ │ ┌────────┐ │ │ ┌────────┐│ +# ││ wg0 │───────────┼───┼────────────│ lo │────────────┼───┼───────────│ wg0 ││ +# │├────────┴──────────┐│ │ ┌───────┴────────┴────────┐ │ │┌──────────┴────────┤│ +# ││192.168.241.1/24 ││ │ │(ns1) (ns2) │ │ ││192.168.241.2/24 ││ +# ││fd00::1/24 ││ │ │127.0.0.1:1 127.0.0.1:2│ │ ││fd00::2/24 ││ +# │└───────────────────┘│ │ │[::]:1 [::]:2 │ │ │└───────────────────┘│ +# └─────────────────────┘ │ └─────────────────────────┘ │ └─────────────────────┘ +# └──────────────────────────────────┘ +# +# After the topology is prepared we run a series of TCP/UDP iperf3 tests between the +# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg0 +# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further +# details on how this is accomplished. +set -e + +# Needs iperf3 + +exec 3>&1 +export LANG=C +export WG_HIDE_KEYS=never +jail0="wg-test-$$-0" +jail1="wg-test-$$-1" +jail2="wg-test-$$-2" +pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } +pp() { pretty "" "$*"; "$@"; } +maybe_exec() { if [[ $BASHPID -eq $$ ]]; then "$@"; else exec "$@"; fi; } +j0() { pretty 0 "$*"; maybe_exec jexec $jail0 "$@"; } +j1() { pretty 1 "$*"; maybe_exec jexec $jail1 "$@"; } +j2() { pretty 2 "$*"; maybe_exec jexec $jail2 "$@"; } +ifconfig0() { j0 ifconfig "$@"; } +ifconfig1() { j1 ifconfig "$@"; } +ifconfig2() { j2 ifconfig "$@"; } +sleep() { read -t "$1" -N 1 || true; } +#waitiperf() { pretty "${1//*-}" "wait for iperf:${3:-5201} pid $2"; while [[ $(ss -N "$1" -tlpH "sport = ${3:-5201}") != *\"iperf3\",pid=$2,fd=* ]]; do sleep 0.1; done; } +waitiperf() { pretty "${1//*-}" "wait for iperf:${3:-5201} pid $2"; while ! sockstat -qj "$1" -ql -P tcp -p "${3:-5201}" | grep -Eq "iperf3[[:space:]]+$2[[:space:]]"; do sleep 0.1; done; } +waitncatudp() { pretty "${1//*-}" "wait for udp:1111 pid $2"; while [[ $(ss -N "$1" -ulpH 'sport = 1111') != *\"ncat\",pid=$2,fd=* ]]; do sleep 0.1; done; } +waitiface() { pretty "${1//*-}" "wait for $2 to come up"; jexec "$1" bash -c "while ! ifconfig wg0 | grep -qE 'flags.+UP'; do read -t .1 -N 0 || true; done;"; } + +cj() { pretty "" "Creating $1"; jail -c path=/ vnet=new name="$1" persist; } +dj() { pretty "" "Deleting $1"; jail -r "$1" >/dev/null; } + +cleanup() { + set +e + exec 2>/dev/null + # printf "$orig_message_cost" > /proc/sys/net/core/message_cost + dj $jail0 + dj $jail1 + dj $jail2 + + for iface in wg1 wg2; do + pretty "" "Awaiting return of ${iface}" + # Give interfaces a second to return + while ! ifconfig ${iface} &> /dev/null; do + sleep 0.1 + done + ifconfig ${iface} destroy + done + exit +} + +trap cleanup EXIT + +dj $jail0 || true +dj $jail1 || true +dj $jail2 || true +cj $jail0 +cj $jail1 +cj $jail2 + +ifconfig wg1 create +ifconfig wg1 vnet ${jail1} +ifconfig wg2 create +ifconfig wg2 vnet ${jail2} + +key1="$(pp wg genkey)" +key2="$(pp wg genkey)" +key3="$(pp wg genkey)" +key4="$(pp wg genkey)" +pub1="$(pp wg pubkey <<<"$key1")" +pub2="$(pp wg pubkey <<<"$key2")" +pub3="$(pp wg pubkey <<<"$key3")" +pub4="$(pp wg pubkey <<<"$key4")" +psk="$(pp wg genpsk)" +[[ -n $key1 && -n $key2 && -n $psk ]] + +configure_peers() { + ifconfig1 wg1 inet 192.168.241.1/24 + ifconfig1 wg1 inet6 fd00::1/112 up + + ifconfig2 wg2 inet 192.168.241.2/24 + ifconfig2 wg2 inet6 fd00::2/112 up + + j1 wg set wg1 \ + private-key <(echo "$key1") \ + listen-port 1 \ + peer "$pub2" \ + preshared-key <(echo "$psk") \ + allowed-ips 192.168.241.2/32,fd00::2/128 + j2 wg set wg2 \ + private-key <(echo "$key2") \ + listen-port 2 \ + peer "$pub1" \ + preshared-key <(echo "$psk") \ + allowed-ips 192.168.241.1/32,fd00::1/128 +} +configure_peers + +tests() { + # Ping over IPv4 + j2 ping -c 10 -f -W 1 192.168.241.1 + j1 ping -c 10 -f -W 1 192.168.241.2 + + # Ping over IPv6 + j2 ping6 -c 10 -f -W 1 fd00::1 + j1 ping6 -c 10 -f -W 1 fd00::2 + + # TCP over IPv4 + j2 iperf3 -s -1 -B 192.168.241.2 & + waitiperf $jail2 $! + j1 iperf3 -Z -t 3 -c 192.168.241.2 + + # TCP over IPv6 + j1 iperf3 -s -1 -B fd00::1 & + waitiperf $jail1 $! + j2 iperf3 -Z -t 3 -c fd00::1 + + # UDP over IPv4 + j1 iperf3 -s -1 -B 192.168.241.1 & + waitiperf $jail1 $! + j2 iperf3 -Z -t 3 -b 0 -u -c 192.168.241.1 + + # UDP over IPv6 + j2 iperf3 -s -1 -B fd00::2 & + waitiperf $jail2 $! + j1 iperf3 -Z -t 3 -b 0 -u -c fd00::2 + + # TCP over IPv4, in parallel + for max in 4 5 50; do + local pids=( ) + for ((i=0; i < max; ++i)) do + j2 iperf3 -p $(( 5200 + i )) -s -1 -B 192.168.241.2 & + pids+=( $! ); waitiperf $jail2 $! $(( 5200 + i )) + done + for ((i=0; i < max; ++i)) do + j1 iperf3 -Z -t 3 -p $(( 5200 + i )) -c 192.168.241.2 & + done + wait "${pids[@]}" + done +} + +[[ $(ifconfig1 wg1) =~ mtu\ ([0-9]+) ]] && orig_mtu="${BASH_REMATCH[1]}" +#big_mtu=$(( 34816 - 1500 + $orig_mtu )) +# XXX +big_mtu=16304 + +# Test using IPv4 as outer transport +j1 wg set wg1 peer "$pub2" endpoint 127.0.0.1:2 +j2 wg set wg2 peer "$pub1" endpoint 127.0.0.1:1 + +# Before calling tests, we first make sure that the stats counters and timestamper are working +#j2 ping -c 10 -f -W 1 192.168.241.1 +#{ read _; read _; read _; read rx_bytes _; read _; read tx_bytes _; } < <(ip2 -stats link show dev wg0) +#(( rx_bytes == 1372 && (tx_bytes == 1428 || tx_bytes == 1460) )) +#{ read _; read _; read _; read rx_bytes _; read _; read tx_bytes _; } < <(ip1 -stats link show dev wg0) +#(( tx_bytes == 1372 && (rx_bytes == 1428 || rx_bytes == 1460) )) +#read _ rx_bytes tx_bytes < <(n2 wg show wg0 transfer) +#(( rx_bytes == 1372 && (tx_bytes == 1428 || tx_bytes == 1460) )) +#read _ rx_bytes tx_bytes < <(n1 wg show wg0 transfer) +#(( tx_bytes == 1372 && (rx_bytes == 1428 || rx_bytes == 1460) )) +#read _ timestamp < <(n1 wg show wg0 latest-handshakes) +#(( timestamp != 0 )) + +tests +ifconfig1 wg1 mtu $big_mtu +ifconfig2 wg2 mtu $big_mtu +tests + +exit 1 + +ip1 link set wg0 mtu $orig_mtu +ip2 link set wg0 mtu $orig_mtu + +# Test using IPv6 as outer transport +n1 wg set wg0 peer "$pub2" endpoint [::1]:2 +n2 wg set wg0 peer "$pub1" endpoint [::1]:1 +tests +ip1 link set wg0 mtu $big_mtu +ip2 link set wg0 mtu $big_mtu +tests + +# Test that route MTUs work with the padding +ip1 link set wg0 mtu 1300 +ip2 link set wg0 mtu 1300 +n1 wg set wg0 peer "$pub2" endpoint 127.0.0.1:2 +n2 wg set wg0 peer "$pub1" endpoint 127.0.0.1:1 +n0 iptables -A INPUT -m length --length 1360 -j DROP +n1 ip route add 192.168.241.2/32 dev wg0 mtu 1299 +n2 ip route add 192.168.241.1/32 dev wg0 mtu 1299 +n2 ping -c 1 -W 1 -s 1269 192.168.241.1 +n2 ip route delete 192.168.241.1/32 dev wg0 mtu 1299 +n1 ip route delete 192.168.241.2/32 dev wg0 mtu 1299 +n0 iptables -F INPUT + +ip1 link set wg0 mtu $orig_mtu +ip2 link set wg0 mtu $orig_mtu + +# Test using IPv4 that roaming works +ip0 -4 addr del 127.0.0.1/8 dev lo +ip0 -4 addr add 127.212.121.99/8 dev lo +n1 wg set wg0 listen-port 9999 +n1 wg set wg0 peer "$pub2" endpoint 127.0.0.1:2 +n1 ping6 -W 1 -c 1 fd00::2 +[[ $(n2 wg show wg0 endpoints) == "$pub1 127.212.121.99:9999" ]] + +# Test using IPv6 that roaming works +n1 wg set wg0 listen-port 9998 +n1 wg set wg0 peer "$pub2" endpoint [::1]:2 +n1 ping -W 1 -c 1 192.168.241.2 +[[ $(n2 wg show wg0 endpoints) == "$pub1 [::1]:9998" ]] + +# Test that crypto-RP filter works +n1 wg set wg0 peer "$pub2" allowed-ips 192.168.241.0/24 +exec 4< <(n1 ncat -l -u -p 1111) +ncat_pid=$! +waitncatudp $jail1 $ncat_pid +n2 ncat -u 192.168.241.1 1111 <<<"X" +read -r -N 1 -t 1 out <&4 && [[ $out == "X" ]] +kill $ncat_pid +more_specific_key="$(pp wg genkey | pp wg pubkey)" +n1 wg set wg0 peer "$more_specific_key" allowed-ips 192.168.241.2/32 +n2 wg set wg0 listen-port 9997 +exec 4< <(n1 ncat -l -u -p 1111) +ncat_pid=$! +waitncatudp $jail1 $ncat_pid +n2 ncat -u 192.168.241.1 1111 <<<"X" +! read -r -N 1 -t 1 out <&4 || false +kill $ncat_pid +n1 wg set wg0 peer "$more_specific_key" remove +[[ $(n1 wg show wg0 endpoints) == "$pub2 [::1]:9997" ]] + +# Test that we can change private keys keys and immediately handshake +n1 wg set wg0 private-key <(echo "$key1") peer "$pub2" preshared-key <(echo "$psk") allowed-ips 192.168.241.2/32 endpoint 127.0.0.1:2 +n2 wg set wg0 private-key <(echo "$key2") listen-port 2 peer "$pub1" preshared-key <(echo "$psk") allowed-ips 192.168.241.1/32 +n1 ping -W 1 -c 1 192.168.241.2 +n1 wg set wg0 private-key <(echo "$key3") +n2 wg set wg0 peer "$pub3" preshared-key <(echo "$psk") allowed-ips 192.168.241.1/32 peer "$pub1" remove +n1 ping -W 1 -c 1 192.168.241.2 +n2 wg set wg0 peer "$pub3" remove + +# Test that we can route wg through wg +ip1 addr flush dev wg0 +ip2 addr flush dev wg0 +ip1 addr add fd00::5:1/112 dev wg0 +ip2 addr add fd00::5:2/112 dev wg0 +n1 wg set wg0 private-key <(echo "$key1") peer "$pub2" preshared-key <(echo "$psk") allowed-ips fd00::5:2/128 endpoint 127.0.0.1:2 +n2 wg set wg0 private-key <(echo "$key2") listen-port 2 peer "$pub1" preshared-key <(echo "$psk") allowed-ips fd00::5:1/128 endpoint 127.212.121.99:9998 +ip1 link add wg1 type wireguard +ip2 link add wg1 type wireguard +ip1 addr add 192.168.241.1/24 dev wg1 +ip1 addr add fd00::1/112 dev wg1 +ip2 addr add 192.168.241.2/24 dev wg1 +ip2 addr add fd00::2/112 dev wg1 +ip1 link set mtu 1340 up dev wg1 +ip2 link set mtu 1340 up dev wg1 +n1 wg set wg1 listen-port 5 private-key <(echo "$key3") peer "$pub4" allowed-ips 192.168.241.2/32,fd00::2/128 endpoint [fd00::5:2]:5 +n2 wg set wg1 listen-port 5 private-key <(echo "$key4") peer "$pub3" allowed-ips 192.168.241.1/32,fd00::1/128 endpoint [fd00::5:1]:5 +tests +# Try to set up a routing loop between the two namespaces +ip1 link set netns $jail0 dev wg1 +ip0 addr add 192.168.241.1/24 dev wg1 +ip0 link set up dev wg1 +n0 ping -W 1 -c 1 192.168.241.2 +n1 wg set wg0 peer "$pub2" endpoint 192.168.241.2:7 +ip2 link del wg0 +ip2 link del wg1 +! n0 ping -W 1 -c 10 -f 192.168.241.2 || false # Should not crash kernel + +ip0 link del wg1 +ip1 link del wg0 + +# Test using NAT. We now change the topology to this: +# ┌────────────────────────────────────────┐ ┌────────────────────────────────────────────────┐ ┌────────────────────────────────────────┐ +# │ $ns1 namespace │ │ $ns0 namespace │ │ $ns2 namespace │ +# │ │ │ │ │ │ +# │ ┌─────┐ ┌─────┐ │ │ ┌──────┐ ┌──────┐ │ │ ┌─────┐ ┌─────┐ │ +# │ │ wg0 │─────────────│vethc│───────────┼────┼────│vethrc│ │vethrs│──────────────┼─────┼──│veths│────────────│ wg0 │ │ +# │ ├─────┴──────────┐ ├─────┴──────────┐│ │ ├──────┴─────────┐ ├──────┴────────────┐ │ │ ├─────┴──────────┐ ├─────┴──────────┐ │ +# │ │192.168.241.1/24│ │192.168.1.100/24││ │ │192.168.1.1/24 │ │10.0.0.1/24 │ │ │ │10.0.0.100/24 │ │192.168.241.2/24│ │ +# │ │fd00::1/24 │ │ ││ │ │ │ │SNAT:192.168.1.0/24│ │ │ │ │ │fd00::2/24 │ │ +# │ └────────────────┘ └────────────────┘│ │ └────────────────┘ └───────────────────┘ │ │ └────────────────┘ └────────────────┘ │ +# └────────────────────────────────────────┘ └────────────────────────────────────────────────┘ └────────────────────────────────────────┘ + +ip1 link add dev wg0 type wireguard +ip2 link add dev wg0 type wireguard +configure_peers + +ip0 link add vethrc type veth peer name vethc +ip0 link add vethrs type veth peer name veths +ip0 link set vethc netns $jail1 +ip0 link set veths netns $jail2 +ip0 link set vethrc up +ip0 link set vethrs up +ip0 addr add 192.168.1.1/24 dev vethrc +ip0 addr add 10.0.0.1/24 dev vethrs +ip1 addr add 192.168.1.100/24 dev vethc +ip1 link set vethc up +ip1 route add default via 192.168.1.1 +ip2 addr add 10.0.0.100/24 dev veths +ip2 link set veths up +waitiface $jail0 vethrc +waitiface $jail0 vethrs +waitiface $jail1 vethc +waitiface $jail2 veths + +n0 bash -c 'printf 1 > /proc/sys/net/ipv4/ip_forward' +n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout' +n0 bash -c 'printf 2 > /proc/sys/net/netfilter/nf_conntrack_udp_timeout_stream' +n0 iptables -t nat -A POSTROUTING -s 192.168.1.0/24 -d 10.0.0.0/24 -j SNAT --to 10.0.0.1 + +n1 wg set wg0 peer "$pub2" endpoint 10.0.0.100:2 persistent-keepalive 1 +n1 ping -W 1 -c 1 192.168.241.2 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n2 wg show wg0 endpoints) == "$pub1 10.0.0.1:1" ]] +# Demonstrate n2 can still send packets to n1, since persistent-keepalive will prevent connection tracking entry from expiring (to see entries: `n0 conntrack -L`). +pp sleep 3 +n2 ping -W 1 -c 1 192.168.241.1 +n1 wg set wg0 peer "$pub2" persistent-keepalive 0 + +# Test that sk_bound_dev_if works +n1 ping -I wg0 -c 1 -W 1 192.168.241.2 +# What about when the mark changes and the packet must be rerouted? +n1 iptables -t mangle -I OUTPUT -j MARK --set-xmark 1 +n1 ping -c 1 -W 1 192.168.241.2 # First the boring case +n1 ping -I wg0 -c 1 -W 1 192.168.241.2 # Then the sk_bound_dev_if case +n1 iptables -t mangle -D OUTPUT -j MARK --set-xmark 1 + +# Test that onion routing works, even when it loops +n1 wg set wg0 peer "$pub3" allowed-ips 192.168.242.2/32 endpoint 192.168.241.2:5 +ip1 addr add 192.168.242.1/24 dev wg0 +ip2 link add wg1 type wireguard +ip2 addr add 192.168.242.2/24 dev wg1 +n2 wg set wg1 private-key <(echo "$key3") listen-port 5 peer "$pub1" allowed-ips 192.168.242.1/32 +ip2 link set wg1 up +n1 ping -W 1 -c 1 192.168.242.2 +ip2 link del wg1 +n1 wg set wg0 peer "$pub3" endpoint 192.168.242.2:5 +! n1 ping -W 1 -c 1 192.168.242.2 || false # Should not crash kernel +n1 wg set wg0 peer "$pub3" remove +ip1 addr del 192.168.242.1/24 dev wg0 + +# Do a wg-quick(8)-style policy routing for the default route, making sure vethc has a v6 address to tease out bugs. +ip1 -6 addr add fc00::9/96 dev vethc +ip1 -6 route add default via fc00::1 +ip2 -4 addr add 192.168.99.7/32 dev wg0 +ip2 -6 addr add abab::1111/128 dev wg0 +n1 wg set wg0 fwmark 51820 peer "$pub2" allowed-ips 192.168.99.7,abab::1111 +ip1 -6 route add default dev wg0 table 51820 +ip1 -6 rule add not fwmark 51820 table 51820 +ip1 -6 rule add table main suppress_prefixlength 0 +ip1 -4 route add default dev wg0 table 51820 +ip1 -4 rule add not fwmark 51820 table 51820 +ip1 -4 rule add table main suppress_prefixlength 0 +# Flood the pings instead of sending just one, to trigger routing table reference counting bugs. +n1 ping -W 1 -c 100 -f 192.168.99.7 +n1 ping -W 1 -c 100 -f abab::1111 + +# Have ns2 NAT into wg0 packets from ns0, but return an icmp error along the right route. +n2 iptables -t nat -A POSTROUTING -s 10.0.0.0/24 -d 192.168.241.0/24 -j SNAT --to 192.168.241.2 +n0 iptables -t filter -A INPUT \! -s 10.0.0.0/24 -i vethrs -j DROP # Manual rpfilter just to be explicit. +n2 bash -c 'printf 1 > /proc/sys/net/ipv4/ip_forward' +ip0 -4 route add 192.168.241.1 via 10.0.0.100 +n2 wg set wg0 peer "$pub1" remove +[[ $(! n0 ping -W 1 -c 1 192.168.241.1 || false) == *"From 10.0.0.100 icmp_seq=1 Destination Host Unreachable"* ]] + +n0 iptables -t nat -F +n0 iptables -t filter -F +n2 iptables -t nat -F +ip0 link del vethrc +ip0 link del vethrs +ip1 link del wg0 +ip2 link del wg0 + +# Test that saddr routing is sticky but not too sticky, changing to this topology: +# ┌────────────────────────────────────────┐ ┌────────────────────────────────────────┐ +# │ $ns1 namespace │ │ $ns2 namespace │ +# │ │ │ │ +# │ ┌─────┐ ┌─────┐ │ │ ┌─────┐ ┌─────┐ │ +# │ │ wg0 │─────────────│veth1│───────────┼────┼──│veth2│────────────│ wg0 │ │ +# │ ├─────┴──────────┐ ├─────┴──────────┐│ │ ├─────┴──────────┐ ├─────┴──────────┐ │ +# │ │192.168.241.1/24│ │10.0.0.1/24 ││ │ │10.0.0.2/24 │ │192.168.241.2/24│ │ +# │ │fd00::1/24 │ │fd00:aa::1/96 ││ │ │fd00:aa::2/96 │ │fd00::2/24 │ │ +# │ └────────────────┘ └────────────────┘│ │ └────────────────┘ └────────────────┘ │ +# └────────────────────────────────────────┘ └────────────────────────────────────────┘ + +ip1 link add dev wg0 type wireguard +ip2 link add dev wg0 type wireguard +configure_peers +ip1 link add veth1 type veth peer name veth2 +ip1 link set veth2 netns $jail2 +n1 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/all/accept_dad' +n2 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/all/accept_dad' +n1 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth1/accept_dad' +n2 bash -c 'printf 0 > /proc/sys/net/ipv6/conf/veth2/accept_dad' +n1 bash -c 'printf 1 > /proc/sys/net/ipv4/conf/veth1/promote_secondaries' + +# First we check that we aren't overly sticky and can fall over to new IPs when old ones are removed +ip1 addr add 10.0.0.1/24 dev veth1 +ip1 addr add fd00:aa::1/96 dev veth1 +ip2 addr add 10.0.0.2/24 dev veth2 +ip2 addr add fd00:aa::2/96 dev veth2 +ip1 link set veth1 up +ip2 link set veth2 up +waitiface $jail1 veth1 +waitiface $jail2 veth2 +n1 wg set wg0 peer "$pub2" endpoint 10.0.0.2:2 +n1 ping -W 1 -c 1 192.168.241.2 +ip1 addr add 10.0.0.10/24 dev veth1 +ip1 addr del 10.0.0.1/24 dev veth1 +n1 ping -W 1 -c 1 192.168.241.2 +n1 wg set wg0 peer "$pub2" endpoint [fd00:aa::2]:2 +n1 ping -W 1 -c 1 192.168.241.2 +ip1 addr add fd00:aa::10/96 dev veth1 +ip1 addr del fd00:aa::1/96 dev veth1 +n1 ping -W 1 -c 1 192.168.241.2 + +# Now we show that we can successfully do reply to sender routing +ip1 link set veth1 down +ip2 link set veth2 down +ip1 addr flush dev veth1 +ip2 addr flush dev veth2 +ip1 addr add 10.0.0.1/24 dev veth1 +ip1 addr add 10.0.0.2/24 dev veth1 +ip1 addr add fd00:aa::1/96 dev veth1 +ip1 addr add fd00:aa::2/96 dev veth1 +ip2 addr add 10.0.0.3/24 dev veth2 +ip2 addr add fd00:aa::3/96 dev veth2 +ip1 link set veth1 up +ip2 link set veth2 up +waitiface $jail1 veth1 +waitiface $jail2 veth2 +n2 wg set wg0 peer "$pub1" endpoint 10.0.0.1:1 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n2 wg show wg0 endpoints) == "$pub1 10.0.0.1:1" ]] +n2 wg set wg0 peer "$pub1" endpoint [fd00:aa::1]:1 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n2 wg show wg0 endpoints) == "$pub1 [fd00:aa::1]:1" ]] +n2 wg set wg0 peer "$pub1" endpoint 10.0.0.2:1 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n2 wg show wg0 endpoints) == "$pub1 10.0.0.2:1" ]] +n2 wg set wg0 peer "$pub1" endpoint [fd00:aa::2]:1 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n2 wg show wg0 endpoints) == "$pub1 [fd00:aa::2]:1" ]] + +# What happens if the inbound destination address belongs to a different interface as the default route? +ip1 link add dummy0 type dummy +ip1 addr add 10.50.0.1/24 dev dummy0 +ip1 link set dummy0 up +ip2 route add 10.50.0.0/24 dev veth2 +n2 wg set wg0 peer "$pub1" endpoint 10.50.0.1:1 +n2 ping -W 1 -c 1 192.168.241.1 +[[ $(n2 wg show wg0 endpoints) == "$pub1 10.50.0.1:1" ]] + +ip1 link del dummy0 +ip1 addr flush dev veth1 +ip2 addr flush dev veth2 +ip1 route flush dev veth1 +ip2 route flush dev veth2 + +# Now we see what happens if another interface route takes precedence over an ongoing one +ip1 link add veth3 type veth peer name veth4 +ip1 link set veth4 netns $jail2 +ip1 addr add 10.0.0.1/24 dev veth1 +ip2 addr add 10.0.0.2/24 dev veth2 +ip1 addr add 10.0.0.3/24 dev veth3 +ip1 link set veth1 up +ip2 link set veth2 up +ip1 link set veth3 up +ip2 link set veth4 up +waitiface $jail1 veth1 +waitiface $jail2 veth2 +waitiface $jail1 veth3 +waitiface $jail2 veth4 +ip1 route flush dev veth1 +ip1 route flush dev veth3 +ip1 route add 10.0.0.0/24 dev veth1 src 10.0.0.1 metric 2 +n1 wg set wg0 peer "$pub2" endpoint 10.0.0.2:2 +n1 ping -W 1 -c 1 192.168.241.2 +[[ $(n2 wg show wg0 endpoints) == "$pub1 10.0.0.1:1" ]] +ip1 route add 10.0.0.0/24 dev veth3 src 10.0.0.3 metric 1 +n1 bash -c 'printf 0 > /proc/sys/net/ipv4/conf/veth1/rp_filter' +n2 bash -c 'printf 0 > /proc/sys/net/ipv4/conf/veth4/rp_filter' +n1 bash -c 'printf 0 > /proc/sys/net/ipv4/conf/all/rp_filter' +n2 bash -c 'printf 0 > /proc/sys/net/ipv4/conf/all/rp_filter' +n1 ping -W 1 -c 1 192.168.241.2 +[[ $(n2 wg show wg0 endpoints) == "$pub1 10.0.0.3:1" ]] + +ip1 link del veth1 +ip1 link del veth3 +ip1 link del wg0 +ip2 link del wg0 + +# We test that Netlink/IPC is working properly by doing things that usually cause split responses +ip0 link add dev wg0 type wireguard +config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" ) +for a in {1..255}; do + for b in {0..255}; do + config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" ) + done +done +n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") +i=0 +for ip in $(n0 wg show wg0 allowed-ips); do + ((++i)) +done +((i == 255*256*2+1)) +ip0 link del wg0 +ip0 link add dev wg0 type wireguard +config=( "[Interface]" "PrivateKey=$(wg genkey)" ) +for a in {1..40}; do + config+=( "[Peer]" "PublicKey=$(wg genkey)" ) + for b in {1..52}; do + config+=( "AllowedIPs=$a.$b.0.0/16" ) + done +done +n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") +i=0 +while read -r line; do + j=0 + for ip in $line; do + ((++j)) + done + ((j == 53)) + ((++i)) +done < <(n0 wg show wg0 allowed-ips) +((i == 40)) +ip0 link del wg0 +ip0 link add wg0 type wireguard +config=( ) +for i in {1..29}; do + config+=( "[Peer]" "PublicKey=$(wg genkey)" ) +done +config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" ) +n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") +n0 wg showconf wg0 > /dev/null +ip0 link del wg0 + +allowedips=( ) +for i in {1..197}; do + allowedips+=( abcd::$i ) +done +saved_ifs="$IFS" +IFS=, +allowedips="${allowedips[*]}" +IFS="$saved_ifs" +ip0 link add wg0 type wireguard +n0 wg set wg0 peer "$pub1" +n0 wg set wg0 peer "$pub2" allowed-ips "$allowedips" +{ + read -r pub allowedips + [[ $pub == "$pub1" && $allowedips == "(none)" ]] + read -r pub allowedips + [[ $pub == "$pub2" ]] + i=0 + for _ in $allowedips; do + ((++i)) + done + ((i == 197)) +} < <(n0 wg show wg0 allowed-ips) +ip0 link del wg0 + +! n0 wg show doesnotexist || false + +ip0 link add wg0 type wireguard +n0 wg set wg0 private-key <(echo "$key1") peer "$pub2" preshared-key <(echo "$psk") +[[ $(n0 wg show wg0 private-key) == "$key1" ]] +[[ $(n0 wg show wg0 preshared-keys) == "$pub2 $psk" ]] +n0 wg set wg0 private-key /dev/null peer "$pub2" preshared-key /dev/null +[[ $(n0 wg show wg0 private-key) == "(none)" ]] +[[ $(n0 wg show wg0 preshared-keys) == "$pub2 (none)" ]] +n0 wg set wg0 peer "$pub2" +n0 wg set wg0 private-key <(echo "$key2") +[[ $(n0 wg show wg0 public-key) == "$pub2" ]] +[[ -z $(n0 wg show wg0 peers) ]] +n0 wg set wg0 peer "$pub2" +[[ -z $(n0 wg show wg0 peers) ]] +n0 wg set wg0 private-key <(echo "$key1") +n0 wg set wg0 peer "$pub2" +[[ $(n0 wg show wg0 peers) == "$pub2" ]] +n0 wg set wg0 private-key <(echo "/${key1:1}") +[[ $(n0 wg show wg0 private-key) == "+${key1:1}" ]] +n0 wg set wg0 peer "$pub2" allowed-ips 0.0.0.0/0,10.0.0.0/8,100.0.0.0/10,172.16.0.0/12,192.168.0.0/16 +n0 wg set wg0 peer "$pub2" allowed-ips 0.0.0.0/0 +n0 wg set wg0 peer "$pub2" allowed-ips ::/0,1700::/111,5000::/4,e000::/37,9000::/75 +n0 wg set wg0 peer "$pub2" allowed-ips ::/0 +n0 wg set wg0 peer "$pub2" remove +for low_order_point in AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= AQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= 4Ot6fDtBuK4WVuP68Z/EatoJjeucMrH9hmIFFl9JuAA= X5yVvKNQjCSx0LFVnIPvWwREXMRYHI6G2CJO3dCfEVc= 7P///////////////////////////////////////38= 7f///////////////////////////////////////38= 7v///////////////////////////////////////38=; do + n0 wg set wg0 peer "$low_order_point" persistent-keepalive 1 endpoint 127.0.0.1:1111 +done +[[ -n $(n0 wg show wg0 peers) ]] +exec 4< <(n0 ncat -l -u -p 1111) +ncat_pid=$! +waitncatudp $jail0 $ncat_pid +ip0 link set wg0 up +! read -r -n 1 -t 2 <&4 || false +kill $ncat_pid +ip0 link del wg0 + +# Ensure there aren't circular reference loops +ip1 link add wg1 type wireguard +ip2 link add wg2 type wireguard +ip1 link set wg1 netns $jail2 +ip2 link set wg2 netns $jail1 +pp ip netns delete $jail1 +pp ip netns delete $jail2 +pp ip netns add $jail1 +pp ip netns add $jail2 + +sleep 2 # Wait for cleanup and grace periods +declare -A objects +while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do + [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ ?[0-9]*)\ .*(created|destroyed).* ]] || continue + objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}" +done < /dev/kmsg +alldeleted=1 +for object in "${!objects[@]}"; do + if [[ ${objects["$object"]} != *createddestroyed ]]; then + echo "Error: $object: merely ${objects["$object"]}" >&3 + alldeleted=0 + fi +done +[[ $alldeleted -eq 1 ]] +pretty "" "Objects that were created were also destroyed." -- cgit v1.2.3-59-g8ed1b