aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock139
-rw-r--r--Cargo.toml16
-rw-r--r--src/configuration/config.rs19
-rw-r--r--src/configuration/uapi/get.rs3
-rw-r--r--src/configuration/uapi/mod.rs1
-rw-r--r--src/main.rs2
-rw-r--r--src/platform/dummy/tun.rs3
-rw-r--r--src/platform/dummy/udp.rs13
-rw-r--r--src/platform/linux/tun.rs2
-rw-r--r--src/platform/linux/udp.rs667
-rw-r--r--src/platform/udp.rs4
-rw-r--r--src/wireguard/handshake/device.rs198
-rw-r--r--src/wireguard/handshake/macs.rs6
-rw-r--r--src/wireguard/handshake/noise.rs46
-rw-r--r--src/wireguard/handshake/peer.rs26
-rw-r--r--src/wireguard/handshake/tests.rs62
-rw-r--r--src/wireguard/handshake/types.rs14
-rw-r--r--src/wireguard/peer.rs2
-rw-r--r--src/wireguard/router/device.rs2
-rw-r--r--src/wireguard/router/peer.rs2
-rw-r--r--src/wireguard/wireguard.rs85
-rw-r--r--src/wireguard/workers.rs80
22 files changed, 1019 insertions, 373 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 1b1feb6..10bfca5 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2,7 +2,7 @@
# It is not intended for manual editing.
[[package]]
name = "aead"
-version = "0.1.1"
+version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -130,23 +130,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "chacha20"
-version = "0.2.1"
+version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
- "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
- "salsa20-core 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"stream-cipher 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
+ "zeroize 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "chacha20poly1305"
-version = "0.1.0"
+version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
- "aead 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "chacha20 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "poly1305 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
- "zeroize 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)",
+ "aead 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "chacha20 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "poly1305 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)",
+ "zeroize 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -204,14 +203,14 @@ dependencies = [
[[package]]
name = "curve25519-dalek"
-version = "1.2.3"
+version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
- "clear_on_drop 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "rand_core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "subtle 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "subtle 2.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
+ "zeroize 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -339,7 +338,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"mio 0.6.19 (registry+https://github.com/rust-lang/crates.io-index)",
"mio-extras 2.0.5 (registry+https://github.com/rust-lang/crates.io-index)",
- "spin 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "spin 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -596,11 +595,10 @@ dependencies = [
[[package]]
name = "poly1305"
-version = "0.2.0"
+version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
- "byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
- "crypto-mac 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "universal-hash 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -690,7 +688,7 @@ dependencies = [
"getrandom 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.62 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
@@ -709,7 +707,7 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"c2-chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
- "rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -727,7 +725,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "rand_core"
-version = "0.5.0"
+version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"getrandom 0.1.11 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -746,7 +744,7 @@ name = "rand_hc"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
- "rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -843,7 +841,7 @@ dependencies = [
"cc 1.0.40 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"libc 0.2.62 (registry+https://github.com/rust-lang/crates.io-index)",
- "spin 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "spin 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)",
"untrusted 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"web-sys 0.3.27 (registry+https://github.com/rust-lang/crates.io-index)",
"winapi 0.3.7 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -871,15 +869,6 @@ dependencies = [
]
[[package]]
-name = "salsa20-core"
-version = "0.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-dependencies = [
- "stream-cipher 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
- "zeroize 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)",
-]
-
-[[package]]
name = "serde"
version = "1.0.99"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -896,7 +885,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "spin"
-version = "0.5.1"
+version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
@@ -914,7 +903,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "subtle"
-version = "2.1.1"
+version = "2.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
@@ -949,6 +938,17 @@ dependencies = [
]
[[package]]
+name = "synstructure"
+version = "0.12.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "proc-macro2 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)",
+ "syn 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)",
+ "unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
name = "syntex"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1062,6 +1062,15 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
+name = "universal-hash"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)",
+ "subtle 2.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
name = "untrusted"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1221,13 +1230,13 @@ dependencies = [
[[package]]
name = "wireguard-rs"
-version = "0.1.0"
+version = "0.1.1"
dependencies = [
- "aead 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "aead 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"arraydeque 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)",
"blake2 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
"byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
- "chacha20poly1305 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "chacha20poly1305 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"clear_on_drop 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"cpuprofiler 0.0.4 (registry+https://github.com/rust-lang/crates.io-index)",
"crossbeam-channel 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -1243,14 +1252,14 @@ dependencies = [
"num_cpus 1.10.1 (registry+https://github.com/rust-lang/crates.io-index)",
"pnet 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)",
"proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)",
- "rand 0.6.5 (registry+https://github.com/rust-lang/crates.io-index)",
+ "rand 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)",
"rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
"ring 0.16.7 (registry+https://github.com/rust-lang/crates.io-index)",
- "spin 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
- "subtle 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "spin 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)",
+ "subtle 2.2.2 (registry+https://github.com/rust-lang/crates.io-index)",
"treebitmap 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
- "x25519-dalek 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)",
+ "x25519-dalek 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)",
"zerocopy 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
@@ -1265,12 +1274,12 @@ dependencies = [
[[package]]
name = "x25519-dalek"
-version = "0.5.2"
+version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
- "clear_on_drop 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
- "curve25519-dalek 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
- "rand_core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "curve25519-dalek 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "zeroize 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
@@ -1294,11 +1303,25 @@ dependencies = [
[[package]]
name = "zeroize"
-version = "0.9.3"
+version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "zeroize_derive 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
+name = "zeroize_derive"
+version = "1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "proc-macro2 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)",
+ "syn 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)",
+ "synstructure 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)",
+]
[metadata]
-"checksum aead 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "529ae27769da55d955d190396e67896f49b440aff94a5b2f50900e091d168b77"
+"checksum aead 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4cf01b9b56e767bb57b94ebf91a58b338002963785cdd7013e21c0d4679471e4"
"checksum aho-corasick 0.6.10 (registry+https://github.com/rust-lang/crates.io-index)" = "81ce3d38065e618af2d7b77e10c5ad9a069859b4be3c2250f674af3840d9c8a5"
"checksum arraydeque 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)" = "f0ffd3d69bd89910509a5d31d1f1353f38ccffdd116dd0099bbd6627f7bd8ad8"
"checksum atty 0.2.13 (registry+https://github.com/rust-lang/crates.io-index)" = "1803c647a3ec87095e7ae7acfca019e98de5ec9a7d01343f611cf3152ed71a90"
@@ -1317,15 +1340,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum c2-chacha 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7d64d04786e0f528460fc884753cf8dddcc466be308f6026f8e355c41a0e4101"
"checksum cc 1.0.40 (registry+https://github.com/rust-lang/crates.io-index)" = "b548a4ee81fccb95919d4e22cfea83c7693ebfd78f0495493178db20b3139da7"
"checksum cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "b486ce3ccf7ffd79fdeb678eac06a9e6c09fc88d33836340becb8fffe87c5e33"
-"checksum chacha20 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "9ce602601e1450409cfe3a6dea32a5de678e08c43368e860c2afa2eec58ce3dc"
-"checksum chacha20poly1305 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "40cd3ddeae0b0ea7fe848a06e4fbf3f02463648b9395bd1139368ce42b44543e"
+"checksum chacha20 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "166651768ffa1f1fa7024d0164fea4e71d84ea5df4ee94796cadb83878faba84"
+"checksum chacha20poly1305 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "4d7c7def9c4e6f11a5b7525585853135689865907ca3c4c34e0a4b252fd50dd0"
"checksum clear_on_drop 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "97276801e127ffb46b66ce23f35cc96bd454fa311294bced4bbace7baa8b1d17"
"checksum cloudabi 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f"
"checksum cpuprofiler 0.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "43f8479dbcfd2bbaa0c0c26779b913052b375981cdf533091f2127ea3d42e52b"
"checksum crossbeam-channel 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "acec9a3b0b3559f15aee4f90746c4e5e293b701c0f7d3925d24e01645267b68c"
"checksum crossbeam-utils 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ce446db02cdc3165b94ae73111e570793400d0794e46125cc4056c81cbb039f4"
"checksum crypto-mac 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4434400df11d95d556bac068ddfedd482915eb18fe8bea89bc80b6e4b1c179e5"
-"checksum curve25519-dalek 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8b7dcd30ba50cdf88b55b033456138b7c0ac4afdc436d82e1b79f370f24cc66d"
+"checksum curve25519-dalek 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "26778518a7f6cffa1d25a44b602b62b979bd88adb9e99ffec546998cf3404839"
"checksum daemonize 0.4.1 (registry+https://github.com/rust-lang/crates.io-index)" = "70c24513e34f53b640819f0ac9f705b673fcf4006d7aab8778bee72ebfc89815"
"checksum digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f3d0c8c8752312f9713efd397ff63acb9f85585afbf179282e720e7704954dd5"
"checksum env_logger 0.6.2 (registry+https://github.com/rust-lang/crates.io-index)" = "aafcde04e90a5226a6443b7aabdb016ba2f8307c847d524724bd9b346dd1a2d3"
@@ -1371,7 +1394,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum pnet_packet 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)" = "08a6cdcdaddc5174f18286298842a4e31cd3cc018933d42af51434b1fa07dcbe"
"checksum pnet_sys 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)" = "682b2eca84cc440bce8336813f78eb6d3cb0fed89fe0e87ae22acfca8363f176"
"checksum pnet_transport 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5faa55dcf725487a699adcff88dfea8f17ea34fa2640528866d9acbb4e3a104f"
-"checksum poly1305 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fb59dfc6d8dd49677e39bf8fdf4c62235a8d84dbe2ef2913e139d3f62bb65f70"
+"checksum poly1305 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "b5829f50f48e9ddb79f3f7c3097029d0caee30f8286accb241416df603b080b8"
"checksum ppv-lite86 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)" = "e3cbf9f658cdb5000fcf6f362b8ea2ba154b9f146a61c7a20d647034c6b6561b"
"checksum proc-macro2 0.4.30 (registry+https://github.com/rust-lang/crates.io-index)" = "cf3d2011ab5c909338f7887f4fc896d35932e29146c12c8d01da6b22a80ba759"
"checksum proc-macro2 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "4c5c2380ae88876faae57698be9e9775e3544decad214599c3a6266cca6ac802"
@@ -1385,7 +1408,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum rand_chacha 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "03a2a90da8c7523f554344f921aa97283eadf6ac484a6d2a7d0212fa7f8d6853"
"checksum rand_core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b"
"checksum rand_core 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc"
-"checksum rand_core 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "615e683324e75af5d43d8f7a39ffe3ee4a9dc42c5c701167a71dc59c3a493aca"
+"checksum rand_core 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
"checksum rand_hc 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7b40677c7be09ae76218dc623efbf7b18e34bced3f38883af07bb75630a21bc4"
"checksum rand_hc 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
"checksum rand_isaac 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "ded997c9d5f13925be2a6fd7e66bf1872597f759fd9dd93513dd7e92e5a5ee08"
@@ -1402,17 +1425,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum rustc-demangle 0.1.16 (registry+https://github.com/rust-lang/crates.io-index)" = "4c691c0e608126e00913e33f0ccf3727d5fc84573623b8d65b2df340b5201783"
"checksum rustc-serialize 0.3.24 (registry+https://github.com/rust-lang/crates.io-index)" = "dcf128d1287d2ea9d80910b5f1120d0b8eede3fbf1abe91c40d39ea7d51e6fda"
"checksum rusty-fork 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "3dd93264e10c577503e926bd1430193eeb5d21b059148910082245309b424fae"
-"checksum salsa20-core 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "c7fd325cb25b420aab2c035b5b76966d9f91b88fb54084ce6c0cd072a1ae5cda"
"checksum serde 1.0.99 (registry+https://github.com/rust-lang/crates.io-index)" = "fec2851eb56d010dc9a21b89ca53ee75e6528bab60c11e89d38390904982da9f"
"checksum slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c111b5bd5695e56cffe5129854aa230b39c93a305372fdbb2668ca2394eea9f8"
"checksum sourcefile 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "4bf77cb82ba8453b42b6ae1d692e4cdc92f9a47beaf89a847c8be83f4e328ad3"
-"checksum spin 0.5.1 (registry+https://github.com/rust-lang/crates.io-index)" = "cbdb51a221842709c2dd65b62ad4b78289fc3e706a02c17a26104528b6aa7837"
+"checksum spin 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
"checksum stream-cipher 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "8131256a5896cabcf5eb04f4d6dacbe1aefda854b0d9896e09cb58829ec5638c"
"checksum subtle 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "2d67a5a62ba6e01cb2192ff309324cb4875d0c451d55fe2319433abe7a05a8ee"
-"checksum subtle 2.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "01f40907d9ffc762709e4ff3eb4a6f6b41b650375a3f09ac92b641942b7fb082"
+"checksum subtle 2.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7c65d530b10ccaeac294f349038a597e435b18fb456aadd0840a623f83b9e941"
"checksum syn 0.15.44 (registry+https://github.com/rust-lang/crates.io-index)" = "9ca4b3b69a77cbe1ffc9e198781b7acb0c7365a883670e8f1c1bc66fba79a5c5"
"checksum syn 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "66850e97125af79138385e9b88339cbcd037e3f28ceab8c5ad98e64f0f1f80bf"
"checksum synstructure 0.10.2 (registry+https://github.com/rust-lang/crates.io-index)" = "02353edf96d6e4dc81aea2d8490a7e9db177bf8acb0e951c24940bf866cb313f"
+"checksum synstructure 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)" = "67656ea1dc1b41b1451851562ea232ec2e5a80242139f7e679ceccfb5d61f545"
"checksum syntex 0.42.2 (registry+https://github.com/rust-lang/crates.io-index)" = "0a30b08a6b383a22e5f6edc127d169670d48f905bb00ca79a00ea3e442ebe317"
"checksum syntex_errors 0.42.0 (registry+https://github.com/rust-lang/crates.io-index)" = "04c48f32867b6114449155b2a82114b86d4b09e1bddb21c47ff104ab9172b646"
"checksum syntex_pos 0.42.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3fd49988e52451813c61fecbe9abb5cfd4e1b7bb6cdbb980a6fbcbab859171a6"
@@ -1427,6 +1450,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum unicode-xid 0.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "36dff09cafb4ec7c8cf0023eb0b686cb6ce65499116a12201c9e11840ca01beb"
"checksum unicode-xid 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "fc72304796d0818e357ead4e000d19c9c174ab23dc11093ac919054d20a6a7fc"
"checksum unicode-xid 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "826e7639553986605ec5979c7dd957c7895e93eabed50ab2ffa7f6128a75097c"
+"checksum universal-hash 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "df0c900f2f9b4116803415878ff48b63da9edb268668e08cf9292d7503114a01"
"checksum untrusted 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)" = "60369ef7a31de49bcb3f6ca728d4ba7300d9a1658f94c727d4cab8c8d9f4aece"
"checksum utf8-ranges 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b4ae116fef2b7fea257ed6440d3cfcff7f190865f170cdad00bb6465bf18ecba"
"checksum version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "914b1a6776c4c929a602fafd8bc742e06365d4bcbe48c30f9cca5824f70dc9dd"
@@ -1448,7 +1472,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum winapi-x86_64-pc-windows-gnu 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
"checksum wincolor 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "96f5016b18804d24db43cebf3c77269e7569b8954a8464501c216cc5e070eaa9"
"checksum ws2_32-sys 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "d59cefebd0c892fa2dd6de581e937301d8552cb44489cdff035c6187cb63fa5e"
-"checksum x25519-dalek 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7ee1585dc1484373cbc1cee7aafda26634665cf449436fd6e24bfd1fad230538"
+"checksum x25519-dalek 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)" = "637ff90c9540fa3073bb577e65033069e4bae7c79d49d74aa3ffdf5342a53217"
"checksum zerocopy 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "992b9b31f80fd4a167f903f879b8ca43d6716cc368ea01df90538baa2dd34056"
"checksum zerocopy-derive 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "b090467ecd0624026e8a6405d343ac7382592530d54881330b3fc8e400280fa5"
-"checksum zeroize 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)" = "45af6a010d13e4cf5b54c94ba5a2b2eba5596b9e46bf5875612d332a1f2b3f86"
+"checksum zeroize 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3cbac2ed2ba24cc90f5e06485ac8c7c1e5449fe8911aef4d8877218af021a5b8"
+"checksum zeroize_derive 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "de251eec69fc7c1bc3923403d18ececb929380e016afe103da75f396704f8ca2"
diff --git a/Cargo.toml b/Cargo.toml
index 1298a28..91cac08 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,14 +1,12 @@
[package]
name = "wireguard-rs"
-version = "0.1.0"
+version = "0.1.1"
authors = ["Mathias Hall-Andersen <mathias@hall-andersen.dk>"]
edition = "2018"
-license = "MIT"
[dependencies]
hex = "0.3"
-spin = "0.5.0"
-rand = "0.6.5"
+spin = "0.5.2"
blake2 = "0.8"
log = { version = "0.4", features = ["max_level_trace", "release_max_level_info"] }
hmac = "0.7.1"
@@ -20,8 +18,10 @@ arraydeque = "0.4.5"
treebitmap = "^0.4"
hjul = "0.2.1"
ring = "0.16.7"
-chacha20poly1305 = "^0.1"
-aead = "^0.1.1"
+rand = "^0.7"
+rand_core = "^0.5"
+chacha20poly1305 = "^0.3"
+aead = "^0.2"
clear_on_drop = "0.2.3"
env_logger = "0.6"
num_cpus = "^1.10"
@@ -33,7 +33,7 @@ cpuprofiler = { version = "*", optional = true }
libc = "0.2"
[dependencies.x25519-dalek]
-version = "^0.5"
+version = "^0.6"
[dependencies.subtle]
version = "2.1"
@@ -47,4 +47,4 @@ start_up = []
pnet = "^0.22"
proptest = "0.9.4"
rand_chacha = "0.2.1"
-rand_core = "0.5"
+
diff --git a/src/configuration/config.rs b/src/configuration/config.rs
index aec943f..59cef4a 100644
--- a/src/configuration/config.rs
+++ b/src/configuration/config.rs
@@ -1,3 +1,4 @@
+use std::mem;
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::Ordering;
use std::sync::{Arc, Mutex, MutexGuard};
@@ -205,7 +206,7 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
}
fn get_fwmark(&self) -> Option<u32> {
- self.lock().bind.as_ref().and_then(|own| own.get_fwmark())
+ self.lock().fwmark
}
fn set_private_key(&self, sk: Option<StaticSecret>) {
@@ -266,24 +267,22 @@ impl<T: tun::Tun, B: udp::PlatformUDP> Configuration for WireGuardConfig<T, B> {
fn set_listen_port(&self, port: u16) -> Result<(), ConfigError> {
log::trace!("Config, Set listen port: {:?}", port);
- // update port
- let listen: bool = {
+ // update port and take old bind
+ let old: Option<B::Owner> = {
let mut cfg = self.lock();
+ let old = mem::replace(&mut cfg.bind, None);
cfg.port = port;
- if cfg.bind.is_some() {
- cfg.bind = None;
- true
- } else {
- false
- }
+ old
};
// restart listener if bound
- if listen {
+ if old.is_some() {
self.start_listener()
} else {
Ok(())
}
+
+ // old bind is dropped, causing the file-descriptors to be released
}
fn set_fwmark(&self, mark: Option<u32>) -> Result<(), ConfigError> {
diff --git a/src/configuration/uapi/get.rs b/src/configuration/uapi/get.rs
index 9e6ab36..00048cd 100644
--- a/src/configuration/uapi/get.rs
+++ b/src/configuration/uapi/get.rs
@@ -2,7 +2,6 @@ use log;
use std::io;
use super::Configuration;
-use super::Endpoint;
pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) -> io::Result<()> {
let mut write = |key: &'static str, value: String| {
@@ -46,7 +45,7 @@ pub fn serialize<C: Configuration, W: io::Write>(writer: &mut W, config: &C) ->
}
if let Some(endpoint) = p.endpoint {
- write("endpoint", endpoint.into_address().to_string())?;
+ write("endpoint", endpoint.to_string())?;
}
for (ip, cidr) in p.allowed_ips {
diff --git a/src/configuration/uapi/mod.rs b/src/configuration/uapi/mod.rs
index 4f0b741..9f54775 100644
--- a/src/configuration/uapi/mod.rs
+++ b/src/configuration/uapi/mod.rs
@@ -4,7 +4,6 @@ mod set;
use log;
use std::io::{Read, Write};
-use super::Endpoint;
use super::{ConfigError, Configuration};
use get::serialize;
diff --git a/src/main.rs b/src/main.rs
index a0f4a23..a8e4ad2 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -112,6 +112,8 @@ fn main() {
.try_init()
.expect("Failed to initialize event logger");
+ log::info!("starting {} wireguard device", name);
+
// drop privileges
if drop_privileges {}
diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs
index 9836b48..1955884 100644
--- a/src/platform/dummy/tun.rs
+++ b/src/platform/dummy/tun.rs
@@ -165,8 +165,7 @@ impl TunTest {
sync_channel(1)
};
- let mut rng = OsRng::new().unwrap();
- let id: u32 = rng.gen();
+ let id: u32 = OsRng.gen();
let fake = TunFakeIO {
id,
diff --git a/src/platform/dummy/udp.rs b/src/platform/dummy/udp.rs
index 35c905d..88630af 100644
--- a/src/platform/dummy/udp.rs
+++ b/src/platform/dummy/udp.rs
@@ -54,7 +54,7 @@ impl Reader<UnitEndpoint> for VoidBind {
impl Writer<UnitEndpoint> for VoidBind {
type Error = BindError;
- fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
+ fn write(&self, _buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> {
Ok(())
}
}
@@ -105,7 +105,7 @@ impl Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
impl Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
type Error = BindError;
- fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
+ fn write(&self, buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> {
debug!(
"dummy({}): write ({}, {})",
self.id,
@@ -135,9 +135,8 @@ impl PairBind {
(PairReader<E>, PairWriter<E>),
(PairReader<E>, PairWriter<E>),
) {
- let mut rng = OsRng::new().unwrap();
- let id1: u32 = rng.gen();
- let id2: u32 = rng.gen();
+ let id1: u32 = OsRng.gen();
+ let id2: u32 = OsRng.gen();
let (tx1, rx1) = sync_channel(128);
let (tx2, rx2) = sync_channel(128);
@@ -187,10 +186,6 @@ impl Owner for VoidOwner {
fn get_port(&self) -> u16 {
0
}
-
- fn get_fwmark(&self) -> Option<u32> {
- None
- }
}
impl PlatformUDP for PairBind {
diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs
index c282a4b..15ca1ec 100644
--- a/src/platform/linux/tun.rs
+++ b/src/platform/linux/tun.rs
@@ -199,7 +199,7 @@ impl Status for LinuxTunStatus {
// cut buffer to size
let size: usize = size as usize;
let mut remain = &buf[..size];
- log::debug!("netlink, recieved message ({} bytes)", size);
+ log::debug!("netlink, received message ({} bytes)", size);
// handle messages
while remain.len() >= HDR_SIZE {
diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs
index f871bce..9815ab1 100644
--- a/src/platform/linux/udp.rs
+++ b/src/platform/linux/udp.rs
@@ -1,84 +1,683 @@
use super::super::udp::*;
use super::super::Endpoint;
+use log;
+
+use std::convert::TryInto;
use std::io;
-use std::net::{SocketAddr, UdpSocket};
-use std::sync::Arc;
+use std::mem;
+use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6};
+use std::os::unix::io::RawFd;
+use std::ptr;
+
+fn errno() -> libc::c_int {
+ unsafe {
+ let ptr = libc::__errno_location();
+ if ptr.is_null() {
+ 0
+ } else {
+ *ptr
+ }
+ }
+}
+
+#[repr(C, align(1))]
+struct ControlHeaderV4 {
+ hdr: libc::cmsghdr,
+ info: libc::in_pktinfo,
+}
+
+#[repr(C, align(1))]
+struct ControlHeaderV6 {
+ hdr: libc::cmsghdr,
+ info: libc::in6_pktinfo,
+}
+
+pub struct EndpointV4 {
+ dst: libc::sockaddr_in, // destination IP
+ info: libc::in_pktinfo, // src & ifindex
+}
+
+pub struct EndpointV6 {
+ dst: libc::sockaddr_in6, // destination IP
+ info: libc::in6_pktinfo, // src & zone id
+}
+
+pub struct LinuxUDP();
+
+pub struct LinuxOwner {
+ port: u16,
+ sock4: Option<RawFd>,
+ sock6: Option<RawFd>,
+}
+
+pub enum LinuxUDPReader {
+ V4(RawFd),
+ V6(RawFd),
+}
#[derive(Clone)]
-pub struct LinuxUDP(Arc<UdpSocket>);
+pub struct LinuxUDPWriter {
+ sock4: RawFd,
+ sock6: RawFd,
+}
-pub struct LinuxOwner(Arc<UdpSocket>);
+pub enum LinuxEndpoint {
+ V4(EndpointV4),
+ V6(EndpointV6),
+}
-impl Endpoint for SocketAddr {
- fn clear_src(&mut self) {}
+impl Endpoint for LinuxEndpoint {
+ fn clear_src(&mut self) {
+ match self {
+ LinuxEndpoint::V4(EndpointV4 { ref mut info, .. }) => {
+ info.ipi_ifindex = 0;
+ info.ipi_spec_dst = libc::in_addr { s_addr: 0 };
+ }
+ LinuxEndpoint::V6(EndpointV6 { ref mut info, .. }) => {
+ info.ipi6_addr = libc::in6_addr { s6_addr: [0; 16] };
+ info.ipi6_ifindex = 0;
+ }
+ };
+ }
fn from_address(addr: SocketAddr) -> Self {
- addr
+ match addr {
+ SocketAddr::V4(addr) => LinuxEndpoint::V4(EndpointV4 {
+ dst: libc::sockaddr_in {
+ sin_family: libc::AF_INET as libc::sa_family_t,
+ sin_port: addr.port().to_be(),
+ sin_addr: libc::in_addr {
+ s_addr: u32::from(*addr.ip()).to_be(),
+ },
+ sin_zero: [0; 8],
+ },
+ info: libc::in_pktinfo {
+ ipi_ifindex: 0, // interface (0 is via routing table)
+ ipi_spec_dst: libc::in_addr { s_addr: 0 }, // src IP (dst of incoming packet)
+ ipi_addr: libc::in_addr { s_addr: 0 },
+ },
+ }),
+ SocketAddr::V6(addr) => LinuxEndpoint::V6(EndpointV6 {
+ dst: libc::sockaddr_in6 {
+ sin6_family: libc::AF_INET6 as libc::sa_family_t,
+ sin6_port: addr.port().to_be(),
+ sin6_flowinfo: addr.flowinfo(),
+ sin6_addr: libc::in6_addr {
+ s6_addr: addr.ip().octets(),
+ },
+ sin6_scope_id: addr.scope_id(),
+ },
+ info: libc::in6_pktinfo {
+ ipi6_addr: libc::in6_addr { s6_addr: [0; 16] }, // src IP
+ ipi6_ifindex: 0, // zone id
+ },
+ }),
+ }
}
fn into_address(&self) -> SocketAddr {
- *self
+ match self {
+ LinuxEndpoint::V4(EndpointV4 { ref dst, .. }) => {
+ SocketAddr::V4(SocketAddrV4::new(
+ u32::from_be(dst.sin_addr.s_addr).into(), // IPv4 addr
+ u16::from_be(dst.sin_port), // convert back to native byte-order
+ ))
+ }
+ LinuxEndpoint::V6(EndpointV6 { ref dst, .. }) => SocketAddr::V6(SocketAddrV6::new(
+ u128::from_ne_bytes(dst.sin6_addr.s6_addr).into(), // IPv6 addr
+ u16::from_be(dst.sin6_port), // convert back to native byte-order
+ dst.sin6_flowinfo,
+ dst.sin6_scope_id,
+ )),
+ }
}
}
-impl Reader<SocketAddr> for LinuxUDP {
- type Error = io::Error;
+fn setsockopt<V: Sized>(
+ fd: RawFd,
+ level: libc::c_int,
+ name: libc::c_int,
+ value: &V,
+) -> Result<(), io::Error> {
+ let res = unsafe {
+ libc::setsockopt(
+ fd,
+ level,
+ name,
+ mem::transmute(value),
+ mem::size_of_val(value).try_into().unwrap(),
+ )
+ };
+ if res == 0 {
+ Ok(())
+ } else {
+ Err(io::Error::new(
+ io::ErrorKind::Other,
+ format!("Failed to set sockopt (res = {}, errno = {})", res, errno()),
+ ))
+ }
+}
+
+#[inline(always)]
+fn setsockopt_int(
+ fd: RawFd,
+ level: libc::c_int,
+ name: libc::c_int,
+ value: libc::c_int,
+) -> Result<(), io::Error> {
+ setsockopt(fd, level, name, &value)
+}
- fn read(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
- self.0.recv_from(buf)
+#[allow(non_snake_case)]
+const fn CMSG_ALIGN(len: usize) -> usize {
+ (((len) + mem::size_of::<u32>() - 1) & !(mem::size_of::<u32>() - 1))
+}
+
+#[allow(non_snake_case)]
+const fn CMSG_LEN(len: usize) -> usize {
+ CMSG_ALIGN(len + mem::size_of::<libc::cmsghdr>())
+}
+
+#[inline(always)]
+fn safe_cast<T, D>(v: &mut T) -> *mut D {
+ (v as *mut T) as *mut D
+}
+
+impl LinuxUDPReader {
+ fn read6(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> {
+ log::trace!(
+ "receive IPv6 packet (block), (fd {}, max-len {})",
+ fd,
+ buf.len()
+ );
+
+ let mut iovs: [libc::iovec; 1] = [libc::iovec {
+ iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void,
+ iov_len: buf.len(),
+ }];
+ let mut src: libc::sockaddr_in6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
+ let mut control: ControlHeaderV6 = unsafe { mem::MaybeUninit::uninit().assume_init() };
+ let mut hdr = libc::msghdr {
+ msg_name: safe_cast(&mut src),
+ msg_namelen: mem::size_of::<libc::sockaddr_in6> as u32,
+ msg_iov: iovs.as_mut_ptr(),
+ msg_iovlen: iovs.len(),
+ msg_control: safe_cast(&mut control),
+ msg_controllen: mem::size_of::<ControlHeaderV6>(),
+ msg_flags: 0,
+ };
+
+ debug_assert!(
+ hdr.msg_controllen
+ >= mem::size_of::<libc::cmsghdr>() + mem::size_of::<libc::in6_pktinfo>(),
+ );
+
+ let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) };
+
+ if len < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::NotConnected,
+ "failed to receive",
+ ));
+ }
+
+ Ok((
+ len.try_into().unwrap(),
+ LinuxEndpoint::V6(EndpointV6 {
+ info: control.info, // save pktinfo (sticky source)
+ dst: src, // our future destination is the source address
+ }),
+ ))
+ }
+
+ fn read4(fd: RawFd, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), io::Error> {
+ log::trace!(
+ "receive IPv4 packet (block), (fd {}, max-len {})",
+ fd,
+ buf.len()
+ );
+
+ let mut iovs: [libc::iovec; 1] = [libc::iovec {
+ iov_base: buf.as_mut_ptr() as *mut core::ffi::c_void,
+ iov_len: buf.len(),
+ }];
+ let mut src: libc::sockaddr_in = unsafe { mem::MaybeUninit::uninit().assume_init() };
+ let mut control: ControlHeaderV4 = unsafe { mem::MaybeUninit::uninit().assume_init() };
+ let mut hdr = libc::msghdr {
+ msg_name: safe_cast(&mut src),
+ msg_namelen: mem::size_of::<libc::sockaddr_in> as u32,
+ msg_iov: iovs.as_mut_ptr(),
+ msg_iovlen: iovs.len(),
+ msg_control: safe_cast(&mut control),
+ msg_controllen: mem::size_of::<ControlHeaderV4>(),
+ msg_flags: 0,
+ };
+
+ debug_assert!(
+ hdr.msg_controllen
+ >= mem::size_of::<libc::cmsghdr>() + mem::size_of::<libc::in_pktinfo>(),
+ );
+
+ let len = unsafe { libc::recvmsg(fd, &mut hdr as *mut libc::msghdr, 0) };
+
+ if len < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::NotConnected,
+ "failed to receive",
+ ));
+ }
+
+ Ok((
+ len.try_into().unwrap(),
+ LinuxEndpoint::V4(EndpointV4 {
+ info: control.info, // save pktinfo (sticky source)
+ dst: src, // our future destination is the source address
+ }),
+ ))
}
}
-impl Writer<SocketAddr> for LinuxUDP {
+impl Reader<LinuxEndpoint> for LinuxUDPReader {
type Error = io::Error;
- fn write(&self, buf: &[u8], dst: &SocketAddr) -> Result<(), Self::Error> {
- self.0.send_to(buf, dst)?;
+ fn read(&self, buf: &mut [u8]) -> Result<(usize, LinuxEndpoint), Self::Error> {
+ match self {
+ Self::V4(fd) => Self::read4(*fd, buf),
+ Self::V6(fd) => Self::read6(*fd, buf),
+ }
+ }
+}
+
+impl LinuxUDPWriter {
+ fn write6(fd: RawFd, buf: &[u8], dst: &mut EndpointV6) -> Result<(), io::Error> {
+ log::debug!("sending IPv6 packet ({} fd, {} bytes)", fd, buf.len());
+
+ let mut iovs: [libc::iovec; 1] = [libc::iovec {
+ iov_base: buf.as_ptr() as *mut core::ffi::c_void,
+ iov_len: buf.len(),
+ }];
+
+ let mut control = ControlHeaderV6 {
+ hdr: libc::cmsghdr {
+ cmsg_len: CMSG_LEN(mem::size_of::<libc::in6_pktinfo>()),
+ cmsg_level: libc::IPPROTO_IPV6,
+ cmsg_type: libc::IPV6_PKTINFO,
+ },
+ info: dst.info,
+ };
+
+ debug_assert_eq!(
+ control.hdr.cmsg_len % mem::size_of::<u32>(),
+ 0,
+ "cmsg_len must be aligned to a long"
+ );
+
+ debug_assert_eq!(
+ dst.dst.sin6_family,
+ libc::AF_INET6 as libc::sa_family_t,
+ "this method only handles IPv6 destinations"
+ );
+
+ let mut hdr = libc::msghdr {
+ msg_name: safe_cast(&mut dst.dst),
+ msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(),
+ msg_iov: iovs.as_mut_ptr(),
+ msg_iovlen: iovs.len(),
+ msg_control: safe_cast(&mut control),
+ msg_controllen: mem::size_of_val(&control),
+ msg_flags: 0,
+ };
+
+ let ret = unsafe { libc::sendmsg(fd, &hdr, 0) };
+
+ if ret < 0 {
+ if errno() == libc::EINVAL {
+ log::trace!("clear source and retry");
+ hdr.msg_control = ptr::null_mut();
+ hdr.msg_controllen = 0;
+ dst.info = unsafe { mem::zeroed() };
+ if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::NotConnected,
+ "failed to send IPv6 packet",
+ ));
+ } else {
+ return Ok(());
+ }
+ }
+ return Err(io::Error::new(
+ io::ErrorKind::NotConnected,
+ "failed to send IPv6 packet",
+ ));
+ }
+
+ Ok(())
+ }
+
+ fn write4(fd: RawFd, buf: &[u8], dst: &mut EndpointV4) -> Result<(), io::Error> {
+ log::debug!("sending IPv4 packet ({} fd, {} bytes)", fd, buf.len());
+
+ let mut iovs: [libc::iovec; 1] = [libc::iovec {
+ iov_base: buf.as_ptr() as *mut core::ffi::c_void,
+ iov_len: buf.len(),
+ }];
+
+ let mut control = ControlHeaderV4 {
+ hdr: libc::cmsghdr {
+ cmsg_len: CMSG_LEN(mem::size_of::<libc::in_pktinfo>()),
+ cmsg_level: libc::IPPROTO_IP,
+ cmsg_type: libc::IP_PKTINFO,
+ },
+ info: dst.info,
+ };
+
+ debug_assert_eq!(
+ control.hdr.cmsg_len % mem::size_of::<u32>(),
+ 0,
+ "cmsg_len must be aligned to a long"
+ );
+
+ debug_assert_eq!(
+ dst.dst.sin_family,
+ libc::AF_INET as libc::sa_family_t,
+ "this method only handles IPv4 destinations"
+ );
+
+ let mut hdr = libc::msghdr {
+ msg_name: safe_cast(&mut dst.dst),
+ msg_namelen: mem::size_of_val(&dst.dst).try_into().unwrap(),
+ msg_iov: iovs.as_mut_ptr(),
+ msg_iovlen: iovs.len(),
+ msg_control: safe_cast(&mut control),
+ msg_controllen: mem::size_of_val(&control),
+ msg_flags: 0,
+ };
+
+ let ret = unsafe { libc::sendmsg(fd, &hdr, 0) };
+
+ if ret < 0 {
+ if errno() == libc::EINVAL {
+ log::trace!("clear source and retry");
+ hdr.msg_control = ptr::null_mut();
+ hdr.msg_controllen = 0;
+ dst.info = unsafe { mem::zeroed() };
+ if unsafe { libc::sendmsg(fd, &hdr, 0) } < 0 {
+ return Err(io::Error::new(
+ io::ErrorKind::NotConnected,
+ "failed to send IPv4 packet",
+ ));
+ } else {
+ return Ok(());
+ }
+ }
+ return Err(io::Error::new(
+ io::ErrorKind::NotConnected,
+ "failed to send IPv4 packet",
+ ));
+ }
+
Ok(())
}
}
-impl Owner for LinuxOwner {
+impl Writer<LinuxEndpoint> for LinuxUDPWriter {
type Error = io::Error;
- fn get_port(&self) -> u16 {
- self.0.local_addr().unwrap().port() // todo handle
+ fn write(&self, buf: &[u8], dst: &mut LinuxEndpoint) -> Result<(), Self::Error> {
+ match dst {
+ LinuxEndpoint::V4(ref mut end) => Self::write4(self.sock4, buf, end),
+ LinuxEndpoint::V6(ref mut end) => Self::write6(self.sock6, buf, end),
+ }
}
+}
- fn get_fwmark(&self) -> Option<u32> {
- None
+impl Owner for LinuxOwner {
+ type Error = io::Error;
+
+ fn get_port(&self) -> u16 {
+ self.port
}
- fn set_fwmark(&mut self, _value: Option<u32>) -> Result<(), Self::Error> {
- Ok(())
+ fn set_fwmark(&mut self, value: Option<u32>) -> Result<(), Self::Error> {
+ fn set_mark(fd: Option<RawFd>, value: u32) -> Result<(), io::Error> {
+ if let Some(fd) = fd {
+ setsockopt(fd, libc::SOL_SOCKET, libc::SO_MARK, &value)
+ } else {
+ Ok(())
+ }
+ }
+ let value = value.unwrap_or(0);
+ set_mark(self.sock6, value)?;
+ set_mark(self.sock4, value)
}
}
impl Drop for LinuxOwner {
fn drop(&mut self) {
- // TODO: close udp bind
+ log::trace!("closing the bind (port {})", self.port);
+ self.sock4.map(|fd| unsafe {
+ libc::shutdown(fd, libc::SHUT_RDWR);
+ libc::close(fd)
+ });
+ self.sock6.map(|fd| unsafe {
+ libc::shutdown(fd, libc::SHUT_RDWR);
+ libc::close(fd)
+ });
}
}
impl UDP for LinuxUDP {
type Error = io::Error;
- type Endpoint = SocketAddr;
- type Reader = Self;
- type Writer = Self;
+ type Endpoint = LinuxEndpoint;
+ type Reader = LinuxUDPReader;
+ type Writer = LinuxUDPWriter;
+}
+
+impl LinuxUDP {
+ /* Bind on all IPv6 interfaces
+ *
+ * Arguments:
+ *
+ * - 'port', port to bind to (0 = any)
+ *
+ * Returns:
+ *
+ * Returns a tuple of the resulting port and socket.
+ */
+ fn bind6(port: u16) -> Result<(u16, RawFd), io::Error> {
+ log::trace!("attempting to bind on IPv6 (port {})", port);
+
+ // create socket fd
+ let fd: RawFd = unsafe { libc::socket(libc::AF_INET6, libc::SOCK_DGRAM, 0) };
+ if fd < 0 {
+ log::debug!("failed to create IPv6 socket");
+ return Err(io::Error::new(
+ io::ErrorKind::Other,
+ "failed to create socket",
+ ));
+ }
+
+ setsockopt_int(fd, libc::SOL_SOCKET, libc::SO_REUSEADDR, 1)?;
+ setsockopt_int(fd, libc::IPPROTO_IPV6, libc::IPV6_RECVPKTINFO, 1)?;
+ setsockopt_int(fd, libc::IPPROTO_IPV6, libc::IPV6_V6ONLY, 1)?;
+
+ // bind
+ let mut sockaddr = libc::sockaddr_in6 {
+ sin6_addr: libc::in6_addr { s6_addr: [0; 16] },
+ sin6_family: libc::AF_INET6 as libc::sa_family_t,
+ sin6_port: port.to_be(), // convert to network (big-endian) byteorder
+ sin6_scope_id: 0,
+ sin6_flowinfo: 0,
+ };
+
+ let err = unsafe {
+ libc::bind(
+ fd,
+ mem::transmute(&sockaddr as *const libc::sockaddr_in6),
+ mem::size_of_val(&sockaddr).try_into().unwrap(),
+ )
+ };
+
+ if err != 0 {
+ log::debug!("failed to bind IPv6 socket");
+ return Err(io::Error::new(
+ io::ErrorKind::Other,
+ "failed to create socket",
+ ));
+ }
+
+ // get the assigned port
+ let mut socklen: libc::socklen_t = mem::size_of_val(&sockaddr).try_into().unwrap();
+ let err = unsafe {
+ libc::getsockname(
+ fd,
+ mem::transmute(&mut sockaddr as *mut libc::sockaddr_in6),
+ &mut socklen as *mut libc::socklen_t,
+ )
+ };
+ if err != 0 {
+ log::debug!("failed to get port of IPv6 socket");
+ return Err(io::Error::new(
+ io::ErrorKind::Other,
+ "failed to create socket",
+ ));
+ }
+
+ // basic sanity checks
+ let new_port = u16::from_be(sockaddr.sin6_port);
+ debug_assert_eq!(socklen, mem::size_of::<libc::sockaddr_in6>() as u32);
+ debug_assert_eq!(sockaddr.sin6_family, libc::AF_INET6 as libc::sa_family_t);
+ debug_assert_eq!(new_port, if port != 0 { port } else { new_port });
+ log::trace!("bound IPv6 socket (port {}, fd {})", new_port, fd);
+ return Ok((new_port, fd));
+ }
+
+ /* Bind on all IPv4 interfaces.
+ *
+ * Arguments:
+ *
+ * - 'port', port to bind to (0 = any)
+ *
+ * Returns:
+ *
+ * Returns a tuple of the resulting port and socket.
+ */
+ fn bind4(port: u16) -> Result<(u16, RawFd), io::Error> {
+ log::trace!("attempting to bind on IPv4 (port {})", port);
+
+ // create socket fd
+ let fd: RawFd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_DGRAM, 0) };
+ if fd < 0 {
+ log::trace!("failed to create IPv4 socket (errno = {})", errno());
+ return Err(io::Error::new(
+ io::ErrorKind::Other,
+ "failed to create socket",
+ ));
+ }
+
+ setsockopt_int(fd, libc::SOL_SOCKET, libc::SO_REUSEADDR, 1)?;
+ setsockopt_int(fd, libc::IPPROTO_IP, libc::IP_PKTINFO, 1)?;
+
+ const INADDR_ANY: libc::in_addr = libc::in_addr { s_addr: 0 };
+
+ // bind
+ let mut sockaddr = libc::sockaddr_in {
+ sin_addr: INADDR_ANY,
+ sin_family: libc::AF_INET as libc::sa_family_t,
+ sin_port: port.to_be(),
+ sin_zero: [0; 8],
+ };
+
+ let err = unsafe {
+ libc::bind(
+ fd,
+ mem::transmute(&sockaddr as *const libc::sockaddr_in),
+ mem::size_of_val(&sockaddr).try_into().unwrap(),
+ )
+ };
+
+ if err != 0 {
+ log::trace!("failed to bind IPv4 socket (errno = {})", errno());
+ return Err(io::Error::new(
+ io::ErrorKind::Other,
+ "failed to create socket",
+ ));
+ }
+
+ // get the assigned port
+ let mut socklen: libc::socklen_t = mem::size_of_val(&sockaddr).try_into().unwrap();
+ let err = unsafe {
+ libc::getsockname(
+ fd,
+ mem::transmute(&mut sockaddr as *mut libc::sockaddr_in),
+ &mut socklen as *mut libc::socklen_t,
+ )
+ };
+ if err != 0 {
+ log::trace!("failed to get port of IPv4 socket (errno = {})", errno());
+ return Err(io::Error::new(
+ io::ErrorKind::Other,
+ "failed to create socket",
+ ));
+ }
+
+ // basic sanity checks
+ let new_port = u16::from_be(sockaddr.sin_port);
+ debug_assert_eq!(socklen, mem::size_of::<libc::sockaddr_in>() as u32);
+ debug_assert_eq!(sockaddr.sin_family, libc::AF_INET as libc::sa_family_t);
+ debug_assert_eq!(new_port, if port != 0 { port } else { new_port });
+ log::trace!("bound IPv4 socket (port {}, fd {})", new_port, fd);
+ return Ok((new_port, fd));
+ }
}
impl PlatformUDP for LinuxUDP {
type Owner = LinuxOwner;
- fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> {
- let socket = UdpSocket::bind(format!("0.0.0.0:{}", port))?;
- let socket = Arc::new(socket);
+ fn bind(mut port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> {
+ log::debug!("bind to port {}", port);
- Ok((
- vec![LinuxUDP(socket.clone())],
- LinuxUDP(socket.clone()),
- LinuxOwner(socket),
- ))
+ // attempt to bind on ipv6
+ let bind6 = Self::bind6(port);
+ if let Ok((new_port, _)) = bind6 {
+ port = new_port;
+ }
+
+ // attempt to bind on ipv4 on the same port
+ let bind4 = Self::bind4(port);
+ if let Ok((new_port, _)) = bind4 {
+ port = new_port;
+ }
+
+ // check if failed to bind on both
+ if bind4.is_err() && bind6.is_err() {
+ log::trace!("failed to bind for either IP version");
+ return Err(bind6.unwrap_err());
+ }
+
+ let sock6 = bind6.ok().map(|(_, fd)| fd);
+ let sock4 = bind4.ok().map(|(_, fd)| fd);
+
+ // create owner
+ let owner = LinuxOwner {
+ port,
+ sock6: sock6,
+ sock4: sock4,
+ };
+
+ // create readers
+ let mut readers: Vec<Self::Reader> = Vec::with_capacity(2);
+ sock6.map(|sock| readers.push(LinuxUDPReader::V6(sock)));
+ sock4.map(|sock| readers.push(LinuxUDPReader::V4(sock)));
+ debug_assert!(readers.len() > 0);
+
+ // create writer
+ let writer = LinuxUDPWriter {
+ sock4: sock4.unwrap_or(-1),
+ sock6: sock6.unwrap_or(-1),
+ };
+
+ Ok((readers, writer, owner))
}
}
diff --git a/src/platform/udp.rs b/src/platform/udp.rs
index 3671229..e1180fb 100644
--- a/src/platform/udp.rs
+++ b/src/platform/udp.rs
@@ -10,7 +10,7 @@ pub trait Reader<E: Endpoint>: Send + Sync {
pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static {
type Error: Error;
- fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>;
+ fn write(&self, buf: &[u8], dst: &mut E) -> Result<(), Self::Error>;
}
pub trait UDP: Send + Sync + 'static {
@@ -30,8 +30,6 @@ pub trait Owner: Send {
fn get_port(&self) -> u16;
- fn get_fwmark(&self) -> Option<u32>;
-
fn set_fwmark(&mut self, value: Option<u32>) -> Result<(), Self::Error>;
}
diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs
index edd1a07..4b5d8f6 100644
--- a/src/wireguard/handshake/device.rs
+++ b/src/wireguard/handshake/device.rs
@@ -1,4 +1,5 @@
use spin::RwLock;
+use std::collections::hash_map;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Mutex;
@@ -6,7 +7,10 @@ use zerocopy::AsBytes;
use byteorder::{ByteOrder, LittleEndian};
-use rand::prelude::*;
+use rand::Rng;
+use rand_core::{CryptoRng, RngCore};
+
+use clear_on_drop::clear::Clear;
use x25519_dalek::PublicKey;
use x25519_dalek::StaticSecret;
@@ -22,42 +26,101 @@ use super::types::*;
const MAX_PEER_PER_DEVICE: usize = 1 << 20;
pub struct KeyState {
- pub sk: StaticSecret, // static secret key
- pub pk: PublicKey, // static public key
- macs: macs::Validator, // validator for the mac fields
+ pub(super) sk: StaticSecret, // static secret key
+ pub(super) pk: PublicKey, // static public key
+ macs: macs::Validator, // validator for the mac fields
}
-pub struct Device {
- keyst: Option<KeyState>, // secret/public key
- pk_map: HashMap<[u8; 32], Peer>, // public key -> peer state
- id_map: RwLock<HashMap<u32, [u8; 32]>>, // receiver ids -> public key
+/// The device is generic over an "opaque" type
+/// which can be used to associate the public key with this value.
+/// (the instance is a Peer object in the parent module)
+pub struct Device<O> {
+ keyst: Option<KeyState>,
+ id_map: RwLock<HashMap<u32, [u8; 32]>>,
+ pk_map: HashMap<[u8; 32], Peer<O>>,
limiter: Mutex<RateLimiter>,
}
+pub struct Iter<'a, O> {
+ iter: hash_map::Iter<'a, [u8; 32], Peer<O>>,
+}
+
+impl<'a, O> Iterator for Iter<'a, O> {
+ type Item = (PublicKey, &'a O);
+
+ fn next(&mut self) -> Option<Self::Item> {
+ self.iter
+ .next()
+ .map(|(pk, peer)| (PublicKey::from(*pk), &peer.opaque))
+ }
+}
+
+/* These methods enable the Device to act as a map
+ * from public keys to the set of contained opaque values.
+ *
+ * It also abstracts away the problem of PublicKey not being hashable.
+ */
+impl<O> Device<O> {
+ pub fn clear(&mut self) {
+ self.id_map.write().clear();
+ self.pk_map.clear();
+ }
+
+ pub fn len(&self) -> usize {
+ self.pk_map.len()
+ }
+
+ /// Enables enumeration of (public key, opaque) pairs
+ /// without exposing internal peer type.
+ pub fn iter(&self) -> Iter<O> {
+ Iter {
+ iter: self.pk_map.iter(),
+ }
+ }
+
+ /// Enables lookup by public key without exposing internal peer type.
+ pub fn get(&self, pk: &PublicKey) -> Option<&O> {
+ self.pk_map.get(pk.as_bytes()).map(|peer| &peer.opaque)
+ }
+
+ pub fn contains_key(&self, pk: &PublicKey) -> bool {
+ self.pk_map.contains_key(pk.as_bytes())
+ }
+}
+
/* A mutable reference to the device needs to be held during configuration.
* Wrapping the device in a RwLock enables peer config after "configuration time"
*/
-impl Device {
+impl<O> Device<O> {
/// Initialize a new handshake state machine
- pub fn new() -> Device {
+ pub fn new() -> Device<O> {
Device {
keyst: None,
- pk_map: HashMap::new(),
id_map: RwLock::new(HashMap::new()),
+ pk_map: HashMap::new(),
limiter: Mutex::new(RateLimiter::new()),
}
}
- fn update_ss(&self, peer: &mut Peer) -> Option<PublicKey> {
- if let Some(key) = self.keyst.as_ref() {
- if *peer.pk.as_bytes() == *key.pk.as_bytes() {
- return Some(peer.pk);
+ fn update_ss(&mut self) -> (Vec<u32>, Option<PublicKey>) {
+ let mut same = None;
+ let mut ids = Vec::with_capacity(self.pk_map.len());
+ for (pk, peer) in self.pk_map.iter_mut() {
+ if let Some(key) = self.keyst.as_ref() {
+ if key.pk.as_bytes() == pk {
+ same = Some(PublicKey::from(*pk));
+ peer.ss.clear()
+ } else {
+ let pk = PublicKey::from(*pk);
+ peer.ss = *key.sk.diffie_hellman(&pk).as_bytes();
+ }
+ } else {
+ peer.ss.clear();
}
- peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes();
- } else {
- peer.ss = [0u8; 32];
- };
- None
+ peer.reset_state().map(|id| ids.push(id));
+ }
+
+ (ids, same)
}
/// Update the secret key of the device
@@ -74,29 +137,15 @@ impl Device {
});
// recalculate / erase the shared secrets for every peer
- let mut ids = vec![];
- let mut same = None;
- for mut peer in self.pk_map.values_mut() {
- // clear any existing handshake state
- peer.reset_state().map(|id| ids.push(id));
-
- // update precomputed shared secret
- if let Some(key) = self.keyst.as_ref() {
- peer.ss = *key.sk.diffie_hellman(&peer.pk).as_bytes();
- if *peer.pk.as_bytes() == *key.pk.as_bytes() {
- same = Some(peer.pk)
- }
- } else {
- peer.ss = [0u8; 32];
- };
- }
+ let (ids, same) = self.update_ss();
// release ids from aborted handshakes
for id in ids {
self.release(id)
}
- // if we found a peer matching the device public key, remove it.
+ // if we found a peer matching the device public key
+ // remove it and return its value to the caller
same.map(|pk| {
self.pk_map.remove(pk.as_bytes());
pk
@@ -119,29 +168,32 @@ impl Device {
///
/// * `pk` - The public key to add
/// * `identifier` - Associated identifier which can be used to distinguish the peers
- pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> {
+ pub fn add(&mut self, pk: PublicKey, opaque: O) -> Result<(), ConfigError> {
// ensure less than 2^20 peers
if self.pk_map.len() > MAX_PEER_PER_DEVICE {
return Err(ConfigError::new("Too many peers for device"));
}
- // create peer and precompute static secret
- let mut peer = Peer::new(
- pk,
- self.keyst
- .as_ref()
- .map(|key| *key.sk.diffie_hellman(&pk).as_bytes())
- .unwrap_or([0u8; 32]),
- );
-
- // add peer to device
- match self.update_ss(&mut peer) {
- Some(_) => Err(ConfigError::new("Public key of peer matches the device")),
- None => {
- self.pk_map.insert(*pk.as_bytes(), peer);
- Ok(())
+ // error if public key matches device
+ if let Some(key) = self.keyst.as_ref() {
+ if pk.as_bytes() == key.pk.as_bytes() {
+ return Err(ConfigError::new("Public key of peer matches the device"));
}
}
+
+ // pre-compute shared secret and add to pk_map
+ self.pk_map.insert(
+ *pk.as_bytes(),
+ Peer::new(
+ pk,
+ self.keyst
+ .as_ref()
+ .map(|key| *key.sk.diffie_hellman(&pk).as_bytes())
+ .unwrap_or([0u8; 32]),
+ opaque,
+ ),
+ );
+ Ok(())
}
/// Remove a peer by public key
@@ -163,7 +215,7 @@ impl Device {
.remove(pk.as_bytes())
.ok_or(ConfigError::new("Public key not in device"))?;
- // pruge the id map (linear scan)
+ // purge the id map (linear scan)
id_map.retain(|_, v| v != pk.as_bytes());
Ok(())
}
@@ -231,11 +283,11 @@ impl Device {
(_, None) => Err(HandshakeError::UnknownPublicKey),
(None, _) => Err(HandshakeError::UnknownPublicKey),
(Some(keyst), Some(peer)) => {
- let local = self.allocate(rng, peer);
+ let local = self.allocate(rng, pk);
let mut msg = Initiation::default();
// create noise part of initation
- noise::create_initiation(rng, keyst, peer, local, &mut msg.noise)?;
+ noise::create_initiation(rng, keyst, peer, pk, local, &mut msg.noise)?;
// add macs to initation
peer.macs
@@ -253,11 +305,11 @@ impl Device {
///
/// * `msg` - Byte slice containing the message (untrusted input)
pub fn process<'a, R: RngCore + CryptoRng>(
- &self,
- rng: &mut R, // rng instance to sample randomness from
- msg: &[u8], // message buffer
+ &'a self,
+ rng: &mut R, // rng instance to sample randomness from
+ msg: &[u8], // message buffer
src: Option<SocketAddr>, // optional source endpoint, set when "under load"
- ) -> Result<Output, HandshakeError> {
+ ) -> Result<Output<'a, O>, HandshakeError> {
// ensure type read in-range
if msg.len() < 4 {
return Err(HandshakeError::InvalidMessageFormat);
@@ -303,17 +355,17 @@ impl Device {
}
// consume the initiation
- let (peer, st) = noise::consume_initiation(self, keyst, &msg.noise)?;
+ let (peer, pk, st) = noise::consume_initiation(self, keyst, &msg.noise)?;
// allocate new index for response
- let local = self.allocate(rng, peer);
+ let local = self.allocate(rng, &pk);
// prepare memory for response, TODO: take slice for zero allocation
let mut resp = Response::default();
// create response (release id on error)
- let keys =
- noise::create_response(rng, peer, local, st, &mut resp.noise).map_err(|e| {
+ let keys = noise::create_response(rng, peer, &pk, local, st, &mut resp.noise)
+ .map_err(|e| {
self.release(local);
e
})?;
@@ -324,7 +376,11 @@ impl Device {
.generate(resp.noise.as_bytes(), &mut resp.macs);
// return unconfirmed keypair and the response as vector
- Ok((Some(peer.pk), Some(resp.as_bytes().to_owned()), Some(keys)))
+ Ok((
+ Some(&peer.opaque),
+ Some(resp.as_bytes().to_owned()),
+ Some(keys),
+ ))
}
TYPE_RESPONSE => {
let msg = Response::parse(msg)?;
@@ -363,7 +419,7 @@ impl Device {
let msg = CookieReply::parse(msg)?;
// lookup peer
- let peer = self.lookup_id(msg.f_receiver.get())?;
+ let (peer, _) = self.lookup_id(msg.f_receiver.get())?;
// validate cookie reply
peer.macs.lock().process(&msg)?;
@@ -379,7 +435,7 @@ impl Device {
// Internal function
//
// Return the peer associated with the public key
- pub(crate) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer, HandshakeError> {
+ pub(super) fn lookup_pk(&self, pk: &PublicKey) -> Result<&Peer<O>, HandshakeError> {
self.pk_map
.get(pk.as_bytes())
.ok_or(HandshakeError::UnknownPublicKey)
@@ -388,11 +444,11 @@ impl Device {
// Internal function
//
// Return the peer currently associated with the receiver identifier
- pub(crate) fn lookup_id(&self, id: u32) -> Result<&Peer, HandshakeError> {
+ pub(super) fn lookup_id(&self, id: u32) -> Result<(&Peer<O>, PublicKey), HandshakeError> {
let im = self.id_map.read();
let pk = im.get(&id).ok_or(HandshakeError::UnknownReceiverId)?;
match self.pk_map.get(pk) {
- Some(peer) => Ok(peer),
+ Some(peer) => Ok((peer, PublicKey::from(*pk))),
_ => unreachable!(), // if the id-lookup succeeded, the peer should exist
}
}
@@ -400,7 +456,7 @@ impl Device {
// Internal function
//
// Allocated a new receiver identifier for the peer
- fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, peer: &Peer) -> u32 {
+ fn allocate<R: RngCore + CryptoRng>(&self, rng: &mut R, pk: &PublicKey) -> u32 {
loop {
let id = rng.gen();
@@ -412,7 +468,7 @@ impl Device {
// take write lock and add index
let mut m = self.id_map.write();
if !m.contains_key(&id) {
- m.insert(id, *peer.pk.as_bytes());
+ m.insert(id, *pk.as_bytes());
return id;
}
}
diff --git a/src/wireguard/handshake/macs.rs b/src/wireguard/handshake/macs.rs
index 689826b..cb5d7d4 100644
--- a/src/wireguard/handshake/macs.rs
+++ b/src/wireguard/handshake/macs.rs
@@ -286,8 +286,7 @@ mod tests {
use x25519_dalek::StaticSecret;
fn new_validator_generator() -> (Validator, Generator) {
- let mut rng = OsRng::new().unwrap();
- let sk = StaticSecret::new(&mut rng);
+ let sk = StaticSecret::new(&mut OsRng);
let pk = PublicKey::from(&sk);
(Validator::new(pk), Generator::new(pk))
}
@@ -296,7 +295,6 @@ mod tests {
#[test]
fn test_cookie_reply(inner1 : Vec<u8>, inner2 : Vec<u8>, receiver : u32) {
let mut msg = CookieReply::default();
- let mut rng = OsRng::new().expect("failed to create rng");
let mut macs = MacsFooter::default();
let src = "192.0.2.16:8080".parse().unwrap();
let (validator, mut generator) = new_validator_generator();
@@ -309,7 +307,7 @@ mod tests {
// check validity of mac1
validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate");
assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate");
- validator.create_cookie_reply(&mut rng, receiver, &src, &macs, &mut msg);
+ validator.create_cookie_reply(&mut OsRng, receiver, &src, &macs, &mut msg);
// consume cookie reply
generator.process(&msg).expect("failed to process CookieReply");
diff --git a/src/wireguard/handshake/noise.rs b/src/wireguard/handshake/noise.rs
index 072ac13..9e431cf 100644
--- a/src/wireguard/handshake/noise.rs
+++ b/src/wireguard/handshake/noise.rs
@@ -10,7 +10,7 @@ use hmac::Hmac;
use aead::{Aead, NewAead, Payload};
use chacha20poly1305::ChaCha20Poly1305;
-use rand::{CryptoRng, RngCore};
+use rand_core::{CryptoRng, RngCore};
use log::debug;
@@ -215,20 +215,21 @@ mod tests {
}
}
-pub fn create_initiation<R: RngCore + CryptoRng>(
+pub(super) fn create_initiation<R: RngCore + CryptoRng, O>(
rng: &mut R,
keyst: &KeyState,
- peer: &Peer,
+ peer: &Peer<O>,
+ pk: &PublicKey,
local: u32,
msg: &mut NoiseInitiation,
) -> Result<(), HandshakeError> {
- debug!("create initation");
+ debug!("create initiation");
clear_stack_on_return(CLEAR_PAGES, || {
// initialize state
let ck = INITIAL_CK;
let hs = INITIAL_HS;
- let hs = HASH!(&hs, peer.pk.as_bytes());
+ let hs = HASH!(&hs, pk.as_bytes());
msg.f_type.set(TYPE_INITIATION as u32);
msg.f_sender.set(local); // from us
@@ -252,7 +253,7 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
// (C, k) := Kdf2(C, DH(E_priv, S_pub))
- let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes());
+ let (ck, key) = KDF2!(&ck, eph_sk.diffie_hellman(&pk).as_bytes());
// msg.static := Aead(k, 0, S_pub, H)
@@ -297,12 +298,12 @@ pub fn create_initiation<R: RngCore + CryptoRng>(
})
}
-pub fn consume_initiation<'a>(
- device: &'a Device,
+pub(super) fn consume_initiation<'a, O>(
+ device: &'a Device<O>,
keyst: &KeyState,
msg: &NoiseInitiation,
-) -> Result<(&'a Peer, TemporaryState), HandshakeError> {
- debug!("consume initation");
+) -> Result<(&'a Peer<O>, PublicKey, TemporaryState), HandshakeError> {
+ debug!("consume initiation");
clear_stack_on_return(CLEAR_PAGES, || {
// initialize new state
@@ -369,13 +370,18 @@ pub fn consume_initiation<'a>(
// return state (to create response)
- Ok((peer, (msg.f_sender.get(), eph_r_pk, hs, ck)))
+ Ok((
+ peer,
+ PublicKey::from(pk),
+ (msg.f_sender.get(), eph_r_pk, hs, ck),
+ ))
})
}
-pub fn create_response<R: RngCore + CryptoRng>(
+pub(super) fn create_response<R: RngCore + CryptoRng, O>(
rng: &mut R,
- peer: &Peer,
+ peer: &Peer<O>,
+ pk: &PublicKey,
local: u32, // sending identifier
state: TemporaryState, // state from "consume_initiation"
msg: &mut NoiseResponse, // resulting response
@@ -388,7 +394,7 @@ pub fn create_response<R: RngCore + CryptoRng>(
msg.f_type.set(TYPE_RESPONSE as u32);
msg.f_sender.set(local); // from us
- msg.f_receiver.set(receiver); // to the sender of the initation
+ msg.f_receiver.set(receiver); // to the sender of the initiation
// (E_priv, E_pub) := DH-Generate()
@@ -413,7 +419,7 @@ pub fn create_response<R: RngCore + CryptoRng>(
// C := Kdf1(C, DH(E_priv, S_pub))
- let ck = KDF1!(&ck, eph_sk.diffie_hellman(&peer.pk).as_bytes());
+ let ck = KDF1!(&ck, eph_sk.diffie_hellman(&pk).as_bytes());
// (C, tau, k) := Kdf3(C, Q)
@@ -460,15 +466,15 @@ pub fn create_response<R: RngCore + CryptoRng>(
* allow concurrent processing of potential responses to the initiation,
* in order to better mitigate DoS from malformed response messages.
*/
-pub fn consume_response(
- device: &Device,
+pub(super) fn consume_response<'a, O>(
+ device: &'a Device<O>,
keyst: &KeyState,
msg: &NoiseResponse,
-) -> Result<Output, HandshakeError> {
+) -> Result<Output<'a, O>, HandshakeError> {
debug!("consume response");
clear_stack_on_return(CLEAR_PAGES, || {
// retrieve peer and copy initiation state
- let peer = device.lookup_id(msg.f_receiver.get())?;
+ let (peer, _) = device.lookup_id(msg.f_receiver.get())?;
let (hs, ck, local, eph_sk) = match *peer.state.lock() {
State::InitiationSent {
@@ -537,7 +543,7 @@ pub fn consume_response(
// return confirmed key-pair
Ok((
- Some(peer.pk),
+ Some(&peer.opaque),
None,
Some(KeyPair {
birth,
diff --git a/src/wireguard/handshake/peer.rs b/src/wireguard/handshake/peer.rs
index a4df560..f4d15fc 100644
--- a/src/wireguard/handshake/peer.rs
+++ b/src/wireguard/handshake/peer.rs
@@ -22,19 +22,21 @@ const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20);
*
* This type is only for internal use and not exposed.
*/
-pub struct Peer {
+pub(super) struct Peer<O> {
+ // opaque type which identifies a peer
+ pub opaque: O,
+
// mutable state
- pub(crate) state: Mutex<State>,
- pub(crate) timestamp: Mutex<Option<timestamp::TAI64N>>,
- pub(crate) last_initiation_consumption: Mutex<Option<Instant>>,
+ pub state: Mutex<State>,
+ pub timestamp: Mutex<Option<timestamp::TAI64N>>,
+ pub last_initiation_consumption: Mutex<Option<Instant>>,
// state related to DoS mitigation fields
- pub(crate) macs: Mutex<macs::Generator>,
+ pub macs: Mutex<macs::Generator>,
// constant state
- pub(crate) pk: PublicKey, // public key of peer
- pub(crate) ss: [u8; 32], // precomputed DH(static, static)
- pub(crate) psk: Psk, // psk of peer
+ pub ss: [u8; 32], // precomputed DH(static, static)
+ pub psk: Psk, // psk of peer
}
pub enum State {
@@ -60,14 +62,14 @@ impl Drop for State {
}
}
-impl Peer {
- pub fn new(pk: PublicKey, ss: [u8; 32]) -> Self {
+impl<O> Peer<O> {
+ pub fn new(pk: PublicKey, ss: [u8; 32], opaque: O) -> Self {
Self {
+ opaque,
macs: Mutex::new(macs::Generator::new(pk)),
state: Mutex::new(State::Reset),
timestamp: Mutex::new(None),
last_initiation_consumption: Mutex::new(None),
- pk,
ss,
psk: [0u8; 32],
}
@@ -88,7 +90,7 @@ impl Peer {
/// * ts_new - The associated timestamp
pub fn check_replay_flood(
&self,
- device: &Device,
+ device: &Device<O>,
timestamp_new: &timestamp::TAI64N,
) -> Result<(), HandshakeError> {
let mut state = self.state.lock();
diff --git a/src/wireguard/handshake/tests.rs b/src/wireguard/handshake/tests.rs
index ff27b3e..bfdc5ab 100644
--- a/src/wireguard/handshake/tests.rs
+++ b/src/wireguard/handshake/tests.rs
@@ -12,8 +12,10 @@ use x25519_dalek::StaticSecret;
use super::messages::{Initiation, Response};
-fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, PublicKey, Device) {
- // generate new keypairs
+fn setup_devices<R: RngCore + CryptoRng, O: Default>(
+ rng: &mut R,
+) -> (PublicKey, Device<O>, PublicKey, Device<O>) {
+ // generate new key pairs
let sk1 = StaticSecret::new(rng);
let pk1 = PublicKey::from(&sk1);
@@ -26,7 +28,7 @@ fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, Pub
let mut psk = [0u8; 32];
rng.fill_bytes(&mut psk[..]);
- // intialize devices on both ends
+ // initialize devices on both ends
let mut dev1 = Device::new();
let mut dev2 = Device::new();
@@ -34,8 +36,8 @@ fn setup_devices<R: RngCore + CryptoRng>(rng: &mut R) -> (PublicKey, Device, Pub
dev1.set_sk(Some(sk1));
dev2.set_sk(Some(sk2));
- dev1.add(pk2).unwrap();
- dev2.add(pk1).unwrap();
+ dev1.add(pk2, O::default()).unwrap();
+ dev2.add(pk1, O::default()).unwrap();
dev1.set_psk(pk2, psk).unwrap();
dev2.set_psk(pk1, psk).unwrap();
@@ -49,45 +51,44 @@ fn wait() {
/* Test longest possible handshake interaction (7 messages):
*
- * 1. I -> R (initation)
+ * 1. I -> R (initiation)
* 2. I <- R (cookie reply)
- * 3. I -> R (initation)
+ * 3. I -> R (initiation)
* 4. I <- R (response)
* 5. I -> R (cookie reply)
- * 6. I -> R (initation)
+ * 6. I -> R (initiation)
* 7. I <- R (response)
*/
#[test]
fn handshake_under_load() {
- let mut rng = OsRng::new().unwrap();
- let (_pk1, dev1, pk2, dev2) = setup_devices(&mut rng);
+ let (_pk1, dev1, pk2, dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng);
let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap();
let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap();
- // 1. device-1 : create first initation
- let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+ // 1. device-1 : create first initiation
+ let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap();
// 2. device-2 : responds with CookieReply
- let msg_cookie = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() {
+ let msg_cookie = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() {
(None, Some(msg), None) => msg,
_ => panic!("unexpected response"),
};
// device-1 : processes CookieReply (no response)
- match dev1.process(&mut rng, &msg_cookie, Some(src2)).unwrap() {
+ match dev1.process(&mut OsRng, &msg_cookie, Some(src2)).unwrap() {
(None, None, None) => (),
_ => panic!("unexpected response"),
}
- // avoid initation flood detection
+ // avoid initiation flood detection
wait();
- // 3. device-1 : create second initation
- let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+ // 3. device-1 : create second initiation
+ let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap();
// 4. device-2 : responds with noise response
- let msg_response = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() {
+ let msg_response = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() {
(Some(_), Some(msg), Some(kp)) => {
assert_eq!(kp.initiator, false);
msg
@@ -96,25 +97,25 @@ fn handshake_under_load() {
};
// 5. device-1 : responds with CookieReply
- let msg_cookie = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() {
+ let msg_cookie = match dev1.process(&mut OsRng, &msg_response, Some(src2)).unwrap() {
(None, Some(msg), None) => msg,
_ => panic!("unexpected response"),
};
// device-2 : processes CookieReply (no response)
- match dev2.process(&mut rng, &msg_cookie, Some(src1)).unwrap() {
+ match dev2.process(&mut OsRng, &msg_cookie, Some(src1)).unwrap() {
(None, None, None) => (),
_ => panic!("unexpected response"),
}
- // avoid initation flood detection
+ // avoid initiation flood detection
wait();
- // 6. device-1 : create third initation
- let msg_init = dev1.begin(&mut rng, &pk2).unwrap();
+ // 6. device-1 : create third initiation
+ let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap();
// 7. device-2 : responds with noise response
- let (msg_response, kp1) = match dev2.process(&mut rng, &msg_init, Some(src1)).unwrap() {
+ let (msg_response, kp1) = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() {
(Some(_), Some(msg), Some(kp)) => {
assert_eq!(kp.initiator, false);
(msg, kp)
@@ -123,7 +124,7 @@ fn handshake_under_load() {
};
// device-1 : process noise response
- let kp2 = match dev1.process(&mut rng, &msg_response, Some(src2)).unwrap() {
+ let kp2 = match dev1.process(&mut OsRng, &msg_response, Some(src2)).unwrap() {
(Some(_), None, Some(kp)) => {
assert_eq!(kp.initiator, true);
kp
@@ -137,8 +138,7 @@ fn handshake_under_load() {
#[test]
fn handshake_no_load() {
- let mut rng = OsRng::new().unwrap();
- let (pk1, mut dev1, pk2, mut dev2) = setup_devices(&mut rng);
+ let (pk1, mut dev1, pk2, mut dev2): (_, Device<usize>, _, _) = setup_devices(&mut OsRng);
// do a few handshakes (every handshake should succeed)
@@ -147,7 +147,7 @@ fn handshake_no_load() {
// create initiation
- let msg1 = dev1.begin(&mut rng, &pk2).unwrap();
+ let msg1 = dev1.begin(&mut OsRng, &pk2).unwrap();
println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len());
println!(
@@ -158,7 +158,7 @@ fn handshake_no_load() {
// process initiation and create response
let (_, msg2, ks_r) = dev2
- .process(&mut rng, &msg1, None)
+ .process(&mut OsRng, &msg1, None)
.expect("failed to process initiation");
let ks_r = ks_r.unwrap();
@@ -175,7 +175,7 @@ fn handshake_no_load() {
// process response and obtain confirmed key-pair
let (_, msg3, ks_i) = dev1
- .process(&mut rng, &msg2, None)
+ .process(&mut OsRng, &msg2, None)
.expect("failed to process response");
let ks_i = ks_i.unwrap();
@@ -188,7 +188,7 @@ fn handshake_no_load() {
dev1.release(ks_i.local_id());
dev2.release(ks_r.local_id());
- // avoid initation flood detection
+ // avoid initiation flood detection
wait();
}
diff --git a/src/wireguard/handshake/types.rs b/src/wireguard/handshake/types.rs
index 5f984cc..ed2fcbb 100644
--- a/src/wireguard/handshake/types.rs
+++ b/src/wireguard/handshake/types.rs
@@ -1,10 +1,8 @@
+use super::super::types::KeyPair;
+
use std::error::Error;
use std::fmt;
-use x25519_dalek::PublicKey;
-
-use super::super::types::KeyPair;
-
/* Internal types for the noise IKpsk2 implementation */
// config error
@@ -79,10 +77,10 @@ impl Error for HandshakeError {
}
}
-pub type Output = (
- Option<PublicKey>, // external identifier associated with peer
- Option<Vec<u8>>, // message to send
- Option<KeyPair>, // resulting key-pair of successful handshake
+pub type Output<'a, O> = (
+ Option<&'a O>, // external identifier associated with peer
+ Option<Vec<u8>>, // message to send
+ Option<KeyPair>, // resulting key-pair of successful handshake
);
// preshared key
diff --git a/src/wireguard/peer.rs b/src/wireguard/peer.rs
index 1af4df3..b3656fe 100644
--- a/src/wireguard/peer.rs
+++ b/src/wireguard/peer.rs
@@ -31,7 +31,7 @@ pub struct PeerInner<T: Tun, B: UDP> {
pub handshake_queued: AtomicBool, // is a handshake job currently queued for the peer?
// stats and configuration
- pub pk: PublicKey, // public key, DISCUSS: avoid this. TODO: remove
+ pub pk: PublicKey, // public key
pub rx_bytes: AtomicU64, // received bytes
pub tx_bytes: AtomicU64, // transmitted bytes
diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs
index f903a8e..6c59491 100644
--- a/src/wireguard/router/device.rs
+++ b/src/wireguard/router/device.rs
@@ -142,7 +142,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> DeviceHandle<
};
// start worker threads
- let mut threads = Vec::with_capacity(num_workers);
+ let mut threads = Vec::with_capacity(4 * num_workers);
// inbound/decryption workers
for _ in 0..num_workers {
diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs
index b8110f0..8fe2e1c 100644
--- a/src/wireguard/router/peer.rs
+++ b/src/wireguard/router/peer.rs
@@ -204,7 +204,7 @@ impl<E: Endpoint, C: Callbacks, T: tun::Writer, B: udp::Writer<E>> PeerInner<E,
debug!("peer.send");
// send to endpoint (if known)
- match self.endpoint.lock().as_ref() {
+ match self.endpoint.lock().as_mut() {
Some(endpoint) => {
let outbound = self.device.outbound.read();
if outbound.0 {
diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs
index bf550ef..ecbb9c1 100644
--- a/src/wireguard/wireguard.rs
+++ b/src/wireguard/wireguard.rs
@@ -21,9 +21,6 @@ use std::sync::Mutex as StdMutex;
use std::thread;
use std::time::Instant;
-use std::collections::hash_map::Entry;
-use std::collections::HashMap;
-
use hjul::Runner;
use rand::rngs::OsRng;
use rand::Rng;
@@ -50,14 +47,13 @@ pub struct WireguardInner<T: Tun, B: UDP> {
// outbound writer
pub send: RwLock<Option<B::Writer>>,
- // identity and configuration map
- pub peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
+ // peer map
+ pub peers: RwLock<handshake::Device<Peer<T, B>>>,
// cryptokey router
pub router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
// handshake related state
- pub handshake: RwLock<handshake::Device>,
pub last_under_load: Mutex<Instant>,
pub pending: AtomicUsize, // number of pending handshake packets in queue
pub queue: ParallelQueue<HandshakeJob<B::Endpoint>>,
@@ -142,7 +138,7 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
self.router.down();
// set all peers down (stops timers)
- for peer in self.peers.write().values() {
+ for (_, peer) in self.peers.write().iter() {
peer.down();
}
@@ -163,11 +159,11 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
return;
}
- // enable tranmission from router
+ // enable transmission from router
self.router.up();
// set all peers up (restarts timers)
- for peer in self.peers.write().values() {
+ for (_, peer) in self.peers.write().iter() {
peer.up();
}
@@ -179,54 +175,51 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
}
pub fn remove_peer(&self, pk: &PublicKey) {
- if self.handshake.write().remove(pk).is_ok() {
- self.peers.write().remove(pk.as_bytes());
- }
+ let _ = self.peers.write().remove(pk);
}
pub fn lookup_peer(&self, pk: &PublicKey) -> Option<Peer<T, B>> {
- self.peers.read().get(pk.as_bytes()).map(|p| p.clone())
+ self.peers.read().get(pk).map(|p| p.clone())
}
pub fn list_peers(&self) -> Vec<Peer<T, B>> {
let peers = self.peers.read();
let mut list = Vec::with_capacity(peers.len());
for (k, v) in peers.iter() {
- debug_assert!(k == v.pk.as_bytes());
+ debug_assert!(k.as_bytes() == v.pk.as_bytes());
list.push(v.clone());
}
list
}
pub fn set_key(&self, sk: Option<StaticSecret>) {
- let mut handshake = self.handshake.write();
- handshake.set_sk(sk);
+ let mut peers = self.peers.write();
+ peers.set_sk(sk);
self.router.clear_sending_keys();
- // handshake lock is released and new handshakes can be initated
}
pub fn get_sk(&self) -> Option<StaticSecret> {
- self.handshake
+ self.peers
.read()
.get_sk()
.map(|sk| StaticSecret::from(sk.to_bytes()))
}
pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool {
- self.handshake.write().set_psk(pk, psk).is_ok()
+ self.peers.write().set_psk(pk, psk).is_ok()
}
pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> {
- self.handshake.read().get_psk(pk).ok()
+ self.peers.read().get_psk(pk).ok()
}
pub fn add_peer(&self, pk: PublicKey) -> bool {
- if self.peers.read().contains_key(pk.as_bytes()) {
+ let mut peers = self.peers.write();
+ if peers.contains_key(&pk) {
return false;
}
- let mut rng = OsRng::new().unwrap();
let state = Arc::new(PeerInner {
- id: rng.gen(),
+ id: OsRng.gen(),
pk,
wg: self.clone(),
walltime_last_handshake: Mutex::new(None),
@@ -243,33 +236,19 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
// form WireGuard peer
let peer = Peer { router, state };
+ // prevent up/down while inserting
+ let enabled = self.enabled.read();
+
+ /* The need for dummy timers arises from the chicken-egg
+ * problem of the timer callbacks being able to set timers themselves.
+ *
+ * This is in fact the only place where the write lock is ever taken.
+ * TODO: Consider the ease of using atomic pointers instead.
+ */
+ *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone());
+
// finally, add the peer to the wireguard device
- let mut peers = self.peers.write();
- match peers.entry(*pk.as_bytes()) {
- Entry::Occupied(_) => false,
- Entry::Vacant(vacancy) => {
- // check that the public key does not cause conflict with the private key of the device
- let ok_pk = self.handshake.write().add(pk).is_ok();
- if !ok_pk {
- return false;
- }
-
- // prevent up/down while inserting
- let enabled = self.enabled.read();
-
- /* The need for dummy timers arises from the chicken-egg
- * problem of the timer callbacks being able to set timers themselves.
- *
- * This is in fact the only place where the write lock is ever taken.
- * TODO: Consider the ease of using atomic pointers instead.
- */
- *peer.timers.write() = Timers::new(&*self.runner.lock(), *enabled, peer.clone());
-
- // insert into peer map (takes ownership and ensures that the peer is not dropped)
- vacancy.insert(peer);
- true
- }
- }
+ peers.add(pk, peer).is_ok()
}
/// Begin consuming messages from the reader.
@@ -311,9 +290,6 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
// workers equal to number of physical cores
let cpus = num_cpus::get();
- // create device state
- let mut rng = OsRng::new().unwrap();
-
// create handshake queue
let (tx, mut rxs) = ParallelQueue::new(cpus, 128);
@@ -322,14 +298,13 @@ impl<T: Tun, B: UDP> WireGuard<T, B> {
inner: Arc::new(WireguardInner {
enabled: RwLock::new(false),
tun_readers: WaitCounter::new(),
- id: rng.gen(),
+ id: OsRng.gen(),
mtu: AtomicUsize::new(0),
- peers: RwLock::new(HashMap::new()),
last_under_load: Mutex::new(Instant::now() - TIME_HORIZON),
send: RwLock::new(None),
router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
pending: AtomicUsize::new(0),
- handshake: RwLock::new(handshake::Device::new()),
+ peers: RwLock::new(handshake::Device::new()),
runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)),
queue: tx,
}),
diff --git a/src/wireguard/workers.rs b/src/wireguard/workers.rs
index e1d3899..c1a2af7 100644
--- a/src/wireguard/workers.rs
+++ b/src/wireguard/workers.rs
@@ -152,9 +152,6 @@ pub fn handshake_worker<T: Tun, B: UDP>(
) {
debug!("{} : handshake worker, started", wg);
- // prepare OsRng instance for this thread
- let mut rng = OsRng::new().expect("Unable to obtain a CSPRNG");
-
// process elements from the handshake queue
for job in rx {
// check if under load
@@ -181,11 +178,11 @@ pub fn handshake_worker<T: Tun, B: UDP>(
// de-multiplex staged handshake jobs and handshake messages
match job {
- HandshakeJob::Message(msg, src) => {
+ HandshakeJob::Message(msg, mut src) => {
// process message
- let device = wg.handshake.read();
+ let device = wg.peers.read();
match device.process(
- &mut rng,
+ &mut OsRng,
&msg[..],
if under_load {
Some(src.into_address())
@@ -193,7 +190,7 @@ pub fn handshake_worker<T: Tun, B: UDP>(
None
},
) {
- Ok((pk, resp, keypair)) => {
+ Ok((peer, resp, keypair)) => {
// send response (might be cookie reply or handshake response)
let mut resp_len: u64 = 0;
if let Some(msg) = resp {
@@ -204,7 +201,7 @@ pub fn handshake_worker<T: Tun, B: UDP>(
"{} : handshake worker, send response ({} bytes)",
wg, resp_len
);
- let _ = writer.write(&msg[..], &src).map_err(|e| {
+ let _ = writer.write(&msg[..], &mut src).map_err(|e| {
debug!(
"{} : handshake worker, failed to send response, error = {}",
wg,
@@ -215,56 +212,55 @@ pub fn handshake_worker<T: Tun, B: UDP>(
}
// update peer state
- if let Some(pk) = pk {
+ if let Some(peer) = peer {
// authenticated handshake packet received
- if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
- // add to rx_bytes and tx_bytes
- let req_len = msg.len() as u64;
- peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed);
- peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed);
- // update endpoint
- peer.router.set_endpoint(src);
+ // add to rx_bytes and tx_bytes
+ let req_len = msg.len() as u64;
+ peer.rx_bytes.fetch_add(req_len, Ordering::Relaxed);
+ peer.tx_bytes.fetch_add(resp_len, Ordering::Relaxed);
- if resp_len > 0 {
- // update timers after sending handshake response
- debug!("{} : handshake worker, handshake response sent", wg);
- peer.state.sent_handshake_response();
- } else {
- // update timers after receiving handshake response
- debug!(
- "{} : handshake worker, handshake response was received",
- wg
- );
- peer.state.timers_handshake_complete();
- }
+ // update endpoint
+ peer.router.set_endpoint(src);
+
+ if resp_len > 0 {
+ // update timers after sending handshake response
+ debug!("{} : handshake worker, handshake response sent", wg);
+ peer.state.sent_handshake_response();
+ } else {
+ // update timers after receiving handshake response
+ debug!(
+ "{} : handshake worker, handshake response was received",
+ wg
+ );
+ peer.state.timers_handshake_complete();
+ }
- // add any new keypair to peer
- keypair.map(|kp| {
- debug!("{} : handshake worker, new keypair for {}", wg, peer);
+ // add any new keypair to peer
+ keypair.map(|kp| {
+ debug!("{} : handshake worker, new keypair for {}", wg, peer);
- // this means that a handshake response was processed or sent
- peer.timers_session_derived();
+ // this means that a handshake response was processed or sent
+ peer.timers_session_derived();
- // free any unused ids
- for id in peer.router.add_keypair(kp) {
- device.release(id);
- }
- });
- }
+ // free any unused ids
+ for id in peer.router.add_keypair(kp) {
+ device.release(id);
+ }
+ });
}
}
Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e),
}
}
HandshakeJob::New(pk) => {
- if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
+ if let Some(peer) = wg.peers.read().get(&pk) {
debug!(
"{} : handshake worker, new handshake requested for {}",
wg, peer
);
- let device = wg.handshake.read();
- let _ = device.begin(&mut rng, &peer.pk).map(|msg| {
+ let device = wg.peers.read();
+ let _ = device.begin(&mut OsRng, &peer.pk).map(|msg| {
let _ = peer.router.send(&msg[..]).map_err(|e| {
debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e)
});