aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-31 21:00:10 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-31 21:00:10 +0200
commitd16521f4c75ebee6fa068461693755a5c1863e9f (patch)
tree4eb1ace931180cee33d956da4a000c02377eb525
parentReduce number of type parameters in router (diff)
downloadwireguard-rs-d16521f4c75ebee6fa068461693755a5c1863e9f.tar.xz
wireguard-rs-d16521f4c75ebee6fa068461693755a5c1863e9f.zip
Added Bind trait to router
-rw-r--r--src/main.rs9
-rw-r--r--src/router/device.rs34
-rw-r--r--src/router/peer.rs42
-rw-r--r--src/router/types.rs19
-rw-r--r--src/router/workers.rs24
-rw-r--r--src/types/udp.rs2
6 files changed, 70 insertions, 60 deletions
diff --git a/src/main.rs b/src/main.rs
index cfe93eb..6d1d2e1 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -52,14 +52,14 @@ impl Tun for TunTest {
}
}
-struct Test {}
+struct BindTest {}
-impl Bind for Test {
+impl Bind for BindTest {
type Error = BindError;
type Endpoint = SocketAddr;
- fn new() -> Test {
- Test {}
+ fn new() -> BindTest {
+ BindTest {}
}
fn set_port(&self, port: u16) -> Result<(), Self::Error> {
@@ -111,6 +111,7 @@ fn main() {
let router = router::Device::new(
4,
TunTest {},
+ BindTest {},
|t: &PeerTimer, data: bool, sent: bool| t.a.reset(Duration::from_millis(1000)),
|t: &PeerTimer, data: bool, sent: bool| t.b.reset(Duration::from_millis(1000)),
|t: &PeerTimer| println!("new key requested"),
diff --git a/src/router/device.rs b/src/router/device.rs
index 84f25c6..f04cf97 100644
--- a/src/router/device.rs
+++ b/src/router/device.rs
@@ -15,12 +15,13 @@ use super::anti_replay::AntiReplay;
use super::peer;
use super::peer::{Peer, PeerInner};
-use super::types::{Callback, Callbacks, CallbacksPhantom, KeyCallback, Opaque};
+use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks};
use super::workers::{worker_parallel, JobParallel};
-pub struct DeviceInner<C: Callbacks, T: Tun> {
+pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
// IO & timer generics
pub tun: T,
+ pub bind: B,
pub call_recv: C::CallbackRecv,
pub call_send: C::CallbackSend,
pub call_need_key: C::CallbackKey,
@@ -31,9 +32,9 @@ pub struct DeviceInner<C: Callbacks, T: Tun> {
pub injector: Injector<JobParallel>, // parallel enc/dec task injector
// routing
- pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T>>>, // receiver id -> decryption state
- pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T>>>>, // ipv4 cryptkey routing
- pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T>>>>, // ipv6 cryptkey routing
+ pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T, B>>>, // receiver id -> decryption state
+ pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
+ pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
}
pub struct EncryptionState {
@@ -44,18 +45,21 @@ pub struct EncryptionState {
// (birth + reject-after-time - keepalive-timeout - rekey-timeout)
}
-pub struct DecryptionState<C: Callbacks, T: Tun> {
+pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
pub key: [u8; 32],
pub keypair: Weak<KeyPair>,
pub confirmed: AtomicBool,
pub protector: spin::Mutex<AntiReplay>,
- pub peer: Weak<PeerInner<C, T>>,
+ pub peer: Weak<PeerInner<C, T, B>>,
pub death: Instant, // time when the key can no longer be used for decryption
}
-pub struct Device<C: Callbacks, T: Tun>(Arc<DeviceInner<C, T>>, Vec<thread::JoinHandle<()>>);
+pub struct Device<C: Callbacks, T: Tun, B: Bind>(
+ Arc<DeviceInner<C, T, B>>,
+ Vec<thread::JoinHandle<()>>,
+);
-impl<C: Callbacks, T: Tun> Drop for Device<C, T> {
+impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
fn drop(&mut self) {
// mark device as stopped
let device = &self.0;
@@ -73,19 +77,21 @@ impl<C: Callbacks, T: Tun> Drop for Device<C, T> {
}
}
-impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun>
- Device<CallbacksPhantom<O, R, S, K>, T>
+impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bind>
+ Device<PhantomCallbacks<O, R, S, K>, T, B>
{
pub fn new(
num_workers: usize,
tun: T,
+ bind: B,
call_recv: R,
call_send: S,
call_need_key: K,
- ) -> Device<CallbacksPhantom<O, R, S, K>, T> {
+ ) -> Device<PhantomCallbacks<O, R, S, K>, T, B> {
// allocate shared device state
let inner = Arc::new(DeviceInner {
tun,
+ bind,
call_recv,
call_send,
call_need_key,
@@ -122,13 +128,13 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun>
}
}
-impl<C: Callbacks, T: Tun> Device<C, T> {
+impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
/// Adds a new peer to the device
///
/// # Returns
///
/// A atomic ref. counted peer (with liftime matching the device)
- pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T> {
+ pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T, B> {
peer::new_peer(self.0.clone(), opaque)
}
diff --git a/src/router/peer.rs b/src/router/peer.rs
index 647d24f..e21e69c 100644
--- a/src/router/peer.rs
+++ b/src/router/peer.rs
@@ -12,7 +12,7 @@ use treebitmap::address::Address;
use treebitmap::IpLookupTable;
use super::super::constants::*;
-use super::super::types::{KeyPair, Tun};
+use super::super::types::{KeyPair, Tun, Bind};
use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
@@ -31,29 +31,29 @@ pub struct KeyWheel {
retired: Option<u32>, // retired id (previous id, after confirming key-pair)
}
-pub struct PeerInner<C: Callbacks, T: Tun> {
+pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
pub stopped: AtomicBool,
pub opaque: C::Opaque,
- pub device: Arc<DeviceInner<C, T>>,
+ pub device: Arc<DeviceInner<C, T, B>>,
pub thread_outbound: spin::Mutex<Option<thread::JoinHandle<()>>>,
pub thread_inbound: spin::Mutex<Option<thread::JoinHandle<()>>>,
pub queue_outbound: SyncSender<JobOutbound>,
- pub queue_inbound: SyncSender<JobInbound<C, T>>,
+ pub queue_inbound: SyncSender<JobInbound<C, T, B>>,
pub staged_packets: spin::Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake
- pub rx_bytes: AtomicU64, // received bytes
- pub tx_bytes: AtomicU64, // transmitted bytes
+ pub rx_bytes: AtomicU64, // received bytes
+ pub tx_bytes: AtomicU64, // transmitted bytes
pub keys: spin::Mutex<KeyWheel>, // key-wheel
pub ekey: spin::Mutex<Option<EncryptionState>>, // encryption state
pub endpoint: spin::Mutex<Option<Arc<SocketAddr>>>,
}
-pub struct Peer<C: Callbacks, T: Tun>(
- Arc<PeerInner<C, T>>,
+pub struct Peer<C: Callbacks, T: Tun, B: Bind>(
+ Arc<PeerInner<C, T, B>>,
);
-fn treebit_list<A, E, C: Callbacks, T: Tun>(
- peer: &Arc<PeerInner<C, T>>,
- table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T>>>>,
+fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>(
+ peer: &Arc<PeerInner<C, T, B>>,
+ table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>,
callback: Box<dyn Fn(A, u32) -> E>,
) -> Vec<E>
where
@@ -71,9 +71,9 @@ where
res
}
-fn treebit_remove<A: Address, C: Callbacks, T: Tun>(
- peer: &Peer<C, T>,
- table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T>>>>,
+fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
+ peer: &Peer<C, T, B>,
+ table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>,
) {
let mut m = table.write();
@@ -95,7 +95,7 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun>(
}
}
-impl<C: Callbacks, T: Tun> Drop for Peer<C, T> {
+impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
fn drop(&mut self) {
// mark peer as stopped
@@ -150,10 +150,10 @@ impl<C: Callbacks, T: Tun> Drop for Peer<C, T> {
}
}
-pub fn new_peer<C: Callbacks, T: Tun>(
- device: Arc<DeviceInner<C, T>>,
+pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
+ device: Arc<DeviceInner<C, T, B>>,
opaque: C::Opaque,
-) -> Peer<C, T> {
+) -> Peer<C, T, B> {
// allocate in-order queues
let (send_inbound, recv_inbound) = sync_channel(MAX_STAGED_PACKETS);
let (send_outbound, recv_outbound) = sync_channel(MAX_STAGED_PACKETS);
@@ -204,7 +204,7 @@ pub fn new_peer<C: Callbacks, T: Tun>(
Peer(peer)
}
-impl<C: Callbacks, T: Tun> PeerInner<C, T> {
+impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
pub fn confirm_key(&self, kp: Weak<KeyPair>) {
// upgrade key-pair to strong reference
@@ -214,8 +214,8 @@ impl<C: Callbacks, T: Tun> PeerInner<C, T> {
}
}
-impl<C: Callbacks, T: Tun> Peer<C, T> {
- fn new(inner: PeerInner<C, T>) -> Peer<C, T> {
+impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
+ fn new(inner: PeerInner<C, T, B>) -> Peer<C, T, B> {
Peer(Arc::new(inner))
}
diff --git a/src/router/types.rs b/src/router/types.rs
index f6a0311..82dcd09 100644
--- a/src/router/types.rs
+++ b/src/router/types.rs
@@ -20,10 +20,6 @@ pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {}
impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
-pub trait TunCallback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {}
-
-pub trait BindCallback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {}
-
pub trait Endpoint: Send + Sync {}
pub trait Callbacks: Send + Sync + 'static {
@@ -33,16 +29,23 @@ pub trait Callbacks: Send + Sync + 'static {
type CallbackKey: KeyCallback<Self::Opaque>;
}
-pub struct CallbacksPhantom<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> {
+/* Concrete implementation of "Callbacks",
+ * used to hide the constituent type parameters.
+ *
+ * This type is never instantiated.
+ */
+pub struct PhantomCallbacks<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> {
_phantom_opaque: PhantomData<O>,
_phantom_recv: PhantomData<R>,
_phantom_send: PhantomData<S>,
- _phantom_key: PhantomData<K>
+ _phantom_key: PhantomData<K>,
}
-impl <O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> Callbacks for CallbacksPhantom<O, R, S, K> {
+impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> Callbacks
+ for PhantomCallbacks<O, R, S, K>
+{
type Opaque = O;
type CallbackRecv = R;
type CallbackSend = S;
type CallbackKey = K;
-} \ No newline at end of file
+}
diff --git a/src/router/workers.rs b/src/router/workers.rs
index 2b0b9ec..c4a9f18 100644
--- a/src/router/workers.rs
+++ b/src/router/workers.rs
@@ -17,7 +17,7 @@ use super::messages::TransportHeader;
use super::peer::PeerInner;
use super::types::Callbacks;
-use super::super::types::Tun;
+use super::super::types::{Tun, Bind};
#[derive(PartialEq, Debug)]
pub enum Operation {
@@ -41,7 +41,7 @@ pub struct JobInner {
pub type JobBuffer = Arc<spin::Mutex<JobInner>>;
pub type JobParallel = (Arc<thread::JoinHandle<()>>, JobBuffer);
-pub type JobInbound<C, T> = (Weak<DecryptionState<C, T>>, JobBuffer);
+pub type JobInbound<C, T, B> = (Weak<DecryptionState<C, T, B>>, JobBuffer);
pub type JobOutbound = JobBuffer;
/* Strategy for workers acquiring a new job:
@@ -89,10 +89,10 @@ fn wait_recv<T>(running: &AtomicBool, recv: &Receiver<T>) -> Result<T, TryRecvEr
return Err(TryRecvError::Disconnected);
}
-pub fn worker_inbound<C: Callbacks, T: Tun>(
- device: Arc<DeviceInner<C, T>>, // related device
- peer: Arc<PeerInner<C, T>>, // related peer
- recv: Receiver<JobInbound<C, T>>, // in order queue
+pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
+ device: Arc<DeviceInner<C, T, B>>, // related device
+ peer: Arc<PeerInner<C, T, B>>, // related peer
+ recv: Receiver<JobInbound<C, T, B>>, // in order queue
) {
loop {
match wait_recv(&peer.stopped, &recv) {
@@ -157,10 +157,10 @@ pub fn worker_inbound<C: Callbacks, T: Tun>(
}
}
-pub fn worker_outbound<C: Callbacks, T: Tun>(
- device: Arc<DeviceInner<C, T>>, // related device
- peer: Arc<PeerInner<C, T>>, // related peer
- recv: Receiver<JobOutbound>, // in order queue
+pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
+ device: Arc<DeviceInner<C, T, B>>, // related device
+ peer: Arc<PeerInner<C, T, B>>, // related peer
+ recv: Receiver<JobOutbound>, // in order queue
) {
loop {
match wait_recv(&peer.stopped, &recv) {
@@ -205,8 +205,8 @@ pub fn worker_outbound<C: Callbacks, T: Tun>(
}
}
-pub fn worker_parallel<C: Callbacks, T: Tun>(
- device: Arc<DeviceInner<C, T>>,
+pub fn worker_parallel<C: Callbacks, T: Tun, B: Bind>(
+ device: Arc<DeviceInner<C, T, B>>,
local: Worker<JobParallel>, // local job queue (local to thread)
stealers: Vec<Stealer<JobParallel>>, // stealers (from other threads)
) {
diff --git a/src/types/udp.rs b/src/types/udp.rs
index 4bf0a9c..71d5a79 100644
--- a/src/types/udp.rs
+++ b/src/types/udp.rs
@@ -3,7 +3,7 @@ use std::error;
/* Often times an a file descriptor in an atomic might suffice.
*/
-pub trait Bind: Send + Sync {
+pub trait Bind: Send + Sync + 'static {
type Error: error::Error;
type Endpoint: Endpoint;