diff options
Diffstat (limited to 'src/router')
-rw-r--r-- | src/router/device.rs | 55 | ||||
-rw-r--r-- | src/router/peer.rs | 65 | ||||
-rw-r--r-- | src/router/tests.rs | 46 | ||||
-rw-r--r-- | src/router/types.rs | 2 | ||||
-rw-r--r-- | src/router/workers.rs | 45 |
5 files changed, 119 insertions, 94 deletions
diff --git a/src/router/device.rs b/src/router/device.rs index d126959..989c2c2 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -17,21 +17,23 @@ use super::constants::*; use super::ip::*; use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::{new_peer, Peer, PeerInner}; -use super::types::{Callbacks, Opaque, RouterError}; +use super::types::{Callbacks, RouterError}; use super::workers::{worker_parallel, JobParallel, Operation}; use super::SIZE_MESSAGE_PREFIX; -use super::super::types::{Bind, KeyPair, Tun}; +use super::super::types::{KeyPair, Endpoint, bind, tun}; -pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> { - // IO & timer callbacks - pub tun: T, - pub bind: B, +pub struct DeviceInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { + // inbound writer (TUN) + pub inbound: T, + + // outbound writer (Bind) + pub outbound: RwLock<Option<B>>, // routing - pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state - pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing - pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing + pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state + pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing + pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing // work queues pub queue_next: AtomicUsize, // next round-robin index @@ -45,20 +47,20 @@ pub struct EncryptionState { pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout) } -pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> { +pub struct DecryptionState<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { pub keypair: Arc<KeyPair>, pub confirmed: AtomicBool, pub protector: Mutex<AntiReplay>, - pub peer: Arc<PeerInner<C, T, B>>, + pub peer: Arc<PeerInner<E, C, T, B>>, pub death: Instant, // time when the key can no longer be used for decryption } -pub struct Device<C: Callbacks, T: Tun, B: Bind> { - state: Arc<DeviceInner<C, T, B>>, // reference to device state +pub struct Device<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { + state: Arc<DeviceInner<E, C, T, B>>, // reference to device state handles: Vec<thread::JoinHandle<()>>, // join handles for workers } -impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> { +impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> { fn drop(&mut self) { debug!("router: dropping device"); @@ -83,10 +85,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> { } #[inline(always)] -fn get_route<C: Callbacks, T: Tun, B: Bind>( - device: &Arc<DeviceInner<C, T, B>>, +fn get_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + device: &Arc<DeviceInner<E, C, T, B>>, packet: &[u8], -) -> Option<Arc<PeerInner<C, T, B>>> { +) -> Option<Arc<PeerInner<E, C, T, B>>> { // ensure version access within bounds if packet.len() < 1 { return None; @@ -122,12 +124,12 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>( } } -impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { - pub fn new(num_workers: usize, tun: T, bind: B) -> Device<C, T, B> { +impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> { + pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> { // allocate shared device state let mut inner = DeviceInner { - tun, - bind, + inbound: tun, + outbound: RwLock::new(None), queues: Mutex::new(Vec::with_capacity(num_workers)), queue_next: AtomicUsize::new(0), recv: RwLock::new(HashMap::new()), @@ -159,7 +161,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { /// # Returns /// /// A atomic ref. counted peer (with liftime matching the device) - pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T, B> { + pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> { new_peer(self.state.clone(), opaque) } @@ -199,7 +201,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { /// # Returns /// /// - pub fn recv(&self, src: B::Endpoint, msg: Vec<u8>) -> Result<(), RouterError> { + pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> { // parse / cast let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { Some(v) => v, @@ -231,4 +233,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { Ok(()) } + + /// Set outbound writer + /// + /// + pub fn set_outbound_writer(&self, new : B) { + *self.state.outbound.write() = Some(new); + } } diff --git a/src/router/peer.rs b/src/router/peer.rs index 86723bb..189904c 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -14,7 +14,7 @@ use treebitmap::IpLookupTable; use zerocopy::LayoutVerified; use super::super::constants::*; -use super::super::types::{Bind, Endpoint, KeyPair, Tun}; +use super::super::types::{Endpoint, KeyPair, bind, tun}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; @@ -39,28 +39,28 @@ pub struct KeyWheel { retired: Vec<u32>, // retired ids } -pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> { - pub device: Arc<DeviceInner<C, T, B>>, +pub struct PeerInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { + pub device: Arc<DeviceInner<E, C, T, B>>, pub opaque: C::Opaque, pub outbound: Mutex<SyncSender<JobOutbound>>, - pub inbound: Mutex<SyncSender<JobInbound<C, T, B>>>, + pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>, pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, pub keys: Mutex<KeyWheel>, pub ekey: Mutex<Option<EncryptionState>>, - pub endpoint: Mutex<Option<B::Endpoint>>, + pub endpoint: Mutex<Option<E>>, } -pub struct Peer<C: Callbacks, T: Tun, B: Bind> { - state: Arc<PeerInner<C, T, B>>, +pub struct Peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> { + state: Arc<PeerInner<E, C, T, B>>, thread_outbound: Option<thread::JoinHandle<()>>, thread_inbound: Option<thread::JoinHandle<()>>, } -fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>( - peer: &Arc<PeerInner<C, T, B>>, - table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>, - callback: Box<dyn Fn(A, u32) -> E>, -) -> Vec<E> +fn treebit_list<A, R, E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + peer: &Arc<PeerInner<E, C, T, B>>, + table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>, + callback: Box<dyn Fn(A, u32) -> R>, +) -> Vec<R> where A: Address, { @@ -74,9 +74,9 @@ where res } -fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>( - peer: &Peer<C, T, B>, - table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>, +fn treebit_remove<E : Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + peer: &Peer<E, C, T, B>, + table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>, ) { let mut m = table.write(); @@ -107,8 +107,8 @@ impl EncryptionState { } } -impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> { - fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> { +impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> { + fn new(peer: &Arc<PeerInner<E, C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<E, C, T, B> { DecryptionState { confirmed: AtomicBool::new(keypair.initiator), keypair: keypair.clone(), @@ -119,7 +119,7 @@ impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> { } } -impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> { +impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> { fn drop(&mut self) { let peer = &self.state; @@ -167,10 +167,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> { } } -pub fn new_peer<C: Callbacks, T: Tun, B: Bind>( - device: Arc<DeviceInner<C, T, B>>, +pub fn new_peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + device: Arc<DeviceInner<E, C, T, B>>, opaque: C::Opaque, -) -> Peer<C, T, B> { +) -> Peer<E, C, T, B> { let (out_tx, out_rx) = sync_channel(128); let (in_tx, in_rx) = sync_channel(128); @@ -215,7 +215,7 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>( } } -impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { +impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> { fn send_staged(&self) -> bool { debug!("peer.send_staged"); let mut sent = false; @@ -286,8 +286,8 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { pub fn recv_job( &self, - src: B::Endpoint, - dec: Arc<DecryptionState<C, T, B>>, + src: E, + dec: Arc<DecryptionState<E, C, T, B>>, mut msg: Vec<u8>, ) -> Option<JobParallel> { let (tx, rx) = oneshot(); @@ -370,7 +370,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { } } -impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { +impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> { /// Set the endpoint of the peer /// /// # Arguments @@ -381,9 +381,9 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { /// /// This API still permits support for the "sticky socket" behavior, /// as sockets should be "unsticked" when manually updating the endpoint - pub fn set_endpoint(&self, address: SocketAddr) { + pub fn set_endpoint(&self, endpoint: E) { debug!("peer.set_endpoint"); - *self.state.endpoint.lock() = Some(B::Endpoint::from_address(address)); + *self.state.endpoint.lock() = Some(endpoint); } /// Returns the current endpoint of the peer (for configuration) @@ -591,11 +591,12 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { debug!("peer.send"); let inner = &self.state; match inner.endpoint.lock().as_ref() { - Some(endpoint) => inner - .device - .bind - .send(msg, endpoint) - .map_err(|_| RouterError::SendError), + Some(endpoint) => inner.device + .outbound + .read() + .as_ref() + .ok_or(RouterError::SendError) + .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError) ), None => Err(RouterError::NoEndpoint), } } diff --git a/src/router/tests.rs b/src/router/tests.rs index f42e1f6..3b6b941 100644 --- a/src/router/tests.rs +++ b/src/router/tests.rs @@ -1,18 +1,18 @@ -use std::error::Error; -use std::fmt; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use std::sync::atomic::Ordering; -use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; use std::thread; -use std::time::{Duration, Instant}; +use std::time::Duration; use num_cpus; use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv6::MutableIpv6Packet; -use super::super::types::{dummy, Bind, Endpoint, Key, KeyPair, Tun}; +use super::super::types::bind::*; +use super::super::types::tun::*; +use super::super::types::*; + use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX}; extern crate test; @@ -145,8 +145,9 @@ mod tests { } // create device - let router: Device<BencherCallbacks, dummy::TunTest, dummy::VoidBind> = - Device::new(num_cpus::get(), dummy::TunTest {}, dummy::VoidBind::new()); + let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap(); + let router: Device<_, BencherCallbacks, dummy::TunTest, dummy::VoidBind> = + Device::new(num_cpus::get(), tun_writer); // add new peer let opaque = Arc::new(AtomicUsize::new(0)); @@ -174,8 +175,9 @@ mod tests { init(); // create device - let router: Device<TestCallbacks, _, _> = - Device::new(1, dummy::TunTest::new(), dummy::VoidBind::new()); + let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap(); + let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer); + router.set_outbound_writer(dummy::VoidBind::new()); let tests = vec![ ("192.168.1.0", 24, "192.168.1.20", true), @@ -315,12 +317,18 @@ mod tests { ]; for (stage, p1, p2) in tests.iter() { - // create matching devices - let (bind1, bind2) = dummy::PairBind::pair(); - let router1: Device<TestCallbacks, _, _> = - Device::new(1, dummy::TunTest::new(), bind1.clone()); - let router2: Device<TestCallbacks, _, _> = - Device::new(1, dummy::TunTest::new(), bind2.clone()); + let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = + dummy::PairBind::pair(); + + // create matching device + let (tun_writer1, _, _) = dummy::TunTest::create("tun1").unwrap(); + let (tun_writer2, _, _) = dummy::TunTest::create("tun1").unwrap(); + + let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1); + router1.set_outbound_writer(bind_writer1); + + let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2); + router2.set_outbound_writer(bind_writer2); // prepare opaque values for tracing callbacks @@ -339,7 +347,7 @@ mod tests { let peer2 = router2.new_peer(opaq2.clone()); let mask: IpAddr = mask.parse().unwrap(); peer2.add_subnet(mask, *len); - peer2.set_endpoint("127.0.0.1:8080".parse().unwrap()); + peer2.set_endpoint(dummy::UnitEndpoint::new()); if *stage { // stage a packet which can be used for confirmation (in place of a keepalive) @@ -372,7 +380,7 @@ mod tests { // read confirming message received by the other end ("across the internet") let mut buf = vec![0u8; 2048]; - let (len, from) = bind1.recv(&mut buf).unwrap(); + let (len, from) = bind_reader1.read(&mut buf).unwrap(); buf.truncate(len); router1.recv(from, buf).unwrap(); @@ -411,7 +419,7 @@ mod tests { // receive ("across the internet") on the other end let mut buf = vec![0u8; 2048]; - let (len, from) = bind2.recv(&mut buf).unwrap(); + let (len, from) = bind_reader2.read(&mut buf).unwrap(); buf.truncate(len); router2.recv(from, buf).unwrap(); diff --git a/src/router/types.rs b/src/router/types.rs index b7c3ae0..4a72c27 100644 --- a/src/router/types.rs +++ b/src/router/types.rs @@ -1,6 +1,8 @@ use std::error::Error; use std::fmt; +use super::super::types::Endpoint; + pub trait Opaque: Send + Sync + 'static {} impl<T> Opaque for T where T: Send + Sync + 'static {} diff --git a/src/router/workers.rs b/src/router/workers.rs index 6710816..2e89bb0 100644 --- a/src/router/workers.rs +++ b/src/router/workers.rs @@ -17,7 +17,7 @@ use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::PeerInner; use super::types::Callbacks; -use super::super::types::{Bind, Tun}; +use super::super::types::{Endpoint, tun, bind}; use super::ip::*; const SIZE_TAG: usize = 16; @@ -38,18 +38,18 @@ pub struct JobBuffer { pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer); #[allow(type_alias_bounds)] -pub type JobInbound<C, T, B: Bind> = ( - Arc<DecryptionState<C, T, B>>, - B::Endpoint, +pub type JobInbound<E, C, T, B: bind::Writer<E>> = ( + Arc<DecryptionState<E, C, T, B>>, + E, oneshot::Receiver<JobBuffer>, ); pub type JobOutbound = oneshot::Receiver<JobBuffer>; #[inline(always)] -fn check_route<C: Callbacks, T: Tun, B: Bind>( - device: &Arc<DeviceInner<C, T, B>>, - peer: &Arc<PeerInner<C, T, B>>, +fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + device: &Arc<DeviceInner<E, C, T, B>>, + peer: &Arc<PeerInner<E, C, T, B>>, packet: &[u8], ) -> Option<usize> { match packet[0] >> 4 { @@ -93,10 +93,10 @@ fn check_route<C: Callbacks, T: Tun, B: Bind>( } } -pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( - device: Arc<DeviceInner<C, T, B>>, // related device - peer: Arc<PeerInner<C, T, B>>, // related peer - receiver: Receiver<JobInbound<C, T, B>>, +pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + device: Arc<DeviceInner<E, C, T, B>>, // related device + peer: Arc<PeerInner<E, C, T, B>>, // related peer + receiver: Receiver<JobInbound<E, C, T, B>>, ) { loop { // fetch job @@ -153,7 +153,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) { debug_assert!(inner_len <= length, "should be validated"); if inner_len <= length { - sent = match device.tun.write(&packet[..inner_len]) { + sent = match device.inbound.write(&packet[..inner_len]) { Err(e) => { debug!("failed to write inbound packet to TUN: {:?}", e); false @@ -176,9 +176,9 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( } } -pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>( - device: Arc<DeviceInner<C, T, B>>, // related device - peer: Arc<PeerInner<C, T, B>>, // related peer +pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( + device: Arc<DeviceInner<E, C, T, B>>, // related device + peer: Arc<PeerInner<E, C, T, B>>, // related peer receiver: Receiver<JobOutbound>, ) { loop { @@ -198,12 +198,17 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>( if buf.okay { // write to UDP bind let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() { - match device.bind.send(&buf.msg[..], dst) { - Err(e) => { - debug!("failed to send outbound packet: {:?}", e); - false + let send : &Option<B> = &*device.outbound.read(); + if let Some(writer) = send.as_ref() { + match writer.write(&buf.msg[..], dst) { + Err(e) => { + debug!("failed to send outbound packet: {:?}", e); + false + } + Ok(_) => true, } - Ok(_) => true, + } else { + false } } else { false |