diff options
-rw-r--r-- | Cargo.lock | 1 | ||||
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | src/router/device.rs | 24 | ||||
-rw-r--r-- | src/router/tests.rs | 245 | ||||
-rw-r--r-- | src/router/workers.rs | 59 |
5 files changed, 239 insertions, 91 deletions
@@ -1569,6 +1569,7 @@ dependencies = [ "hmac 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)", + "num_cpus 1.10.1 (registry+https://github.com/rust-lang/crates.io-index)", "parking_lot 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", "pnet 0.22.0 (registry+https://github.com/rust-lang/crates.io-index)", "proptest 0.9.4 (registry+https://github.com/rust-lang/crates.io-index)", @@ -30,6 +30,7 @@ clear_on_drop = "0.2.3" parking_lot = "^0.9" futures-channel = "^0.2" env_logger = "0.6" +num_cpus = "^1.10" [dependencies.x25519-dalek] version = "^0.5" diff --git a/src/router/device.rs b/src/router/device.rs index 1d10244..73678cb 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -1,32 +1,28 @@ -use std::cmp; use std::collections::HashMap; use std::net::{Ipv4Addr, Ipv6Addr}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::mpsc::sync_channel; use std::sync::mpsc::SyncSender; -use std::sync::{Arc, Weak}; +use std::sync::Arc; use std::thread; use std::time::Instant; use log::debug; - use spin::{Mutex, RwLock}; use treebitmap::IpLookupTable; use zerocopy::LayoutVerified; -use super::super::types::{Bind, KeyPair, Tun}; - use super::anti_replay::AntiReplay; -use super::peer; -use super::peer::{Peer, PeerInner}; -use super::SIZE_MESSAGE_PREFIX; - use super::constants::*; use super::ip::*; - use super::messages::{TransportHeader, TYPE_TRANSPORT}; +use super::peer; +use super::peer::{Peer, PeerInner}; use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError}; use super::workers::{worker_parallel, JobParallel, Operation}; +use super::SIZE_MESSAGE_PREFIX; + +use super::super::types::{Bind, KeyPair, Tun}; pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> { // IO & timer callbacks @@ -139,8 +135,8 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>( match packet[0] >> 4 { VERSION_IP4 => { // check length and cast to IPv4 header - let (header, _) = LayoutVerified::new_from_prefix(packet)?; - let header: LayoutVerified<&[u8], IPv4Header> = header; + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; // lookup destination address device @@ -151,8 +147,8 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>( } VERSION_IP6 => { // check length and cast to IPv6 header - let (header, packet) = LayoutVerified::new_from_prefix(packet)?; - let header: LayoutVerified<&[u8], IPv6Header> = header; + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; // lookup destination address device diff --git a/src/router/tests.rs b/src/router/tests.rs index f574096..ea5e05f 100644 --- a/src/router/tests.rs +++ b/src/router/tests.rs @@ -2,10 +2,13 @@ use std::error::Error; use std::fmt; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; +use std::sync::Mutex; use std::thread; use std::time::{Duration, Instant}; +use num_cpus; use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv6::MutableIpv6Packet; @@ -14,6 +17,33 @@ use super::{Device, SIZE_MESSAGE_PREFIX}; extern crate test; +/* Error implementation */ + +#[derive(Debug)] +enum BindError { + Disconnected, +} + +impl Error for BindError { + fn description(&self) -> &str { + "Generic Bind Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for BindError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BindError::Disconnected => write!(f, "PairBind disconnected"), + } + } +} + +/* TUN implementation */ + #[derive(Debug)] enum TunError {} @@ -33,6 +63,22 @@ impl fmt::Display for TunError { } } +/* Endpoint implementation */ + +struct UnitEndpoint {} + +impl From<SocketAddr> for UnitEndpoint { + fn from(addr: SocketAddr) -> UnitEndpoint { + UnitEndpoint {} + } +} + +impl Into<SocketAddr> for UnitEndpoint { + fn into(self) -> SocketAddr { + "127.0.0.1:8080".parse().unwrap() + } +} + struct TunTest {} impl Tun for TunTest { @@ -51,14 +97,16 @@ impl Tun for TunTest { } } -struct BindTest {} +/* Bind implemenentations */ + +struct VoidBind {} -impl Bind for BindTest { +impl Bind for VoidBind { type Error = BindError; - type Endpoint = SocketAddr; + type Endpoint = UnitEndpoint; - fn new() -> BindTest { - BindTest {} + fn new() -> VoidBind { + VoidBind {} } fn set_port(&self, port: u16) -> Result<(), Self::Error> { @@ -70,7 +118,7 @@ impl Bind for BindTest { } fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> { - Ok((0, "127.0.0.1:8080".parse().unwrap())) + Ok((0, UnitEndpoint {})) } fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> { @@ -78,25 +126,61 @@ impl Bind for BindTest { } } -#[derive(Debug)] -enum BindError {} +struct PairBind { + send: Mutex<SyncSender<Vec<u8>>>, + recv: Mutex<Receiver<Vec<u8>>>, +} -impl Error for BindError { - fn description(&self) -> &str { - "Generic Bind Error" +impl Bind for PairBind { + type Error = BindError; + type Endpoint = UnitEndpoint; + + fn new() -> PairBind { + PairBind { + send: Mutex::new(sync_channel(0).0), + recv: Mutex::new(sync_channel(0).1), + } } - fn source(&self) -> Option<&(dyn Error + 'static)> { + fn set_port(&self, port: u16) -> Result<(), Self::Error> { + Ok(()) + } + + fn get_port(&self) -> Option<u16> { None } -} -impl fmt::Display for BindError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Not Possible") + fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> { + let vec = self + .recv + .lock() + .unwrap() + .recv() + .map_err(|_| BindError::Disconnected)?; + buf.copy_from_slice(&vec[..]); + Ok((vec.len(), UnitEndpoint {})) + } + + fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> { + Ok(()) } } +fn bind_pair() -> (PairBind, PairBind) { + let (tx1, rx1) = sync_channel(0); + let (tx2, rx2) = sync_channel(0); + ( + PairBind { + send: Mutex::new(tx1), + recv: Mutex::new(rx2), + }, + PairBind { + send: Mutex::new(tx2), + recv: Mutex::new(rx1), + }, + ) +} + fn dummy_keypair(initiator: bool) -> KeyPair { let k1 = Key { key: [0x53u8; 32], @@ -131,6 +215,32 @@ mod tests { use std::sync::atomic::AtomicU64; use test::Bencher; + fn get_tests() -> Vec<(&'static str, u32, &'static str, bool)> { + vec![ + ("192.168.1.0", 24, "192.168.1.20", true), + ("172.133.133.133", 32, "172.133.133.133", true), + ("172.133.133.133", 32, "172.133.133.132", false), + ( + "2001:db8::ff00:42:0000", + 112, + "2001:db8::ff00:42:3242", + true, + ), + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:0660", + false, + ), + ( + "2001:db8::ff00:42:8000", + 113, + "2001:db8::ff00:42:ffff", + true, + ), + ] + } + fn init() { let _ = env_logger::builder().is_test(true).try_init(); } @@ -162,16 +272,15 @@ mod tests { type Opaque = Arc<AtomicU64>; // create device - let workers = 4; let router = Device::new( - workers, + num_cpus::get(), TunTest {}, - BindTest {}, + VoidBind::new(), |t: &Opaque, _data: bool, _sent: bool| { t.fetch_add(1, Ordering::SeqCst); }, - |t: &Opaque, _data: bool, _sent: bool| {}, - |t: &Opaque| {}, + |_t: &Opaque, _data: bool, _sent: bool| {}, + |_t: &Opaque| {}, ); // add new peer @@ -185,16 +294,10 @@ mod tests { let ip: IpAddr = ip.parse().unwrap(); peer.add_subnet(mask, len); - for _ in 0..1024 { - let msg = make_packet(1024, ip); - router.send(msg).unwrap(); - } - + // every iteration sends 10 MB b.iter(|| { opaque.store(0, Ordering::SeqCst); - // wait till 10 MB while opaque.load(Ordering::Acquire) < 10 * 1024 { - // create "IP packet" let msg = make_packet(1024, ip); router.send(msg).unwrap(); } @@ -214,40 +317,16 @@ mod tests { type Opaque = Arc<Flags>; // create device - let workers = 4; let router = Device::new( - workers, + 1, TunTest {}, - BindTest {}, + VoidBind::new(), |t: &Opaque, _data: bool, _sent: bool| t.send.store(true, Ordering::SeqCst), |t: &Opaque, _data: bool, _sent: bool| t.recv.store(true, Ordering::SeqCst), |t: &Opaque| t.need_key.store(true, Ordering::SeqCst), ); - let tests = vec![ - ("192.168.1.0", 24, "192.168.1.20", true), - ("172.133.133.133", 32, "172.133.133.133", true), - ("172.133.133.133", 32, "172.133.133.132", false), - ( - "2001:db8::ff00:42:0000", - 112, - "2001:db8::ff00:42:3242", - true, - ), - ( - "2001:db8::ff00:42:8000", - 113, - "2001:db8::ff00:42:0660", - false, - ), - ( - "2001:db8::ff00:42:8000", - 113, - "2001:db8::ff00:42:ffff", - true, - ), - ]; - + let tests = get_tests(); for (num, (mask, len, ip, okay)) in tests.iter().enumerate() { for set_key in vec![true, false] { debug!("index = {}, set_key = {}", num, set_key); @@ -317,4 +396,60 @@ mod tests { } } } + + #[test] + fn test_outbound_inbound() { + // type for tracking events inside the router module + + struct Flags { + send: AtomicBool, + recv: AtomicBool, + need_key: AtomicBool, + } + type Opaque = Arc<Flags>; + + let (bind1, bind2) = bind_pair(); + + // create matching devices + + let router1 = Device::new( + 1, + TunTest {}, + bind1, + |t: &Opaque, _data: bool, _sent: bool| t.send.store(true, Ordering::SeqCst), + |t: &Opaque, _data: bool, _sent: bool| t.recv.store(true, Ordering::SeqCst), + |t: &Opaque| t.need_key.store(true, Ordering::SeqCst), + ); + + let router2 = Device::new( + 1, + TunTest {}, + bind2, + |t: &Opaque, _data: bool, _sent: bool| t.send.store(true, Ordering::SeqCst), + |t: &Opaque, _data: bool, _sent: bool| t.recv.store(true, Ordering::SeqCst), + |t: &Opaque| t.need_key.store(true, Ordering::SeqCst), + ); + + // create peers with matching keypairs + + let opaq1 = Arc::new(Flags { + send: AtomicBool::new(false), + recv: AtomicBool::new(false), + need_key: AtomicBool::new(false), + }); + + let opaq2 = Arc::new(Flags { + send: AtomicBool::new(false), + recv: AtomicBool::new(false), + need_key: AtomicBool::new(false), + }); + + let peer1 = router1.new_peer(opaq1.clone()); + peer1.set_endpoint("127.0.0.1:8080".parse().unwrap()); + peer1.add_keypair(dummy_keypair(false)); + + let peer2 = router2.new_peer(opaq2.clone()); + peer2.set_endpoint("127.0.0.1:8080".parse().unwrap()); + peer2.add_keypair(dummy_keypair(true)); // this should cause an empty key-confirmation packet + } } diff --git a/src/router/workers.rs b/src/router/workers.rs index fb22280..85cf22a 100644 --- a/src/router/workers.rs +++ b/src/router/workers.rs @@ -54,8 +54,8 @@ fn check_route<C: Callbacks, T: Tun, B: Bind>( match packet[0] >> 4 { VERSION_IP4 => { // check length and cast to IPv4 header - let (header, _) = LayoutVerified::new_from_prefix(packet)?; - let header: LayoutVerified<&[u8], IPv4Header> = header; + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; // check IPv4 source address device @@ -72,8 +72,8 @@ fn check_route<C: Callbacks, T: Tun, B: Bind>( } VERSION_IP6 => { // check length and cast to IPv6 header - let (header, _) = LayoutVerified::new_from_prefix(packet)?; - let header: LayoutVerified<&[u8], IPv6Header> = header; + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; // check IPv6 source address device @@ -110,14 +110,15 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( let _ = rx .map(|buf| { if buf.okay { - // parse / cast - let (header, packet) = match LayoutVerified::new_from_prefix(&buf.msg[..]) { - Some(v) => v, - None => { - return; - } - }; - let header: LayoutVerified<&[u8], TransportHeader> = header; + // cast transport header + let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = + match LayoutVerified::new_from_prefix(&buf.msg[..]) { + Some(v) => v, + None => { + return; + } + }; + debug_assert!( packet.len() >= CHACHA20_POLY1305.tag_len(), "this should be checked earlier in the pipeline" @@ -145,8 +146,13 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) { debug_assert!(inner_len <= length, "should be validated"); if inner_len <= length { - sent = true; - let _ = device.tun.write(&packet[..inner_len]); + sent = match device.tun.write(&packet[..inner_len]) { + Err(e) => { + debug!("failed to write inbound packet to TUN: {:?}", e); + false + } + Ok(_) => true, + } } } } @@ -177,8 +183,18 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>( let _ = rx .map(|buf| { if buf.okay { - // write to UDP device, TODO - let xmit = false; + // write to UDP bind + let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() { + match device.bind.send(&buf.msg[..], dst) { + Err(e) => { + debug!("failed to send outbound packet: {:?}", e); + false + } + Ok(_) => true, + } + } else { + false + }; // trigger callback (device.call_send)( @@ -204,17 +220,16 @@ pub fn worker_parallel(receiver: Receiver<JobParallel>) { }; // cast and check size of packet - let (header, packet) = match LayoutVerified::new_from_prefix(&buf.msg[..]) { - Some(v) => v, - None => continue, - }; + let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = + match LayoutVerified::new_from_prefix(&buf.msg[..]) { + Some(v) => v, + None => continue, + }; if packet.len() < CHACHA20_POLY1305.nonce_len() { continue; } - let header: LayoutVerified<&[u8], TransportHeader> = header; - // do the weird ring AEAD dance let key = LessSafeKey::new(UnboundKey::new(&CHACHA20_POLY1305, &buf.key[..]).unwrap()); |