summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-31 20:25:16 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-31 20:25:16 +0200
commit46d76b80c6b1b3b1c549b770b1a5ba791b49da8a (patch)
tree0c880785943dd00e66cff6d8cd7560d42dc68c24 /src
parentExplicitly clear t0 in KDF macro (diff)
downloadwireguard-rs-46d76b80c6b1b3b1c549b770b1a5ba791b49da8a.tar.xz
wireguard-rs-46d76b80c6b1b3b1c549b770b1a5ba791b49da8a.zip
Reduce number of type parameters in router
Merge multiple related type parameters into trait, allowing for easier refactoring and better maintainability.
Diffstat (limited to 'src')
-rw-r--r--src/handshake/noise.rs2
-rw-r--r--src/main.rs40
-rw-r--r--src/router/device.rs60
-rw-r--r--src/router/messages.rs4
-rw-r--r--src/router/mod.rs2
-rw-r--r--src/router/peer.rs48
-rw-r--r--src/router/types.rs23
-rw-r--r--src/router/workers.rs30
8 files changed, 137 insertions, 72 deletions
diff --git a/src/handshake/noise.rs b/src/handshake/noise.rs
index 1e7c50d..9fc0eb4 100644
--- a/src/handshake/noise.rs
+++ b/src/handshake/noise.rs
@@ -1,5 +1,3 @@
-#![allow(dead_code)]
-
// DH
use x25519_dalek::PublicKey;
use x25519_dalek::StaticSecret;
diff --git a/src/main.rs b/src/main.rs
index 600e144..cfe93eb 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -13,7 +13,44 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
-use types::{Bind, KeyPair};
+use types::{Bind, KeyPair, Tun};
+
+#[derive(Debug)]
+enum TunError {}
+
+impl Error for TunError {
+ fn description(&self) -> &str {
+ "Generic Tun Error"
+ }
+
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ None
+ }
+}
+
+impl fmt::Display for TunError {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "Not Possible")
+ }
+}
+
+struct TunTest {}
+
+impl Tun for TunTest {
+ type Error = TunError;
+
+ fn mtu(&self) -> usize {
+ 1500
+ }
+
+ fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> {
+ Ok(0)
+ }
+
+ fn write(&self, src: &[u8]) -> Result<(), Self::Error> {
+ Ok(())
+ }
+}
struct Test {}
@@ -73,6 +110,7 @@ fn main() {
{
let router = router::Device::new(
4,
+ TunTest {},
|t: &PeerTimer, data: bool, sent: bool| t.a.reset(Duration::from_millis(1000)),
|t: &PeerTimer, data: bool, sent: bool| t.b.reset(Duration::from_millis(1000)),
|t: &PeerTimer| println!("new key requested"),
diff --git a/src/router/device.rs b/src/router/device.rs
index a7f0590..84f25c6 100644
--- a/src/router/device.rs
+++ b/src/router/device.rs
@@ -5,7 +5,7 @@ use std::sync::{Arc, Weak};
use std::thread;
use std::time::Instant;
-use crossbeam_deque::{Injector, Steal, Stealer, Worker};
+use crossbeam_deque::{Injector, Worker};
use spin;
use treebitmap::IpLookupTable;
@@ -15,24 +15,25 @@ use super::anti_replay::AntiReplay;
use super::peer;
use super::peer::{Peer, PeerInner};
-use super::types::{Callback, KeyCallback, Opaque};
+use super::types::{Callback, Callbacks, CallbacksPhantom, KeyCallback, Opaque};
use super::workers::{worker_parallel, JobParallel};
-pub struct DeviceInner<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> {
+pub struct DeviceInner<C: Callbacks, T: Tun> {
+ // IO & timer generics
+ pub tun: T,
+ pub call_recv: C::CallbackRecv,
+ pub call_send: C::CallbackSend,
+ pub call_need_key: C::CallbackKey,
+
// threading and workers
pub running: AtomicBool, // workers running?
pub parked: AtomicBool, // any workers parked?
pub injector: Injector<JobParallel>, // parallel enc/dec task injector
- // unboxed callbacks (used for timers and handshake requests)
- pub event_send: S, // called when authenticated message send
- pub event_recv: R, // called when authenticated message received
- pub event_need_key: K, // called when new key material is required
-
// routing
- pub recv: spin::RwLock<HashMap<u32, DecryptionState<T, S, R, K>>>, // receiver id -> decryption state
- pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<T, S, R, K>>>>, // ipv4 cryptkey routing
- pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<T, S, R, K>>>>, // ipv6 cryptkey routing
+ pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T>>>, // receiver id -> decryption state
+ pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T>>>>, // ipv4 cryptkey routing
+ pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T>>>>, // ipv6 cryptkey routing
}
pub struct EncryptionState {
@@ -43,21 +44,18 @@ pub struct EncryptionState {
// (birth + reject-after-time - keepalive-timeout - rekey-timeout)
}
-pub struct DecryptionState<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> {
+pub struct DecryptionState<C: Callbacks, T: Tun> {
pub key: [u8; 32],
pub keypair: Weak<KeyPair>,
pub confirmed: AtomicBool,
pub protector: spin::Mutex<AntiReplay>,
- pub peer: Weak<PeerInner<T, S, R, K>>,
+ pub peer: Weak<PeerInner<C, T>>,
pub death: Instant, // time when the key can no longer be used for decryption
}
-pub struct Device<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>(
- Arc<DeviceInner<T, S, R, K>>,
- Vec<thread::JoinHandle<()>>,
-);
+pub struct Device<C: Callbacks, T: Tun>(Arc<DeviceInner<C, T>>, Vec<thread::JoinHandle<()>>);
-impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Drop for Device<T, S, R, K> {
+impl<C: Callbacks, T: Tun> Drop for Device<C, T> {
fn drop(&mut self) {
// mark device as stopped
let device = &self.0;
@@ -75,18 +73,22 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Drop for Devi
}
}
-impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Device<T, S, R, K> {
+impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun>
+ Device<CallbacksPhantom<O, R, S, K>, T>
+{
pub fn new(
num_workers: usize,
- event_recv: R,
- event_send: S,
- event_need_key: K,
- ) -> Device<T, S, R, K> {
+ tun: T,
+ call_recv: R,
+ call_send: S,
+ call_need_key: K,
+ ) -> Device<CallbacksPhantom<O, R, S, K>, T> {
// allocate shared device state
let inner = Arc::new(DeviceInner {
- event_recv,
- event_send,
- event_need_key,
+ tun,
+ call_recv,
+ call_send,
+ call_need_key,
parked: AtomicBool::new(false),
running: AtomicBool::new(true),
injector: Injector::new(),
@@ -95,7 +97,7 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Device<T, S,
ipv6: spin::RwLock::new(IpLookupTable::new()),
});
- // alloacate work pool resources
+ // allocate work pool resources
let mut workers = Vec::with_capacity(num_workers);
let mut stealers = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
@@ -118,13 +120,15 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Device<T, S,
// return exported device handle
Device(inner, threads)
}
+}
+impl<C: Callbacks, T: Tun> Device<C, T> {
/// Adds a new peer to the device
///
/// # Returns
///
/// A atomic ref. counted peer (with liftime matching the device)
- pub fn new_peer(&self, opaque: T) -> Peer<T, S, R, K> {
+ pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T> {
peer::new_peer(self.0.clone(), opaque)
}
diff --git a/src/router/messages.rs b/src/router/messages.rs
index d09bbb3..bec24ac 100644
--- a/src/router/messages.rs
+++ b/src/router/messages.rs
@@ -7,5 +7,5 @@ use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
pub struct TransportHeader {
pub f_type: U32<LittleEndian>,
pub f_receiver: U32<LittleEndian>,
- pub f_counter: U64<LittleEndian>
-} \ No newline at end of file
+ pub f_counter: U64<LittleEndian>,
+}
diff --git a/src/router/mod.rs b/src/router/mod.rs
index 70ac868..c1ecf1c 100644
--- a/src/router/mod.rs
+++ b/src/router/mod.rs
@@ -1,9 +1,9 @@
mod anti_replay;
mod device;
+mod messages;
mod peer;
mod types;
mod workers;
-mod messages;
pub use device::Device;
pub use peer::Peer;
diff --git a/src/router/peer.rs b/src/router/peer.rs
index 234c353..647d24f 100644
--- a/src/router/peer.rs
+++ b/src/router/peer.rs
@@ -12,7 +12,7 @@ use treebitmap::address::Address;
use treebitmap::IpLookupTable;
use super::super::constants::*;
-use super::super::types::KeyPair;
+use super::super::types::{KeyPair, Tun};
use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
@@ -20,7 +20,7 @@ use super::device::DeviceInner;
use super::device::EncryptionState;
use super::workers::{worker_inbound, worker_outbound, JobInbound, JobOutbound};
-use super::types::{Callback, KeyCallback, Opaque};
+use super::types::Callbacks;
const MAX_STAGED_PACKETS: usize = 128;
@@ -31,14 +31,14 @@ pub struct KeyWheel {
retired: Option<u32>, // retired id (previous id, after confirming key-pair)
}
-pub struct PeerInner<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> {
+pub struct PeerInner<C: Callbacks, T: Tun> {
pub stopped: AtomicBool,
- pub opaque: T,
- pub device: Arc<DeviceInner<T, S, R, K>>,
+ pub opaque: C::Opaque,
+ pub device: Arc<DeviceInner<C, T>>,
pub thread_outbound: spin::Mutex<Option<thread::JoinHandle<()>>>,
pub thread_inbound: spin::Mutex<Option<thread::JoinHandle<()>>>,
pub queue_outbound: SyncSender<JobOutbound>,
- pub queue_inbound: SyncSender<JobInbound<T, S, R, K>>,
+ pub queue_inbound: SyncSender<JobInbound<C, T>>,
pub staged_packets: spin::Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake
pub rx_bytes: AtomicU64, // received bytes
pub tx_bytes: AtomicU64, // transmitted bytes
@@ -47,15 +47,15 @@ pub struct PeerInner<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T
pub endpoint: spin::Mutex<Option<Arc<SocketAddr>>>,
}
-pub struct Peer<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>(
- Arc<PeerInner<T, S, R, K>>,
+pub struct Peer<C: Callbacks, T: Tun>(
+ Arc<PeerInner<C, T>>,
);
-fn treebit_list<A, O, T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>(
- peer: &Arc<PeerInner<T, S, R, K>>,
- table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<T, S, R, K>>>>,
- callback: Box<dyn Fn(A, u32) -> O>,
-) -> Vec<O>
+fn treebit_list<A, E, C: Callbacks, T: Tun>(
+ peer: &Arc<PeerInner<C, T>>,
+ table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T>>>>,
+ callback: Box<dyn Fn(A, u32) -> E>,
+) -> Vec<E>
where
A: Address,
{
@@ -71,9 +71,9 @@ where
res
}
-fn treebit_remove<A: Address, T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>(
- peer: &Peer<T, S, R, K>,
- table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<T, S, R, K>>>>,
+fn treebit_remove<A: Address, C: Callbacks, T: Tun>(
+ peer: &Peer<C, T>,
+ table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T>>>>,
) {
let mut m = table.write();
@@ -95,7 +95,7 @@ fn treebit_remove<A: Address, T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyC
}
}
-impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Drop for Peer<T, S, R, K> {
+impl<C: Callbacks, T: Tun> Drop for Peer<C, T> {
fn drop(&mut self) {
// mark peer as stopped
@@ -150,10 +150,10 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Drop for Peer
}
}
-pub fn new_peer<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>(
- device: Arc<DeviceInner<T, S, R, K>>,
- opaque: T,
-) -> Peer<T, S, R, K> {
+pub fn new_peer<C: Callbacks, T: Tun>(
+ device: Arc<DeviceInner<C, T>>,
+ opaque: C::Opaque,
+) -> Peer<C, T> {
// allocate in-order queues
let (send_inbound, recv_inbound) = sync_channel(MAX_STAGED_PACKETS);
let (send_outbound, recv_outbound) = sync_channel(MAX_STAGED_PACKETS);
@@ -204,7 +204,7 @@ pub fn new_peer<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>(
Peer(peer)
}
-impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> PeerInner<T, S, R, K> {
+impl<C: Callbacks, T: Tun> PeerInner<C, T> {
pub fn confirm_key(&self, kp: Weak<KeyPair>) {
// upgrade key-pair to strong reference
@@ -214,8 +214,8 @@ impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> PeerInner<T,
}
}
-impl<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>> Peer<T, S, R, K> {
- fn new(inner: PeerInner<T, S, R, K>) -> Peer<T, S, R, K> {
+impl<C: Callbacks, T: Tun> Peer<C, T> {
+ fn new(inner: PeerInner<C, T>) -> Peer<C, T> {
Peer(Arc::new(inner))
}
diff --git a/src/router/types.rs b/src/router/types.rs
index 3d486bc..f6a0311 100644
--- a/src/router/types.rs
+++ b/src/router/types.rs
@@ -1,3 +1,5 @@
+use std::marker::PhantomData;
+
pub trait Opaque: Send + Sync + 'static {}
impl<T> Opaque for T where T: Send + Sync + 'static {}
@@ -23,3 +25,24 @@ pub trait TunCallback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {}
pub trait BindCallback<T>: Fn(&T, bool, bool) -> () + Sync + Send + 'static {}
pub trait Endpoint: Send + Sync {}
+
+pub trait Callbacks: Send + Sync + 'static {
+ type Opaque: Opaque;
+ type CallbackRecv: Callback<Self::Opaque>;
+ type CallbackSend: Callback<Self::Opaque>;
+ type CallbackKey: KeyCallback<Self::Opaque>;
+}
+
+pub struct CallbacksPhantom<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> {
+ _phantom_opaque: PhantomData<O>,
+ _phantom_recv: PhantomData<R>,
+ _phantom_send: PhantomData<S>,
+ _phantom_key: PhantomData<K>
+}
+
+impl <O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> Callbacks for CallbacksPhantom<O, R, S, K> {
+ type Opaque = O;
+ type CallbackRecv = R;
+ type CallbackSend = S;
+ type CallbackKey = K;
+} \ No newline at end of file
diff --git a/src/router/workers.rs b/src/router/workers.rs
index 4861847..2b0b9ec 100644
--- a/src/router/workers.rs
+++ b/src/router/workers.rs
@@ -15,7 +15,9 @@ use super::device::DecryptionState;
use super::device::DeviceInner;
use super::messages::TransportHeader;
use super::peer::PeerInner;
-use super::types::{Callback, KeyCallback, Opaque};
+use super::types::Callbacks;
+
+use super::super::types::Tun;
#[derive(PartialEq, Debug)]
pub enum Operation {
@@ -39,7 +41,7 @@ pub struct JobInner {
pub type JobBuffer = Arc<spin::Mutex<JobInner>>;
pub type JobParallel = (Arc<thread::JoinHandle<()>>, JobBuffer);
-pub type JobInbound<T, S, R, K> = (Weak<DecryptionState<T, S, R, K>>, JobBuffer);
+pub type JobInbound<C, T> = (Weak<DecryptionState<C, T>>, JobBuffer);
pub type JobOutbound = JobBuffer;
/* Strategy for workers acquiring a new job:
@@ -87,10 +89,10 @@ fn wait_recv<T>(running: &AtomicBool, recv: &Receiver<T>) -> Result<T, TryRecvEr
return Err(TryRecvError::Disconnected);
}
-pub fn worker_inbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>(
- device: Arc<DeviceInner<T, S, R, K>>, // related device
- peer: Arc<PeerInner<T, S, R, K>>, // related peer
- recv: Receiver<JobInbound<T, S, R, K>>, // in order queue
+pub fn worker_inbound<C: Callbacks, T: Tun>(
+ device: Arc<DeviceInner<C, T>>, // related device
+ peer: Arc<PeerInner<C, T>>, // related peer
+ recv: Receiver<JobInbound<C, T>>, // in order queue
) {
loop {
match wait_recv(&peer.stopped, &recv) {
@@ -134,7 +136,7 @@ pub fn worker_inbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<
packet.len() >= CHACHA20_POLY1305.nonce_len(),
"this should be checked earlier in the pipeline"
);
- (device.event_recv)(
+ (device.call_recv)(
&peer.opaque,
packet.len() > CHACHA20_POLY1305.nonce_len(),
true,
@@ -155,10 +157,10 @@ pub fn worker_inbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<
}
}
-pub fn worker_outbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>(
- device: Arc<DeviceInner<T, S, R, K>>, // related device
- peer: Arc<PeerInner<T, S, R, K>>, // related peer
- recv: Receiver<JobOutbound>, // in order queue
+pub fn worker_outbound<C: Callbacks, T: Tun>(
+ device: Arc<DeviceInner<C, T>>, // related device
+ peer: Arc<PeerInner<C, T>>, // related peer
+ recv: Receiver<JobOutbound>, // in order queue
) {
loop {
match wait_recv(&peer.stopped, &recv) {
@@ -180,7 +182,7 @@ pub fn worker_outbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback
let xmit = false;
// trigger callback
- (device.event_send)(
+ (device.call_send)(
&peer.opaque,
buf.msg.len()
> CHACHA20_POLY1305.nonce_len()
@@ -203,8 +205,8 @@ pub fn worker_outbound<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback
}
}
-pub fn worker_parallel<T: Opaque, S: Callback<T>, R: Callback<T>, K: KeyCallback<T>>(
- device: Arc<DeviceInner<T, S, R, K>>,
+pub fn worker_parallel<C: Callbacks, T: Tun>(
+ device: Arc<DeviceInner<C, T>>,
local: Worker<JobParallel>, // local job queue (local to thread)
stealers: Vec<Stealer<JobParallel>>, // stealers (from other threads)
) {