diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-08-31 20:25:16 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-08-31 20:25:16 +0200 |
commit | 46d76b80c6b1b3b1c549b770b1a5ba791b49da8a (patch) | |
tree | 0c880785943dd00e66cff6d8cd7560d42dc68c24 /src | |
parent | Explicitly clear t0 in KDF macro (diff) | |
download | wireguard-rs-46d76b80c6b1b3b1c549b770b1a5ba791b49da8a.tar.xz wireguard-rs-46d76b80c6b1b3b1c549b770b1a5ba791b49da8a.zip |
Reduce number of type parameters in router
Merge multiple related type parameters into trait,
allowing for easier refactoring and better maintainability.
Diffstat (limited to 'src')
-rw-r--r-- | src/handshake/noise.rs | 2 | ||||
-rw-r--r-- | src/main.rs | 40 | ||||
-rw-r--r-- | src/router/device.rs | 60 | ||||
-rw-r--r-- | src/router/messages.rs | 4 | ||||
-rw-r--r-- | src/router/mod.rs | 2 | ||||
-rw-r--r-- | src/router/peer.rs | 48 | ||||
-rw-r--r-- | src/router/types.rs | 23 | ||||
-rw-r--r-- | src/router/workers.rs | 30 |
8 files changed, 137 insertions, 72 deletions
diff --git a/src/handshake/noise.rs b/src/handshake/noise.rs index 1e7c50d..9fc0eb4 100644 --- a/src/handshake/noise.rs +++ b/src/handshake/noise.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - // DH use x25519_dalek::PublicKey; use x25519_dalek::StaticSecret; diff --git a/src/main.rs b/src/main.rs index 600e144..cfe93eb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,7 +13,44 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use types::{Bind, KeyPair}; +use types::{Bind, KeyPair, Tun}; + +#[derive(Debug)] +enum TunError {} + +impl Error for TunError { + fn description(&self) -> &str { + "Generic Tun Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for TunError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Not Possible") + } +} + +struct TunTest {} + +impl Tun for TunTest { + type Error = TunError; + + fn mtu(&self) -> usize { + 1500 + } + + fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> { + Ok(0) + } + + fn write(&self, src: &[u8]) -> Result<(), Self::Error> { + Ok(()) + } +} struct Test {} @@ -73,6 +110,7 @@ fn main() { { let router = router::Device::new( 4, + TunTest {}, |t: &PeerTimer, data: bool, sent: bool| t.a.reset(Duration::from_millis(1000)), |t: &PeerTimer, data: bool, sent: bool| t.b.reset(Duration::from_millis(1000)), |t: &PeerTimer| println!("new key requested"), diff --git a/src/router/device.rs b/src/router/device.rs index a7f0590..84f25c6 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -5,7 +5,7 @@ use std::sync::{Arc, Weak}; use std::thread; use std::time::Instant; -use crossbeam_deque::{Injector, Steal, Stealer, Worker}; +use crossbeam_deque::{Injector, Worker}; use spin; use treebitmap::IpLookupTable; @@ -15,24 +15,25 @@ use super::anti_replay::AntiReplay; use super::peer; use super::peer::{Peer, PeerInner}; -use super::types::{Callback, KeyCallback, Opaque}; +use super::types::{Callback, Callbacks, CallbacksPhantom, KeyCallback, Opaque}; use super::workers::{worker_parallel, JobParallel}; -pub struct DeviceInner<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> { +pub struct DeviceInner<C: Callbacks, T: Tun> { + // IO & timer generics + pub tun: T, + pub call_recv: C::CallbackRecv, + pub call_send: C::CallbackSend, + pub call_need_key: C::CallbackKey, + // threading and workers pub running: AtomicBool, // workers running? pub parked: AtomicBool, // any workers parked? pub injector: Injector<JobParallel>, // parallel enc/dec task injector - // unboxed callbacks (used for timers and handshake requests) - pub event_send: S, // called when authenticated message send - pub event_recv: R, // called when authenticated message received - pub event_need_key: K, // called when new key material is required - // routing - pub recv: spin::RwLock<HashMap<u32, DecryptionState<T, S, R, K>>>, // receiver id -> decryption state - pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<T, S, R, K>>>>, // ipv4 cryptkey routing - pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<T, S, R, K>>>>, // ipv6 cryptkey routing + pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T>>>, // receiver id -> decryption state + pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T>>>>, // ipv4 cryptkey routing + pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T>>>>, // ipv6 cryptkey routing } pub struct EncryptionState { @@ -43,21 +44,18 @@ pub struct EncryptionState { // (birth + reject-after-time - keepalive-timeout - rekey-timeout) } -pub struct DecryptionState<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> { +pub struct DecryptionState<C: Callbacks, T: Tun> { pub key: [u8; 32], pub keypair: Weak<KeyPair>, pub confirmed: AtomicBool, pub protector: spin::Mutex<AntiReplay>, - pub peer: Weak<PeerInner<T, S, R, K>>, + pub peer: Weak<PeerInner<C, T>>, pub death: Instant, // time when the key can no longer be used for decryption } -pub struct Device<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>( - Arc<DeviceInner<T, S, R, K>>, - Vec<thread::JoinHandle<()>>, -); +pub struct Device<C: Callbacks, T: Tun>(Arc<DeviceInner<C, T>>, Vec<thread::JoinHandle<()>>); -impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Drop for Device<T, S, R, K> { +impl<C: Callbacks, T: Tun> Drop for Device<C, T> { fn drop(&mut self) { // mark device as stopped let device = &self.0; @@ -75,18 +73,22 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Drop for Devi } } -impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Device<T, S, R, K> { +impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun> + Device<CallbacksPhantom<O, R, S, K>, T> +{ pub fn new( num_workers: usize, - event_recv: R, - event_send: S, - event_need_key: K, - ) -> Device<T, S, R, K> { + tun: T, + call_recv: R, + call_send: S, + call_need_key: K, + ) -> Device<CallbacksPhantom<O, R, S, K>, T> { // allocate shared device state let inner = Arc::new(DeviceInner { - event_recv, - event_send, - event_need_key, + tun, + call_recv, + call_send, + call_need_key, parked: AtomicBool::new(false), running: AtomicBool::new(true), injector: Injector::new(), @@ -95,7 +97,7 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Device<T, S, ipv6: spin::RwLock::new(IpLookupTable::new()), }); - // alloacate work pool resources + // allocate work pool resources let mut workers = Vec::with_capacity(num_workers); let mut stealers = Vec::with_capacity(num_workers); for _ in 0..num_workers { @@ -118,13 +120,15 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Device<T, S, // return exported device handle Device(inner, threads) } +} +impl<C: Callbacks, T: Tun> Device<C, T> { /// Adds a new peer to the device /// /// # Returns /// /// A atomic ref. counted peer (with liftime matching the device) - pub fn new_peer(&self, opaque: T) -> Peer<T, S, R, K> { + pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T> { peer::new_peer(self.0.clone(), opaque) } diff --git a/src/router/messages.rs b/src/router/messages.rs index d09bbb3..bec24ac 100644 --- a/src/router/messages.rs +++ b/src/router/messages.rs @@ -7,5 +7,5 @@ use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified}; pub struct TransportHeader { pub f_type: U32<LittleEndian>, pub f_receiver: U32<LittleEndian>, - pub f_counter: U64<LittleEndian> -}
\ No newline at end of file + pub f_counter: U64<LittleEndian>, +} diff --git a/src/router/mod.rs b/src/router/mod.rs index 70ac868..c1ecf1c 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -1,9 +1,9 @@ mod anti_replay; mod device; +mod messages; mod peer; mod types; mod workers; -mod messages; pub use device::Device; pub use peer::Peer; diff --git a/src/router/peer.rs b/src/router/peer.rs index 234c353..647d24f 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -12,7 +12,7 @@ use treebitmap::address::Address; use treebitmap::IpLookupTable; use super::super::constants::*; -use super::super::types::KeyPair; +use super::super::types::{KeyPair, Tun}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; @@ -20,7 +20,7 @@ use super::device::DeviceInner; use super::device::EncryptionState; use super::workers::{worker_inbound, worker_outbound, JobInbound, JobOutbound}; -use super::types::{Callback, KeyCallback, Opaque}; +use super::types::Callbacks; const MAX_STAGED_PACKETS: usize = 128; @@ -31,14 +31,14 @@ pub struct KeyWheel { retired: Option<u32>, // retired id (previous id, after confirming key-pair) } -pub struct PeerInner<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> { +pub struct PeerInner<C: Callbacks, T: Tun> { pub stopped: AtomicBool, - pub opaque: T, - pub device: Arc<DeviceInner<T, S, R, K>>, + pub opaque: C::Opaque, + pub device: Arc<DeviceInner<C, T>>, pub thread_outbound: spin::Mutex<Option<thread::JoinHandle<()>>>, pub thread_inbound: spin::Mutex<Option<thread::JoinHandle<()>>>, pub queue_outbound: SyncSender<JobOutbound>, - pub queue_inbound: SyncSender<JobInbound<T, S, R, K>>, + pub queue_inbound: SyncSender<JobInbound<C, T>>, pub staged_packets: spin::Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake pub rx_bytes: AtomicU64, // received bytes pub tx_bytes: AtomicU64, // transmitted bytes @@ -47,15 +47,15 @@ pub struct PeerInner<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T pub endpoint: spin::Mutex<Option<Arc<SocketAddr>>>, } -pub struct Peer<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>( - Arc<PeerInner<T, S, R, K>>, +pub struct Peer<C: Callbacks, T: Tun>( + Arc<PeerInner<C, T>>, ); -fn treebit_list<A, O, T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>( - peer: &Arc<PeerInner<T, S, R, K>>, - table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<T, S, R, K>>>>, - callback: Box<dyn Fn(A, u32) -> O>, -) -> Vec<O> +fn treebit_list<A, E, C: Callbacks, T: Tun>( + peer: &Arc<PeerInner<C, T>>, + table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T>>>>, + callback: Box<dyn Fn(A, u32) -> E>, +) -> Vec<E> where A: Address, { @@ -71,9 +71,9 @@ where res } -fn treebit_remove<A: Address, T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>( - peer: &Peer<T, S, R, K>, - table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<T, S, R, K>>>>, +fn treebit_remove<A: Address, C: Callbacks, T: Tun>( + peer: &Peer<C, T>, + table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T>>>>, ) { let mut m = table.write(); @@ -95,7 +95,7 @@ fn treebit_remove<A: Address, T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyC } } -impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Drop for Peer<T, S, R, K> { +impl<C: Callbacks, T: Tun> Drop for Peer<C, T> { fn drop(&mut self) { // mark peer as stopped @@ -150,10 +150,10 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Drop for Peer } } -pub fn new_peer<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>( - device: Arc<DeviceInner<T, S, R, K>>, - opaque: T, -) -> Peer<T, S, R, K> { +pub fn new_peer<C: Callbacks, T: Tun>( + device: Arc<DeviceInner<C, T>>, + opaque: C::Opaque, +) -> Peer<C, T> { // allocate in-order queues let (send_inbound, recv_inbound) = sync_channel(MAX_STAGED_PACKETS); let (send_outbound, recv_outbound) = sync_channel(MAX_STAGED_PACKETS); @@ -204,7 +204,7 @@ pub fn new_peer<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>( Peer(peer) } -impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> PeerInner<T, S, R, K> { +impl<C: Callbacks, T: Tun> PeerInner<C, T> { pub fn confirm_key(&self, kp: Weak<KeyPair>) { // upgrade key-pair to strong reference @@ -214,8 +214,8 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> PeerInner<T, } } -impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Peer<T, S, R, K> { - fn new(inner: PeerInner<T, S, R, K>) -> Peer<T, S, R, K> { +impl<C: Callbacks, T: Tun> Peer<C, T> { + fn new(inner: PeerInner<C, T>) -> Peer<C, T> { Peer(Arc::new(inner)) } diff --git a/src/router/types.rs b/src/router/types.rs index 3d486bc..f6a0311 100644 --- a/src/router/types.rs +++ b/src/router/types.rs @@ -1,3 +1,5 @@ +use std::marker::PhantomData; + pub trait Opaque: Send + Sync + 'static {} impl<T> Opaque for T where T: Send + Sync + 'static {} @@ -23,3 +25,24 @@ pub trait TunCallback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {} pub trait BindCallback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {} pub trait Endpoint: Send + Sync {} + +pub trait Callbacks: Send + Sync + 'static { + type Opaque: Opaque; + type CallbackRecv: Callback<Self::Opaque>; + type CallbackSend: Callback<Self::Opaque>; + type CallbackKey: KeyCallback<Self::Opaque>; +} + +pub struct CallbacksPhantom<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> { + _phantom_opaque: PhantomData<O>, + _phantom_recv: PhantomData<R>, + _phantom_send: PhantomData<S>, + _phantom_key: PhantomData<K> +} + +impl <O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> Callbacks for CallbacksPhantom<O, R, S, K> { + type Opaque = O; + type CallbackRecv = R; + type CallbackSend = S; + type CallbackKey = K; +}
\ No newline at end of file diff --git a/src/router/workers.rs b/src/router/workers.rs index 4861847..2b0b9ec 100644 --- a/src/router/workers.rs +++ b/src/router/workers.rs @@ -15,7 +15,9 @@ use super::device::DecryptionState; use super::device::DeviceInner; use super::messages::TransportHeader; use super::peer::PeerInner; -use super::types::{Callback, KeyCallback, Opaque}; +use super::types::Callbacks; + +use super::super::types::Tun; #[derive(PartialEq, Debug)] pub enum Operation { @@ -39,7 +41,7 @@ pub struct JobInner { pub type JobBuffer = Arc<spin::Mutex<JobInner>>; pub type JobParallel = (Arc<thread::JoinHandle<()>>, JobBuffer); -pub type JobInbound<T, S, R, K> = (Weak<DecryptionState<T, S, R, K>>, JobBuffer); +pub type JobInbound<C, T> = (Weak<DecryptionState<C, T>>, JobBuffer); pub type JobOutbound = JobBuffer; /* Strategy for workers acquiring a new job: @@ -87,10 +89,10 @@ fn wait_recv<T>(running: &AtomicBool, recv: &Receiver<T>) -> Result<T, TryRecvEr return Err(TryRecvError::Disconnected); } -pub fn worker_inbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>( - device: Arc<DeviceInner<T, S, R, K>>, // related device - peer: Arc<PeerInner<T, S, R, K>>, // related peer - recv: Receiver<JobInbound<T, S, R, K>>, // in order queue +pub fn worker_inbound<C: Callbacks, T: Tun>( + device: Arc<DeviceInner<C, T>>, // related device + peer: Arc<PeerInner<C, T>>, // related peer + recv: Receiver<JobInbound<C, T>>, // in order queue ) { loop { match wait_recv(&peer.stopped, &recv) { @@ -134,7 +136,7 @@ pub fn worker_inbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback< packet.len() >= CHACHA20_POLY1305.nonce_len(), "this should be checked earlier in the pipeline" ); - (device.event_recv)( + (device.call_recv)( &peer.opaque, packet.len() > CHACHA20_POLY1305.nonce_len(), true, @@ -155,10 +157,10 @@ pub fn worker_inbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback< } } -pub fn worker_outbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>( - device: Arc<DeviceInner<T, S, R, K>>, // related device - peer: Arc<PeerInner<T, S, R, K>>, // related peer - recv: Receiver<JobOutbound>, // in order queue +pub fn worker_outbound<C: Callbacks, T: Tun>( + device: Arc<DeviceInner<C, T>>, // related device + peer: Arc<PeerInner<C, T>>, // related peer + recv: Receiver<JobOutbound>, // in order queue ) { loop { match wait_recv(&peer.stopped, &recv) { @@ -180,7 +182,7 @@ pub fn worker_outbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback let xmit = false; // trigger callback - (device.event_send)( + (device.call_send)( &peer.opaque, buf.msg.len() > CHACHA20_POLY1305.nonce_len() @@ -203,8 +205,8 @@ pub fn worker_outbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback } } -pub fn worker_parallel<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>( - device: Arc<DeviceInner<T, S, R, K>>, +pub fn worker_parallel<C: Callbacks, T: Tun>( + device: Arc<DeviceInner<C, T>>, local: Worker<JobParallel>, // local job queue (local to thread) stealers: Vec<Stealer<JobParallel>>, // stealers (from other threads) ) { |