diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-08-31 21:00:10 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-08-31 21:00:10 +0200 |
commit | d16521f4c75ebee6fa068461693755a5c1863e9f (patch) | |
tree | 4eb1ace931180cee33d956da4a000c02377eb525 /src | |
parent | Reduce number of type parameters in router (diff) | |
download | wireguard-rs-d16521f4c75ebee6fa068461693755a5c1863e9f.tar.xz wireguard-rs-d16521f4c75ebee6fa068461693755a5c1863e9f.zip |
Added Bind trait to router
Diffstat (limited to 'src')
-rw-r--r-- | src/main.rs | 9 | ||||
-rw-r--r-- | src/router/device.rs | 34 | ||||
-rw-r--r-- | src/router/peer.rs | 42 | ||||
-rw-r--r-- | src/router/types.rs | 19 | ||||
-rw-r--r-- | src/router/workers.rs | 24 | ||||
-rw-r--r-- | src/types/udp.rs | 2 |
6 files changed, 70 insertions, 60 deletions
diff --git a/src/main.rs b/src/main.rs index cfe93eb..6d1d2e1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -52,14 +52,14 @@ impl Tun for TunTest { } } -struct Test {} +struct BindTest {} -impl Bind for Test { +impl Bind for BindTest { type Error = BindError; type Endpoint = SocketAddr; - fn new() -> Test { - Test {} + fn new() -> BindTest { + BindTest {} } fn set_port(&self, port: u16) -> Result<(), Self::Error> { @@ -111,6 +111,7 @@ fn main() { let router = router::Device::new( 4, TunTest {}, + BindTest {}, |t: &PeerTimer, data: bool, sent: bool| t.a.reset(Duration::from_millis(1000)), |t: &PeerTimer, data: bool, sent: bool| t.b.reset(Duration::from_millis(1000)), |t: &PeerTimer| println!("new key requested"), diff --git a/src/router/device.rs b/src/router/device.rs index 84f25c6..f04cf97 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -15,12 +15,13 @@ use super::anti_replay::AntiReplay; use super::peer; use super::peer::{Peer, PeerInner}; -use super::types::{Callback, Callbacks, CallbacksPhantom, KeyCallback, Opaque}; +use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks}; use super::workers::{worker_parallel, JobParallel}; -pub struct DeviceInner<C: Callbacks, T: Tun> { +pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> { // IO & timer generics pub tun: T, + pub bind: B, pub call_recv: C::CallbackRecv, pub call_send: C::CallbackSend, pub call_need_key: C::CallbackKey, @@ -31,9 +32,9 @@ pub struct DeviceInner<C: Callbacks, T: Tun> { pub injector: Injector<JobParallel>, // parallel enc/dec task injector // routing - pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T>>>, // receiver id -> decryption state - pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T>>>>, // ipv4 cryptkey routing - pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T>>>>, // ipv6 cryptkey routing + pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T, B>>>, // receiver id -> decryption state + pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing + pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing } pub struct EncryptionState { @@ -44,18 +45,21 @@ pub struct EncryptionState { // (birth + reject-after-time - keepalive-timeout - rekey-timeout) } -pub struct DecryptionState<C: Callbacks, T: Tun> { +pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> { pub key: [u8; 32], pub keypair: Weak<KeyPair>, pub confirmed: AtomicBool, pub protector: spin::Mutex<AntiReplay>, - pub peer: Weak<PeerInner<C, T>>, + pub peer: Weak<PeerInner<C, T, B>>, pub death: Instant, // time when the key can no longer be used for decryption } -pub struct Device<C: Callbacks, T: Tun>(Arc<DeviceInner<C, T>>, Vec<thread::JoinHandle<()>>); +pub struct Device<C: Callbacks, T: Tun, B: Bind>( + Arc<DeviceInner<C, T, B>>, + Vec<thread::JoinHandle<()>>, +); -impl<C: Callbacks, T: Tun> Drop for Device<C, T> { +impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> { fn drop(&mut self) { // mark device as stopped let device = &self.0; @@ -73,19 +77,21 @@ impl<C: Callbacks, T: Tun> Drop for Device<C, T> { } } -impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun> - Device<CallbacksPhantom<O, R, S, K>, T> +impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bind> + Device<PhantomCallbacks<O, R, S, K>, T, B> { pub fn new( num_workers: usize, tun: T, + bind: B, call_recv: R, call_send: S, call_need_key: K, - ) -> Device<CallbacksPhantom<O, R, S, K>, T> { + ) -> Device<PhantomCallbacks<O, R, S, K>, T, B> { // allocate shared device state let inner = Arc::new(DeviceInner { tun, + bind, call_recv, call_send, call_need_key, @@ -122,13 +128,13 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun> } } -impl<C: Callbacks, T: Tun> Device<C, T> { +impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { /// Adds a new peer to the device /// /// # Returns /// /// A atomic ref. counted peer (with liftime matching the device) - pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T> { + pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T, B> { peer::new_peer(self.0.clone(), opaque) } diff --git a/src/router/peer.rs b/src/router/peer.rs index 647d24f..e21e69c 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -12,7 +12,7 @@ use treebitmap::address::Address; use treebitmap::IpLookupTable; use super::super::constants::*; -use super::super::types::{KeyPair, Tun}; +use super::super::types::{KeyPair, Tun, Bind}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; @@ -31,29 +31,29 @@ pub struct KeyWheel { retired: Option<u32>, // retired id (previous id, after confirming key-pair) } -pub struct PeerInner<C: Callbacks, T: Tun> { +pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> { pub stopped: AtomicBool, pub opaque: C::Opaque, - pub device: Arc<DeviceInner<C, T>>, + pub device: Arc<DeviceInner<C, T, B>>, pub thread_outbound: spin::Mutex<Option<thread::JoinHandle<()>>>, pub thread_inbound: spin::Mutex<Option<thread::JoinHandle<()>>>, pub queue_outbound: SyncSender<JobOutbound>, - pub queue_inbound: SyncSender<JobInbound<C, T>>, + pub queue_inbound: SyncSender<JobInbound<C, T, B>>, pub staged_packets: spin::Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake - pub rx_bytes: AtomicU64, // received bytes - pub tx_bytes: AtomicU64, // transmitted bytes + pub rx_bytes: AtomicU64, // received bytes + pub tx_bytes: AtomicU64, // transmitted bytes pub keys: spin::Mutex<KeyWheel>, // key-wheel pub ekey: spin::Mutex<Option<EncryptionState>>, // encryption state pub endpoint: spin::Mutex<Option<Arc<SocketAddr>>>, } -pub struct Peer<C: Callbacks, T: Tun>( - Arc<PeerInner<C, T>>, +pub struct Peer<C: Callbacks, T: Tun, B: Bind>( + Arc<PeerInner<C, T, B>>, ); -fn treebit_list<A, E, C: Callbacks, T: Tun>( - peer: &Arc<PeerInner<C, T>>, - table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T>>>>, +fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>( + peer: &Arc<PeerInner<C, T, B>>, + table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>, callback: Box<dyn Fn(A, u32) -> E>, ) -> Vec<E> where @@ -71,9 +71,9 @@ where res } -fn treebit_remove<A: Address, C: Callbacks, T: Tun>( - peer: &Peer<C, T>, - table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T>>>>, +fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>( + peer: &Peer<C, T, B>, + table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>, ) { let mut m = table.write(); @@ -95,7 +95,7 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun>( } } -impl<C: Callbacks, T: Tun> Drop for Peer<C, T> { +impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> { fn drop(&mut self) { // mark peer as stopped @@ -150,10 +150,10 @@ impl<C: Callbacks, T: Tun> Drop for Peer<C, T> { } } -pub fn new_peer<C: Callbacks, T: Tun>( - device: Arc<DeviceInner<C, T>>, +pub fn new_peer<C: Callbacks, T: Tun, B: Bind>( + device: Arc<DeviceInner<C, T, B>>, opaque: C::Opaque, -) -> Peer<C, T> { +) -> Peer<C, T, B> { // allocate in-order queues let (send_inbound, recv_inbound) = sync_channel(MAX_STAGED_PACKETS); let (send_outbound, recv_outbound) = sync_channel(MAX_STAGED_PACKETS); @@ -204,7 +204,7 @@ pub fn new_peer<C: Callbacks, T: Tun>( Peer(peer) } -impl<C: Callbacks, T: Tun> PeerInner<C, T> { +impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { pub fn confirm_key(&self, kp: Weak<KeyPair>) { // upgrade key-pair to strong reference @@ -214,8 +214,8 @@ impl<C: Callbacks, T: Tun> PeerInner<C, T> { } } -impl<C: Callbacks, T: Tun> Peer<C, T> { - fn new(inner: PeerInner<C, T>) -> Peer<C, T> { +impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { + fn new(inner: PeerInner<C, T, B>) -> Peer<C, T, B> { Peer(Arc::new(inner)) } diff --git a/src/router/types.rs b/src/router/types.rs index f6a0311..82dcd09 100644 --- a/src/router/types.rs +++ b/src/router/types.rs @@ -20,10 +20,6 @@ pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {} impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {} -pub trait TunCallback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {} - -pub trait BindCallback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {} - pub trait Endpoint: Send + Sync {} pub trait Callbacks: Send + Sync + 'static { @@ -33,16 +29,23 @@ pub trait Callbacks: Send + Sync + 'static { type CallbackKey: KeyCallback<Self::Opaque>; } -pub struct CallbacksPhantom<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> { +/* Concrete implementation of "Callbacks", + * used to hide the constituent type parameters. + * + * This type is never instantiated. + */ +pub struct PhantomCallbacks<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> { _phantom_opaque: PhantomData<O>, _phantom_recv: PhantomData<R>, _phantom_send: PhantomData<S>, - _phantom_key: PhantomData<K> + _phantom_key: PhantomData<K>, } -impl <O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> Callbacks for CallbacksPhantom<O, R, S, K> { +impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> Callbacks + for PhantomCallbacks<O, R, S, K> +{ type Opaque = O; type CallbackRecv = R; type CallbackSend = S; type CallbackKey = K; -}
\ No newline at end of file +} diff --git a/src/router/workers.rs b/src/router/workers.rs index 2b0b9ec..c4a9f18 100644 --- a/src/router/workers.rs +++ b/src/router/workers.rs @@ -17,7 +17,7 @@ use super::messages::TransportHeader; use super::peer::PeerInner; use super::types::Callbacks; -use super::super::types::Tun; +use super::super::types::{Tun, Bind}; #[derive(PartialEq, Debug)] pub enum Operation { @@ -41,7 +41,7 @@ pub struct JobInner { pub type JobBuffer = Arc<spin::Mutex<JobInner>>; pub type JobParallel = (Arc<thread::JoinHandle<()>>, JobBuffer); -pub type JobInbound<C, T> = (Weak<DecryptionState<C, T>>, JobBuffer); +pub type JobInbound<C, T, B> = (Weak<DecryptionState<C, T, B>>, JobBuffer); pub type JobOutbound = JobBuffer; /* Strategy for workers acquiring a new job: @@ -89,10 +89,10 @@ fn wait_recv<T>(running: &AtomicBool, recv: &Receiver<T>) -> Result<T, TryRecvEr return Err(TryRecvError::Disconnected); } -pub fn worker_inbound<C: Callbacks, T: Tun>( - device: Arc<DeviceInner<C, T>>, // related device - peer: Arc<PeerInner<C, T>>, // related peer - recv: Receiver<JobInbound<C, T>>, // in order queue +pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( + device: Arc<DeviceInner<C, T, B>>, // related device + peer: Arc<PeerInner<C, T, B>>, // related peer + recv: Receiver<JobInbound<C, T, B>>, // in order queue ) { loop { match wait_recv(&peer.stopped, &recv) { @@ -157,10 +157,10 @@ pub fn worker_inbound<C: Callbacks, T: Tun>( } } -pub fn worker_outbound<C: Callbacks, T: Tun>( - device: Arc<DeviceInner<C, T>>, // related device - peer: Arc<PeerInner<C, T>>, // related peer - recv: Receiver<JobOutbound>, // in order queue +pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>( + device: Arc<DeviceInner<C, T, B>>, // related device + peer: Arc<PeerInner<C, T, B>>, // related peer + recv: Receiver<JobOutbound>, // in order queue ) { loop { match wait_recv(&peer.stopped, &recv) { @@ -205,8 +205,8 @@ pub fn worker_outbound<C: Callbacks, T: Tun>( } } -pub fn worker_parallel<C: Callbacks, T: Tun>( - device: Arc<DeviceInner<C, T>>, +pub fn worker_parallel<C: Callbacks, T: Tun, B: Bind>( + device: Arc<DeviceInner<C, T, B>>, local: Worker<JobParallel>, // local job queue (local to thread) stealers: Vec<Stealer<JobParallel>>, // stealers (from other threads) ) { diff --git a/src/types/udp.rs b/src/types/udp.rs index 4bf0a9c..71d5a79 100644 --- a/src/types/udp.rs +++ b/src/types/udp.rs @@ -3,7 +3,7 @@ use std::error; /* Often times an a file descriptor in an atomic might suffice. */ -pub trait Bind: Send + Sync { +pub trait Bind: Send + Sync + 'static { type Error: error::Error; type Endpoint: Endpoint; |