summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-10 21:42:21 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-10 21:42:21 +0200
commit02d9bdcc96c955b654a45d3629b1ee515902078d (patch)
treed0231989ccca424d26f0dcded640acac079aa9de
parentBegin work on full router interaction unittest (diff)
downloadwireguard-rs-02d9bdcc96c955b654a45d3629b1ee515902078d.tar.xz
wireguard-rs-02d9bdcc96c955b654a45d3629b1ee515902078d.zip
Full inbound/outbound router test
-rw-r--r--Cargo.lock28
-rw-r--r--Cargo.toml1
-rw-r--r--src/handshake/messages.rs1
-rw-r--r--src/main.rs129
-rw-r--r--src/router/device.rs45
-rw-r--r--src/router/messages.rs4
-rw-r--r--src/router/mod.rs1
-rw-r--r--src/router/peer.rs129
-rw-r--r--src/router/tests.rs221
-rw-r--r--src/router/types.rs8
-rw-r--r--src/router/workers.rs72
11 files changed, 392 insertions, 247 deletions
diff --git a/Cargo.lock b/Cargo.lock
index cd73aa4..8d04e33 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -288,6 +288,11 @@ version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
+name = "fs_extra"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+
+[[package]]
name = "fuchsia-cprng"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -408,6 +413,25 @@ dependencies = [
]
[[package]]
+name = "jemalloc-sys"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "cc 1.0.40 (registry+https://github.com/rust-lang/crates.io-index)",
+ "fs_extra 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
+ "libc 0.2.62 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
+name = "jemallocator"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+dependencies = [
+ "jemalloc-sys 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
+ "libc 0.2.62 (registry+https://github.com/rust-lang/crates.io-index)",
+]
+
+[[package]]
name = "js-sys"
version = "0.3.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1567,6 +1591,7 @@ dependencies = [
"hex 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
"hjul 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"hmac 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
+ "jemallocator 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
"lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
"num_cpus 1.10.1 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -1663,6 +1688,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum failure 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "795bd83d3abeb9220f257e597aa0080a508b27533824adf336529648f6abf7e2"
"checksum failure_derive 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)" = "ea1063915fd7ef4309e222a5a07cf9c319fb9c7836b1f89b85458672dbb127e1"
"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3"
+"checksum fs_extra 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5f2a4a2034423744d2cc7ca2068453168dcdb82c438419e639a26bd87839c674"
"checksum fuchsia-cprng 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba"
"checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82"
"checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
@@ -1679,6 +1705,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum humantime 1.2.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3ca7e5f2e110db35f93b837c81797f3714500b81d517bf20c431b16d3ca4f114"
"checksum iovec 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "dbe6e417e7d0975db6512b90796e8ce223145ac4e33c377e4a42882a0e88bb08"
"checksum ipnetwork 0.14.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b3d862c86f7867f19b693ec86765e0252d82e53d4240b9b629815675a0714ad1"
+"checksum jemalloc-sys 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "0d3b9f3f5c9b31aa0f5ed3260385ac205db665baa41d49bb8338008ae94ede45"
+"checksum jemallocator 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "43ae63fcfc45e99ab3d1b29a46782ad679e98436c3169d15a167a1108a724b69"
"checksum js-sys 0.3.27 (registry+https://github.com/rust-lang/crates.io-index)" = "1efc4f2a556c58e79c5500912e221dd826bec64ff4aabd8ce71ccef6da02d7d4"
"checksum kernel32-sys 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7507624b29483431c0ba2d82aece8ca6cdba9382bff4ddd0f7490560c056098d"
"checksum lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
diff --git a/Cargo.toml b/Cargo.toml
index 7ea8b67..caa6d20 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -31,6 +31,7 @@ parking_lot = "^0.9"
futures-channel = "^0.2"
env_logger = "0.6"
num_cpus = "^1.10"
+jemallocator = "0.3.0"
[dependencies.x25519-dalek]
version = "^0.5"
diff --git a/src/handshake/messages.rs b/src/handshake/messages.rs
index 6dca413..07c2b1a 100644
--- a/src/handshake/messages.rs
+++ b/src/handshake/messages.rs
@@ -8,7 +8,6 @@ use byteorder::LittleEndian;
use zerocopy::byteorder::U32;
use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
-use super::timestamp;
use super::types::*;
const SIZE_MAC: usize = 16;
diff --git a/src/main.rs b/src/main.rs
index 8d92048..53b2a51 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,130 +1,13 @@
#![feature(test)]
+extern crate jemallocator;
+
+#[global_allocator]
+static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
+
mod constants;
mod handshake;
mod router;
mod types;
-use hjul::*;
-
-use std::error::Error;
-use std::fmt;
-use std::net::SocketAddr;
-use std::time::Duration;
-
-use types::{Bind, Tun};
-
-#[derive(Debug)]
-enum TunError {}
-
-impl Error for TunError {
- fn description(&self) -> &str {
- "Generic Tun Error"
- }
-
- fn source(&self) -> Option<&(dyn Error + 'static)> {
- None
- }
-}
-
-impl fmt::Display for TunError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "Not Possible")
- }
-}
-
-struct TunTest {}
-
-impl Tun for TunTest {
- type Error = TunError;
-
- fn mtu(&self) -> usize {
- 1500
- }
-
- fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
- Ok(0)
- }
-
- fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
- Ok(())
- }
-}
-
-struct BindTest {}
-
-impl Bind for BindTest {
- type Error = BindError;
- type Endpoint = SocketAddr;
-
- fn new() -> BindTest {
- BindTest {}
- }
-
- fn set_port(&self, port: u16) -> Result<(), Self::Error> {
- Ok(())
- }
-
- fn get_port(&self) -> Option<u16> {
- None
- }
-
- fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
- Ok((0, "127.0.0.1:8080".parse().unwrap()))
- }
-
- fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> {
- Ok(())
- }
-}
-
-#[derive(Debug)]
-enum BindError {}
-
-impl Error for BindError {
- fn description(&self) -> &str {
- "Generic Bind Error"
- }
-
- fn source(&self) -> Option<&(dyn Error + 'static)> {
- None
- }
-}
-
-impl fmt::Display for BindError {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "Not Possible")
- }
-}
-
-#[derive(Debug, Clone)]
-struct PeerTimer {
- a: Timer,
- b: Timer,
-}
-
-fn main() {
- let runner = Runner::new(Duration::from_millis(100), 1000, 1024);
-
- {
- let router = router::Device::new(
- 4,
- TunTest {},
- BindTest {},
- |t: &PeerTimer, data: bool, sent: bool| t.a.reset(Duration::from_millis(1000)),
- |t: &PeerTimer, data: bool, sent: bool| t.b.reset(Duration::from_millis(1000)),
- |t: &PeerTimer| println!("new key requested"),
- );
-
- let pt = PeerTimer {
- a: runner.timer(|| println!("timer-a fired for peer")),
- b: runner.timer(|| println!("timer-b fired for peer")),
- };
-
- let peer = router.new_peer(pt.clone());
-
- println!("{:?}", pt);
- }
-
- println!("joined");
-}
+fn main() {}
diff --git a/src/router/device.rs b/src/router/device.rs
index 73678cb..703fa55 100644
--- a/src/router/device.rs
+++ b/src/router/device.rs
@@ -16,8 +16,7 @@ use super::anti_replay::AntiReplay;
use super::constants::*;
use super::ip::*;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
-use super::peer;
-use super::peer::{Peer, PeerInner};
+use super::peer::{new_peer, Peer, PeerInner};
use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
use super::workers::{worker_parallel, JobParallel, Operation};
use super::SIZE_MESSAGE_PREFIX;
@@ -36,6 +35,10 @@ pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
+
+ // work queues
+ pub queue_next: AtomicUsize, // next round-robin index
+ pub queues: Mutex<Vec<SyncSender<JobParallel>>>, // work queues (1 per thread)
}
pub struct EncryptionState {
@@ -56,14 +59,15 @@ pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
pub struct Device<C: Callbacks, T: Tun, B: Bind> {
state: Arc<DeviceInner<C, T, B>>, // reference to device state
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
- queue_next: AtomicUsize, // next round-robin index
- queues: Vec<Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread)
}
impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
fn drop(&mut self) {
// drop all queues
- while self.queues.pop().is_some() {}
+ {
+ let mut queues = self.state.queues.lock();
+ while queues.pop().is_some() {}
+ }
// join all worker threads
while match self.handles.pop() {
@@ -91,32 +95,31 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
call_need_key: K,
) -> Device<PhantomCallbacks<O, R, S, K>, T, B> {
// allocate shared device state
- let inner = Arc::new(DeviceInner {
+ let mut inner = DeviceInner {
tun,
bind,
call_recv,
call_send,
+ queues: Mutex::new(Vec::with_capacity(num_workers)),
+ queue_next: AtomicUsize::new(0),
call_need_key,
recv: RwLock::new(HashMap::new()),
ipv4: RwLock::new(IpLookupTable::new()),
ipv6: RwLock::new(IpLookupTable::new()),
- });
+ };
// start worker threads
- let mut queues = Vec::with_capacity(num_workers);
let mut threads = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
- queues.push(Mutex::new(tx));
+ inner.queues.lock().push(tx);
threads.push(thread::spawn(move || worker_parallel(rx)));
}
// return exported device handle
Device {
- state: inner,
+ state: Arc::new(inner),
handles: threads,
- queue_next: AtomicUsize::new(0),
- queues: queues,
}
}
}
@@ -168,7 +171,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
///
/// A atomic ref. counted peer (with liftime matching the device)
pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T, B> {
- peer::new_peer(self.state.clone(), opaque)
+ new_peer(self.state.clone(), opaque)
}
/// Cryptkey routes and sends a plaintext message (IP packet)
@@ -189,11 +192,9 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
debug_assert_eq!(job.1.op, Operation::Encryption);
// add job to worker queue
- let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
- self.queues[idx % self.queues.len()]
- .lock()
- .send(job)
- .unwrap();
+ let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
+ let queues = self.state.queues.lock();
+ queues[idx % queues.len()].send(job).unwrap();
}
Ok(())
@@ -234,11 +235,9 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
debug_assert_eq!(job.1.op, Operation::Decryption);
// add job to worker queue
- let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
- self.queues[idx % self.queues.len()]
- .lock()
- .send(job)
- .unwrap();
+ let idx = self.state.queue_next.fetch_add(1, Ordering::SeqCst);
+ let queues = self.state.queues.lock();
+ queues[idx % queues.len()].send(job).unwrap();
}
Ok(())
diff --git a/src/router/messages.rs b/src/router/messages.rs
index e7b592b..bf4d13b 100644
--- a/src/router/messages.rs
+++ b/src/router/messages.rs
@@ -1,8 +1,8 @@
use byteorder::LittleEndian;
use zerocopy::byteorder::{U32, U64};
-use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
+use zerocopy::{AsBytes, FromBytes};
-pub const TYPE_TRANSPORT: u8 = 4;
+pub const TYPE_TRANSPORT: u32 = 4;
#[repr(packed)]
#[derive(Copy, Clone, FromBytes, AsBytes)]
diff --git a/src/router/mod.rs b/src/router/mod.rs
index 883c875..8cd0d3b 100644
--- a/src/router/mod.rs
+++ b/src/router/mod.rs
@@ -14,6 +14,5 @@ use messages::TransportHeader;
use std::mem;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
-
pub use device::Device;
pub use peer::Peer;
diff --git a/src/router/peer.rs b/src/router/peer.rs
index 0cd588d..9ad5d2f 100644
--- a/src/router/peer.rs
+++ b/src/router/peer.rs
@@ -1,19 +1,17 @@
use std::mem;
use std::net::{IpAddr, SocketAddr};
-use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
+use std::sync::atomic::AtomicBool;
+use std::sync::atomic::Ordering;
use std::sync::mpsc::{sync_channel, SyncSender};
-use std::sync::{Arc, Weak};
+use std::sync::Arc;
use std::thread;
+use arraydeque::{ArrayDeque, Wrapping};
use log::debug;
-
use spin::Mutex;
-
-use arraydeque::{ArrayDeque, Saturating, Wrapping};
-use zerocopy::{AsBytes, LayoutVerified};
-
use treebitmap::address::Address;
use treebitmap::IpLookupTable;
+use zerocopy::LayoutVerified;
use super::super::constants::*;
use super::super::types::{Bind, KeyPair, Tun};
@@ -29,9 +27,10 @@ use futures::*;
use super::workers::Operation;
use super::workers::{worker_inbound, worker_outbound};
use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
+use super::SIZE_MESSAGE_PREFIX;
use super::constants::*;
-use super::types::Callbacks;
+use super::types::{Callbacks, RouterError};
pub struct KeyWheel {
next: Option<Arc<KeyPair>>, // next key state (unconfirmed)
@@ -45,11 +44,9 @@ pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
pub opaque: C::Opaque,
pub outbound: Mutex<SyncSender<JobOutbound>>,
pub inbound: Mutex<SyncSender<JobInbound<C, T, B>>>,
- pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake
- pub rx_bytes: AtomicU64, // received bytes
- pub tx_bytes: AtomicU64, // transmitted bytes
- pub keys: Mutex<KeyWheel>, // key-wheel
- pub ekey: Mutex<Option<EncryptionState>>, // encryption state
+ pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>,
+ pub keys: Mutex<KeyWheel>,
+ pub ekey: Mutex<Option<EncryptionState>>,
pub endpoint: Mutex<Option<B::Endpoint>>,
}
@@ -193,8 +190,6 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
previous: None,
retired: None,
}),
- rx_bytes: AtomicU64::new(0),
- tx_bytes: AtomicU64::new(0),
staged_packets: spin::Mutex::new(ArrayDeque::new()),
})
};
@@ -254,7 +249,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
mut msg: Vec<u8>,
) -> Option<JobParallel> {
let (tx, rx) = oneshot();
- let key = dec.keypair.send.key;
+ let key = dec.keypair.recv.key;
match self.inbound.lock().try_send((dec, src, rx)) {
Ok(_) => Some((
tx,
@@ -270,7 +265,11 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
}
pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> {
- debug_assert!(msg.len() >= mem::size_of::<TransportHeader>());
+ debug_assert!(
+ msg.len() >= mem::size_of::<TransportHeader>(),
+ "received message with size: {:}",
+ msg.len()
+ );
// parse / cast
let (header, _) = LayoutVerified::new_from_prefix(&mut msg[..]).unwrap();
@@ -318,6 +317,16 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
}
impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
+ /// Set the endpoint of the peer
+ ///
+ /// # Arguments
+ ///
+ /// - `endpoint`, socket address converted to bind endpoint
+ ///
+ /// # Note
+ ///
+ /// This API still permits support for the "sticky socket" behavior,
+ /// as sockets should be "unsticked" when manually updating the endpoint
pub fn set_endpoint(&self, endpoint: SocketAddr) {
*self.state.endpoint.lock() = Some(endpoint.into());
}
@@ -372,18 +381,67 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
);
}
+ // schedule confirmation
+ if new.initiator {
+ // attempt to confirm with staged packets
+ let mut staged = self.state.staged_packets.lock();
+ let keepalive = staged.len() == 0;
+ loop {
+ match staged.pop_front() {
+ Some(msg) => {
+ debug!("send staged packet to confirm key-pair");
+ self.send_raw(msg);
+ }
+ None => break,
+ }
+ }
+
+ // fall back to keepalive packet
+ if keepalive {
+ let ok = self.keepalive();
+ debug!("keepalive for confirmation, sent = {}", ok);
+ }
+ }
+
// return the released id (for handshake state machine)
release
}
- pub fn rx_bytes(&self) -> u64 {
- self.state.rx_bytes.load(Ordering::Relaxed)
+ fn send_raw(&self, msg: Vec<u8>) -> bool {
+ match self.state.send_job(msg) {
+ Some(job) => {
+ debug!("send_raw: got obtained send_job");
+ let device = &self.state.device;
+ let index = device.queue_next.fetch_add(1, Ordering::SeqCst);
+ let queues = device.queues.lock();
+ match queues[index % queues.len()].send(job) {
+ Ok(_) => true,
+ Err(_) => false,
+ }
+ }
+ None => false,
+ }
}
- pub fn tx_bytes(&self) -> u64 {
- self.state.tx_bytes.load(Ordering::Relaxed)
+ pub fn keepalive(&self) -> bool {
+ debug!("send keepalive");
+ self.send_raw(vec![0u8; SIZE_MESSAGE_PREFIX])
}
+ /// Map a subnet to the peer
+ ///
+ /// # Arguments
+ ///
+ /// - `ip`, the mask of the subnet
+ /// - `masklen`, the length of the mask
+ ///
+ /// # Note
+ ///
+ /// The `ip` must not have any bits set right of `masklen`.
+ /// e.g. `192.168.1.0/24` is valid, while `192.168.1.128/24` is not.
+ ///
+ /// If an identical value already exists as part of a prior peer,
+ /// the allowed IP entry will be removed from that peer and added to this peer.
pub fn add_subnet(&self, ip: IpAddr, masklen: u32) {
match ip {
IpAddr::V4(v4) => {
@@ -403,6 +461,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
};
}
+ /// List subnets mapped to the peer
+ ///
+ /// # Returns
+ ///
+ /// A vector of subnets, represented by as mask/size
pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> {
let mut res = Vec::new();
res.append(&mut treebit_list(
@@ -418,10 +481,32 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
res
}
+ /// Clear subnets mapped to the peer.
+ /// After the call, no subnets will be cryptkey routed to the peer.
+ /// Used for the UAPI command "replace_allowed_ips=true"
pub fn remove_subnets(&self) {
treebit_remove(self, &self.state.device.ipv4);
treebit_remove(self, &self.state.device.ipv6);
}
- fn send(&self, msg: Vec<u8>) {}
+ /// Send a raw message to the peer (used for handshake messages)
+ ///
+ /// # Arguments
+ ///
+ /// - `msg`, message body to send to peer
+ ///
+ /// # Returns
+ ///
+ /// Unit if packet was sent, or an error indicating why sending failed
+ pub fn send(&self, msg: &[u8]) -> Result<(), RouterError> {
+ let inner = &self.state;
+ match inner.endpoint.lock().as_ref() {
+ Some(endpoint) => inner
+ .device
+ .bind
+ .send(msg, endpoint)
+ .map_err(|_| RouterError::SendError),
+ None => Err(RouterError::NoEndpoint),
+ }
+ }
}
diff --git a/src/router/tests.rs b/src/router/tests.rs
index ea5e05f..de3799f 100644
--- a/src/router/tests.rs
+++ b/src/router/tests.rs
@@ -109,7 +109,7 @@ impl Bind for VoidBind {
VoidBind {}
}
- fn set_port(&self, port: u16) -> Result<(), Self::Error> {
+ fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
Ok(())
}
@@ -117,18 +117,19 @@ impl Bind for VoidBind {
None
}
- fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
+ fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
Ok((0, UnitEndpoint {}))
}
- fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> {
+ fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
Ok(())
}
}
+#[derive(Clone)]
struct PairBind {
- send: Mutex<SyncSender<Vec<u8>>>,
- recv: Mutex<Receiver<Vec<u8>>>,
+ send: Arc<Mutex<SyncSender<Vec<u8>>>>,
+ recv: Arc<Mutex<Receiver<Vec<u8>>>>,
}
impl Bind for PairBind {
@@ -137,12 +138,12 @@ impl Bind for PairBind {
fn new() -> PairBind {
PairBind {
- send: Mutex::new(sync_channel(0).0),
- recv: Mutex::new(sync_channel(0).1),
+ send: Arc::new(Mutex::new(sync_channel(0).0)),
+ recv: Arc::new(Mutex::new(sync_channel(0).1)),
}
}
- fn set_port(&self, port: u16) -> Result<(), Self::Error> {
+ fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
Ok(())
}
@@ -157,26 +158,31 @@ impl Bind for PairBind {
.unwrap()
.recv()
.map_err(|_| BindError::Disconnected)?;
- buf.copy_from_slice(&vec[..]);
+ let len = vec.len();
+ buf[..len].copy_from_slice(&vec[..]);
Ok((vec.len(), UnitEndpoint {}))
}
fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> {
- Ok(())
+ let owned = buf.to_owned();
+ match self.send.lock().unwrap().send(owned) {
+ Err(_) => Err(BindError::Disconnected),
+ Ok(_) => Ok(()),
+ }
}
}
fn bind_pair() -> (PairBind, PairBind) {
- let (tx1, rx1) = sync_channel(0);
- let (tx2, rx2) = sync_channel(0);
+ let (tx1, rx1) = sync_channel(128);
+ let (tx2, rx2) = sync_channel(128);
(
PairBind {
- send: Mutex::new(tx1),
- recv: Mutex::new(rx2),
+ send: Arc::new(Mutex::new(tx1)),
+ recv: Arc::new(Mutex::new(rx2)),
},
PairBind {
- send: Mutex::new(tx2),
- recv: Mutex::new(rx1),
+ send: Arc::new(Mutex::new(tx2)),
+ recv: Arc::new(Mutex::new(rx1)),
},
)
}
@@ -276,10 +282,10 @@ mod tests {
num_cpus::get(),
TunTest {},
VoidBind::new(),
- |t: &Opaque, _data: bool, _sent: bool| {
+ |t: &Opaque, _size: usize, _data: bool, _sent: bool| {
t.fetch_add(1, Ordering::SeqCst);
},
- |_t: &Opaque, _data: bool, _sent: bool| {},
+ |_t: &Opaque, _size: usize, _data: bool, _sent: bool| {},
|_t: &Opaque| {},
);
@@ -321,8 +327,12 @@ mod tests {
1,
TunTest {},
VoidBind::new(),
- |t: &Opaque, _data: bool, _sent: bool| t.send.store(true, Ordering::SeqCst),
- |t: &Opaque, _data: bool, _sent: bool| t.recv.store(true, Ordering::SeqCst),
+ |t: &Opaque, _size: usize, _data: bool, _sent: bool| {
+ t.send.store(true, Ordering::SeqCst)
+ },
+ |t: &Opaque, _size: usize, _data: bool, _sent: bool| {
+ t.recv.store(true, Ordering::SeqCst)
+ },
|t: &Opaque| t.need_key.store(true, Ordering::SeqCst),
);
@@ -397,8 +407,14 @@ mod tests {
}
}
+ fn wait() {
+ thread::sleep(Duration::from_millis(10));
+ }
+
#[test]
fn test_outbound_inbound() {
+ init();
+
// type for tracking events inside the router module
struct Flags {
@@ -408,48 +424,143 @@ mod tests {
}
type Opaque = Arc<Flags>;
- let (bind1, bind2) = bind_pair();
+ fn reset(opaq: &Opaque) {
+ opaq.send.store(false, Ordering::SeqCst);
+ opaq.recv.store(false, Ordering::SeqCst);
+ opaq.need_key.store(false, Ordering::SeqCst);
+ }
- // create matching devices
+ fn test(opaq: &Opaque, send: bool, recv: bool, need_key: bool) {
+ assert_eq!(
+ opaq.send.load(Ordering::Acquire),
+ send,
+ "send did not match"
+ );
+ assert_eq!(
+ opaq.recv.load(Ordering::Acquire),
+ recv,
+ "recv did not match"
+ );
+ assert_eq!(
+ opaq.need_key.load(Ordering::Acquire),
+ need_key,
+ "need_key did not match"
+ );
+ }
- let router1 = Device::new(
- 1,
- TunTest {},
- bind1,
- |t: &Opaque, _data: bool, _sent: bool| t.send.store(true, Ordering::SeqCst),
- |t: &Opaque, _data: bool, _sent: bool| t.recv.store(true, Ordering::SeqCst),
- |t: &Opaque| t.need_key.store(true, Ordering::SeqCst),
- );
+ let tests = [(
+ false,
+ ("192.168.1.0", 24, "192.168.1.20", true),
+ ("172.133.133.133", 32, "172.133.133.133", true),
+ )];
+
+ for (num, (stage, p1, p2)) in tests.iter().enumerate() {
+ let (bind1, bind2) = bind_pair();
+
+ // create matching devices
+
+ let router1 = Device::new(
+ 1,
+ TunTest {},
+ bind1.clone(),
+ |t: &Opaque, _size: usize, _data: bool, _sent: bool| {
+ t.send.store(true, Ordering::SeqCst)
+ },
+ |t: &Opaque, _size: usize, _data: bool, _sent: bool| {
+ t.recv.store(true, Ordering::SeqCst)
+ },
+ |t: &Opaque| t.need_key.store(true, Ordering::SeqCst),
+ );
+
+ let router2 = Device::new(
+ 1,
+ TunTest {},
+ bind2.clone(),
+ |t: &Opaque, _size: usize, _data: bool, _sent: bool| {
+ t.send.store(true, Ordering::SeqCst)
+ },
+ |t: &Opaque, _size: usize, _data: bool, _sent: bool| {
+ t.recv.store(true, Ordering::SeqCst)
+ },
+ |t: &Opaque| t.need_key.store(true, Ordering::SeqCst),
+ );
+
+ // prepare opaque values for tracing callbacks
+
+ let opaq1 = Arc::new(Flags {
+ send: AtomicBool::new(false),
+ recv: AtomicBool::new(false),
+ need_key: AtomicBool::new(false),
+ });
+
+ let opaq2 = Arc::new(Flags {
+ send: AtomicBool::new(false),
+ recv: AtomicBool::new(false),
+ need_key: AtomicBool::new(false),
+ });
+
+ // create peers with matching keypairs and assign subnets
+
+ let (mask, len, _ip, _okay) = p1;
+ let peer1 = router1.new_peer(opaq1.clone());
+ let mask: IpAddr = mask.parse().unwrap();
+ peer1.add_subnet(mask, *len);
+ peer1.set_endpoint("127.0.0.1:8080".parse().unwrap());
+
+ peer1.add_keypair(dummy_keypair(false));
+
+ let (mask, len, _ip, _okay) = p2;
+ let peer2 = router2.new_peer(opaq2.clone());
+ let mask: IpAddr = mask.parse().unwrap();
+ peer2.add_subnet(mask, *len);
+ peer2.set_endpoint("127.0.0.1:8080".parse().unwrap());
+
+ if *stage {
+ // stage a packet which can be used for confirmation (in place of a keepalive)
+ let (_mask, _len, ip, _okay) = p2;
+ let msg = make_packet(1024, ip.parse().unwrap());
+ router2.send(msg).expect("failed to sent staged packet");
+ wait();
+ test(&opaq2, false, false, true);
+ reset(&opaq2);
+ }
- let router2 = Device::new(
- 1,
- TunTest {},
- bind2,
- |t: &Opaque, _data: bool, _sent: bool| t.send.store(true, Ordering::SeqCst),
- |t: &Opaque, _data: bool, _sent: bool| t.recv.store(true, Ordering::SeqCst),
- |t: &Opaque| t.need_key.store(true, Ordering::SeqCst),
- );
+ // this should cause a key-confirmation packet (keepalive or staged packet)
+ peer2.add_keypair(dummy_keypair(true));
- // create peers with matching keypairs
+ wait();
+ test(&opaq2, true, false, false);
- let opaq1 = Arc::new(Flags {
- send: AtomicBool::new(false),
- recv: AtomicBool::new(false),
- need_key: AtomicBool::new(false),
- });
+ // read confirming message received by the other end ("across the internet")
+ let mut buf = vec![0u8; 1024];
+ let (len, from) = bind1.recv(&mut buf).unwrap();
+ buf.truncate(len);
+ router1.recv(from, buf).unwrap();
- let opaq2 = Arc::new(Flags {
- send: AtomicBool::new(false),
- recv: AtomicBool::new(false),
- need_key: AtomicBool::new(false),
- });
+ wait();
+ test(&opaq1, false, true, false);
+
+ // start crypt-key routing packets
- let peer1 = router1.new_peer(opaq1.clone());
- peer1.set_endpoint("127.0.0.1:8080".parse().unwrap());
- peer1.add_keypair(dummy_keypair(false));
+ for _ in 0..10 {
+ reset(&opaq1);
+ reset(&opaq2);
- let peer2 = router2.new_peer(opaq2.clone());
- peer2.set_endpoint("127.0.0.1:8080".parse().unwrap());
- peer2.add_keypair(dummy_keypair(true)); // this should cause an empty key-confirmation packet
+ // pass IP packet to router
+ let (_mask, _len, ip, _okay) = p1;
+ let msg = make_packet(1024, ip.parse().unwrap());
+ router1.send(msg).unwrap();
+ wait();
+ test(&opaq1, true, false, false);
+
+ // receive ("across the internet") on the other end
+ let mut buf = vec![0u8; 2048];
+ let (len, from) = bind2.recv(&mut buf).unwrap();
+ buf.truncate(len);
+ router2.recv(from, buf).unwrap();
+ wait();
+ test(&opaq2, false, true, false);
+ }
+ }
}
}
diff --git a/src/router/types.rs b/src/router/types.rs
index 7706997..61f1fe7 100644
--- a/src/router/types.rs
+++ b/src/router/types.rs
@@ -11,9 +11,9 @@ impl<T> Opaque for T where T: Send + Sync + 'static {}
/// * `0`, a reference to the opaque value assigned to the peer
/// * `1`, a bool indicating whether the message contained data (not just keepalive)
/// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?)
-pub trait Callback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {}
+pub trait Callback<T>: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
-impl<T, F> Callback<T> for F where F: Fn(&T, bool, bool) -> () + Sync + Send + 'static {}
+impl<T, F> Callback<T> for F where F: Fn(&T, usize, bool, bool) -> () + Sync + Send + 'static {}
/// A key callback takes 1 argument
///
@@ -58,6 +58,8 @@ pub enum RouterError {
MalformedIPHeader,
MalformedTransportMessage,
UnkownReceiverId,
+ NoEndpoint,
+ SendError,
}
impl fmt::Display for RouterError {
@@ -69,6 +71,8 @@ impl fmt::Display for RouterError {
RouterError::UnkownReceiverId => {
write!(f, "No decryption state associated with receiver id")
}
+ RouterError::NoEndpoint => write!(f, "No endpoint for peer"),
+ RouterError::SendError => write!(f, "Failed to send packet on bind"),
}
}
}
diff --git a/src/router/workers.rs b/src/router/workers.rs
index 85cf22a..b038a20 100644
--- a/src/router/workers.rs
+++ b/src/router/workers.rs
@@ -13,13 +13,14 @@ use std::sync::atomic::Ordering;
use zerocopy::{AsBytes, LayoutVerified};
use super::device::{DecryptionState, DeviceInner};
-use super::messages::TransportHeader;
+use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::PeerInner;
use super::types::Callbacks;
+use super::super::types::{Bind, Tun};
use super::ip::*;
-use super::super::types::{Bind, Tun};
+const SIZE_TAG: usize = 16;
#[derive(PartialEq, Debug)]
pub enum Operation {
@@ -105,32 +106,37 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
return;
}
};
+ debug!("inbound worker: obtained job");
// wait for job to complete
let _ = rx
.map(|buf| {
+ debug!("inbound worker: job complete");
if buf.okay {
// cast transport header
let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
match LayoutVerified::new_from_prefix(&buf.msg[..]) {
Some(v) => v,
None => {
+ debug!("inbound worker: failed to parse message");
return;
}
};
debug_assert!(
packet.len() >= CHACHA20_POLY1305.tag_len(),
- "this should be checked earlier in the pipeline"
+ "this should be checked earlier in the pipeline (decryption should fail)"
);
// check for replay
if !state.protector.lock().update(header.f_counter.get()) {
+ debug!("inbound worker: replay detected");
return;
}
// check for confirms key
if !state.confirmed.swap(true, Ordering::SeqCst) {
+ debug!("inbound worker: message confirms key");
peer.confirm_key(&state.keypair);
}
@@ -138,7 +144,8 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
*peer.endpoint.lock() = Some(endpoint);
// calculate length of IP packet + padding
- let length = packet.len() - CHACHA20_POLY1305.nonce_len();
+ let length = packet.len() - SIZE_TAG;
+ debug!("inbound worker: plaintext length = {}", length);
// check if should be written to TUN
let mut sent = false;
@@ -155,10 +162,14 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
}
}
}
+ } else {
+ debug!("inbound worker: received keepalive")
}
// trigger callback
- (device.call_recv)(&peer.opaque, length == 0, sent);
+ (device.call_recv)(&peer.opaque, buf.msg.len(), length == 0, sent);
+ } else {
+ debug!("inbound worker: authentication failure")
}
})
.wait();
@@ -178,10 +189,12 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
return;
}
};
+ debug!("outbound worker: obtained job");
// wait for job to complete
let _ = rx
.map(|buf| {
+ debug!("outbound worker: job complete");
if buf.okay {
// write to UDP bind
let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
@@ -199,6 +212,7 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
// trigger callback
(device.call_send)(
&peer.opaque,
+ buf.msg.len(),
buf.msg.len()
> CHACHA20_POLY1305.nonce_len() + mem::size_of::<TransportHeader>(),
xmit,
@@ -218,17 +232,26 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) {
}
Ok(val) => val,
};
+ debug!("parallel worker: obtained job");
+
+ // make space for tag (TODO: consider moving this out)
+ if buf.op == Operation::Encryption {
+ buf.msg.extend([0u8; SIZE_TAG].iter());
+ }
// cast and check size of packet
- let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) =
- match LayoutVerified::new_from_prefix(&buf.msg[..]) {
+ let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) =
+ match LayoutVerified::new_from_prefix(&mut buf.msg[..]) {
Some(v) => v,
- None => continue,
+ None => {
+ debug_assert!(
+ false,
+ "parallel worker: failed to parse message (insufficient size)"
+ );
+ continue;
+ }
};
-
- if packet.len() < CHACHA20_POLY1305.nonce_len() {
- continue;
- }
+ debug_assert!(packet.len() >= CHACHA20_POLY1305.tag_len());
// do the weird ring AEAD dance
let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap());
@@ -241,18 +264,27 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) {
match buf.op {
Operation::Encryption => {
- debug!("worker, process encryption");
+ debug!("parallel worker: process encryption");
- // note: extends the vector to accommodate the tag
- key.seal_in_place_append_tag(nonce, Aad::empty(), &mut buf.msg)
+ // set the type field
+ header.f_type.set(TYPE_TRANSPORT);
+
+ // encrypt content of transport message in-place
+ let end = packet.len() - SIZE_TAG;
+ let tag = key
+ .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..end])
.unwrap();
+
+ // append tag
+ packet[end..].copy_from_slice(tag.as_ref());
+
buf.okay = true;
}
Operation::Decryption => {
- debug!("worker, process decryption");
+ debug!("parallel worker: process decryption");
// opening failure is signaled by fault state
- buf.okay = match key.open_in_place(nonce, Aad::empty(), &mut buf.msg) {
+ buf.okay = match key.open_in_place(nonce, Aad::empty(), packet) {
Ok(_) => true,
Err(_) => false,
};
@@ -260,6 +292,10 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) {
}
// pass ownership to consumer
- let _ = tx.send(buf);
+ let okay = tx.send(buf);
+ debug!(
+ "parallel worker: passing ownership to sequential worker: {}",
+ okay.is_ok()
+ );
}
}