aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/config.rs183
-rw-r--r--src/handshake/device.rs9
-rw-r--r--src/main.rs7
-rw-r--r--src/router/device.rs55
-rw-r--r--src/router/peer.rs65
-rw-r--r--src/router/tests.rs46
-rw-r--r--src/router/types.rs2
-rw-r--r--src/router/workers.rs45
-rw-r--r--src/timers.rs8
-rw-r--r--src/types/bind.rs79
-rw-r--r--src/types/dummy.rs172
-rw-r--r--src/types/endpoint.rs2
-rw-r--r--src/types/mod.rs8
-rw-r--r--src/types/tun.rs43
-rw-r--r--src/types/udp.rs29
-rw-r--r--src/wireguard.rs200
16 files changed, 610 insertions, 343 deletions
diff --git a/src/config.rs b/src/config.rs
new file mode 100644
index 0000000..60faf43
--- /dev/null
+++ b/src/config.rs
@@ -0,0 +1,183 @@
+use std::error::Error;
+use std::net::{IpAddr, SocketAddr};
+use x25519_dalek::{PublicKey, StaticSecret};
+
+use crate::wireguard::Wireguard;
+use crate::types::{Bind, Endpoint, Tun};
+
+///
+/// The goal of the configuration interface is, among others,
+/// to hide the IO implementations (over which the WG device is generic),
+/// from the configuration and UAPI code.
+
+/// Describes a snapshot of the state of a peer
+pub struct PeerState {
+ rx_bytes: u64,
+ tx_bytes: u64,
+ last_handshake_time_sec: u64,
+ last_handshake_time_nsec: u64,
+ public_key: PublicKey,
+ allowed_ips: Vec<(IpAddr, u32)>,
+}
+
+pub enum ConfigError {
+ NoSuchPeer
+}
+
+impl ConfigError {
+
+ fn errno(&self) -> i32 {
+ match self {
+ NoSuchPeer => 1,
+ }
+ }
+}
+
+/// Exposed configuration interface
+pub trait Configuration {
+ /// Updates the private key of the device
+ ///
+ /// # Arguments
+ ///
+ /// - `sk`: The new private key (or None, if the private key should be cleared)
+ fn set_private_key(&self, sk: Option<StaticSecret>);
+
+ /// Returns the private key of the device
+ ///
+ /// # Returns
+ ///
+ /// The private if set, otherwise None.
+ fn get_private_key(&self) -> Option<StaticSecret>;
+
+ /// Returns the protocol version of the device
+ ///
+ /// # Returns
+ ///
+ /// An integer indicating the protocol version
+ fn get_protocol_version(&self) -> usize;
+
+ fn set_listen_port(&self, port: u16) -> Option<ConfigError>;
+
+ /// Set the firewall mark (or similar, depending on platform)
+ ///
+ /// # Arguments
+ ///
+ /// - `mark`: The fwmark value
+ ///
+ /// # Returns
+ ///
+ /// An error if this operation is not supported by the underlying
+ /// "bind" implementation.
+ fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError>;
+
+ /// Removes all peers from the device
+ fn replace_peers(&self);
+
+ /// Remove the peer from the
+ ///
+ /// # Arguments
+ ///
+ /// - `peer`: The public key of the peer to remove
+ ///
+ /// # Returns
+ ///
+ /// If the peer does not exists this operation is a noop
+ fn remove_peer(&self, peer: PublicKey);
+
+ /// Adds a new peer to the device
+ ///
+ /// # Arguments
+ ///
+ /// - `peer`: The public key of the peer to add
+ ///
+ /// # Returns
+ ///
+ /// A bool indicating if the peer was added.
+ ///
+ /// If the peer already exists this operation is a noop
+ fn add_peer(&self, peer: PublicKey) -> bool;
+
+ /// Update the psk of a peer
+ ///
+ /// # Arguments
+ ///
+ /// - `peer`: The public key of the peer
+ /// - `psk`: The new psk or None if the psk should be unset
+ ///
+ /// # Returns
+ ///
+ /// An error if no such peer exists
+ fn set_preshared_key(&self, peer: PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError>;
+
+ /// Update the endpoint of the
+ ///
+ /// # Arguments
+ ///
+ /// - `peer': The public key of the peer
+ /// - `psk`
+ fn set_endpoint(&self, peer: PublicKey, addr: SocketAddr) -> Option<ConfigError>;
+
+ /// Update the endpoint of the
+ ///
+ /// # Arguments
+ ///
+ /// - `peer': The public key of the peer
+ /// - `psk`
+ fn set_persistent_keepalive_interval(&self, peer: PublicKey) -> Option<ConfigError>;
+
+ /// Remove all allowed IPs from the peer
+ ///
+ /// # Arguments
+ ///
+ /// - `peer': The public key of the peer
+ ///
+ /// # Returns
+ ///
+ /// An error if no such peer exists
+ fn replace_allowed_ips(&self, peer: PublicKey) -> Option<ConfigError>;
+
+ /// Add a new allowed subnet to the peer
+ ///
+ /// # Arguments
+ ///
+ /// - `peer`: The public key of the peer
+ /// - `ip`: Subnet mask
+ /// - `masklen`:
+ ///
+ /// # Returns
+ ///
+ /// An error if the peer does not exist
+ ///
+ /// # Note:
+ ///
+ /// The API must itself sanitize the (ip, masklen) set:
+ /// The ip should be masked to remove any set bits right of the first "masklen" bits.
+ fn add_allowed_ip(&self, peer: PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError>;
+
+ /// Returns the state of all peers
+ ///
+ /// # Returns
+ ///
+ /// A list of structures describing the state of each peer
+ fn get_peers(&self) -> Vec<PeerState>;
+}
+
+impl <T : Tun, B : Bind>Configuration for Wireguard<T, B> {
+
+ fn set_private_key(&self, sk : Option<StaticSecret>) {
+ self.set_key(sk)
+ }
+
+ fn get_private_key(&self) -> Option<StaticSecret> {
+ self.get_sk()
+ }
+
+ fn get_protocol_version(&self) -> usize {
+ 1
+ }
+
+ fn set_listen_port(&self, port : u16) -> Option<ConfigError> {
+
+ }
+
+} \ No newline at end of file
diff --git a/src/handshake/device.rs b/src/handshake/device.rs
index 6178831..6a55f6e 100644
--- a/src/handshake/device.rs
+++ b/src/handshake/device.rs
@@ -76,6 +76,15 @@ impl Device {
}
}
+ /// Return the secret key of the device
+ ///
+ /// # Returns
+ ///
+ /// A secret key (x25519 scalar)
+ pub fn get_sk(&self) -> StaticSecret {
+ StaticSecret::from(self.sk.to_bytes())
+ }
+
/// Add a new public key to the state machine
/// To remove public keys, you must create a new machine instance
///
diff --git a/src/main.rs b/src/main.rs
index 6133884..7a31119 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -5,6 +5,7 @@ extern crate jemallocator;
#[global_allocator]
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
+// mod config;
mod constants;
mod handshake;
mod router;
@@ -14,7 +15,8 @@ mod wireguard;
#[cfg(test)]
mod tests {
- use crate::types::{dummy, Bind};
+ use crate::types::tun::Tun;
+ use crate::types::{bind, dummy, tun};
use crate::wireguard::Wireguard;
use std::thread;
@@ -27,7 +29,8 @@ mod tests {
#[test]
fn test_pure_wireguard() {
init();
- let wg = Wireguard::new(dummy::TunTest::new(), dummy::VoidBind::new());
+ let (reader, writer, mtu) = dummy::TunTest::create("name").unwrap();
+ let wg: Wireguard<dummy::TunTest, dummy::PairBind> = Wireguard::new(reader, writer, mtu);
thread::sleep(Duration::from_millis(500));
}
}
diff --git a/src/router/device.rs b/src/router/device.rs
index d126959..989c2c2 100644
--- a/src/router/device.rs
+++ b/src/router/device.rs
@@ -17,21 +17,23 @@ use super::constants::*;
use super::ip::*;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerInner};
-use super::types::{Callbacks, Opaque, RouterError};
+use super::types::{Callbacks, RouterError};
use super::workers::{worker_parallel, JobParallel, Operation};
use super::SIZE_MESSAGE_PREFIX;
-use super::super::types::{Bind, KeyPair, Tun};
+use super::super::types::{KeyPair, Endpoint, bind, tun};
-pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
- // IO & timer callbacks
- pub tun: T,
- pub bind: B,
+pub struct DeviceInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ // inbound writer (TUN)
+ pub inbound: T,
+
+ // outbound writer (Bind)
+ pub outbound: RwLock<Option<B>>,
// routing
- pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
- pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
- pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
+ pub recv: RwLock<HashMap<u32, Arc<DecryptionState<E, C, T, B>>>>, // receiver id -> decryption state
+ pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv4 cryptkey routing
+ pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<E, C, T, B>>>>, // ipv6 cryptkey routing
// work queues
pub queue_next: AtomicUsize, // next round-robin index
@@ -45,20 +47,20 @@ pub struct EncryptionState {
pub death: Instant, // (birth + reject-after-time - keepalive-timeout - rekey-timeout)
}
-pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
+pub struct DecryptionState<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
pub keypair: Arc<KeyPair>,
pub confirmed: AtomicBool,
pub protector: Mutex<AntiReplay>,
- pub peer: Arc<PeerInner<C, T, B>>,
+ pub peer: Arc<PeerInner<E, C, T, B>>,
pub death: Instant, // time when the key can no longer be used for decryption
}
-pub struct Device<C: Callbacks, T: Tun, B: Bind> {
- state: Arc<DeviceInner<C, T, B>>, // reference to device state
+pub struct Device<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ state: Arc<DeviceInner<E, C, T, B>>, // reference to device state
handles: Vec<thread::JoinHandle<()>>, // join handles for workers
}
-impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
+impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Device<E, C, T, B> {
fn drop(&mut self) {
debug!("router: dropping device");
@@ -83,10 +85,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
}
#[inline(always)]
-fn get_route<C: Callbacks, T: Tun, B: Bind>(
- device: &Arc<DeviceInner<C, T, B>>,
+fn get_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: &Arc<DeviceInner<E, C, T, B>>,
packet: &[u8],
-) -> Option<Arc<PeerInner<C, T, B>>> {
+) -> Option<Arc<PeerInner<E, C, T, B>>> {
// ensure version access within bounds
if packet.len() < 1 {
return None;
@@ -122,12 +124,12 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>(
}
}
-impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
- pub fn new(num_workers: usize, tun: T, bind: B) -> Device<C, T, B> {
+impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Device<E, C, T, B> {
+ pub fn new(num_workers: usize, tun: T) -> Device<E, C, T, B> {
// allocate shared device state
let mut inner = DeviceInner {
- tun,
- bind,
+ inbound: tun,
+ outbound: RwLock::new(None),
queues: Mutex::new(Vec::with_capacity(num_workers)),
queue_next: AtomicUsize::new(0),
recv: RwLock::new(HashMap::new()),
@@ -159,7 +161,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
/// # Returns
///
/// A atomic ref. counted peer (with liftime matching the device)
- pub fn new_peer(&self, opaque: C::Opaque) -> Peer<C, T, B> {
+ pub fn new_peer(&self, opaque: C::Opaque) -> Peer<E, C, T, B> {
new_peer(self.state.clone(), opaque)
}
@@ -199,7 +201,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
/// # Returns
///
///
- pub fn recv(&self, src: B::Endpoint, msg: Vec<u8>) -> Result<(), RouterError> {
+ pub fn recv(&self, src: E, msg: Vec<u8>) -> Result<(), RouterError> {
// parse / cast
let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
Some(v) => v,
@@ -231,4 +233,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
Ok(())
}
+
+ /// Set outbound writer
+ ///
+ ///
+ pub fn set_outbound_writer(&self, new : B) {
+ *self.state.outbound.write() = Some(new);
+ }
}
diff --git a/src/router/peer.rs b/src/router/peer.rs
index 86723bb..189904c 100644
--- a/src/router/peer.rs
+++ b/src/router/peer.rs
@@ -14,7 +14,7 @@ use treebitmap::IpLookupTable;
use zerocopy::LayoutVerified;
use super::super::constants::*;
-use super::super::types::{Bind, Endpoint, KeyPair, Tun};
+use super::super::types::{Endpoint, KeyPair, bind, tun};
use super::anti_replay::AntiReplay;
use super::device::DecryptionState;
@@ -39,28 +39,28 @@ pub struct KeyWheel {
retired: Vec<u32>, // retired ids
}
-pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
- pub device: Arc<DeviceInner<C, T, B>>,
+pub struct PeerInner<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ pub device: Arc<DeviceInner<E, C, T, B>>,
pub opaque: C::Opaque,
pub outbound: Mutex<SyncSender<JobOutbound>>,
- pub inbound: Mutex<SyncSender<JobInbound<C, T, B>>>,
+ pub inbound: Mutex<SyncSender<JobInbound<E, C, T, B>>>,
pub staged_packets: Mutex<ArrayDeque<[Vec<u8>; MAX_STAGED_PACKETS], Wrapping>>,
pub keys: Mutex<KeyWheel>,
pub ekey: Mutex<Option<EncryptionState>>,
- pub endpoint: Mutex<Option<B::Endpoint>>,
+ pub endpoint: Mutex<Option<E>>,
}
-pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
- state: Arc<PeerInner<C, T, B>>,
+pub struct Peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> {
+ state: Arc<PeerInner<E, C, T, B>>,
thread_outbound: Option<thread::JoinHandle<()>>,
thread_inbound: Option<thread::JoinHandle<()>>,
}
-fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>(
- peer: &Arc<PeerInner<C, T, B>>,
- table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
- callback: Box<dyn Fn(A, u32) -> E>,
-) -> Vec<E>
+fn treebit_list<A, R, E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ peer: &Arc<PeerInner<E, C, T, B>>,
+ table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
+ callback: Box<dyn Fn(A, u32) -> R>,
+) -> Vec<R>
where
A: Address,
{
@@ -74,9 +74,9 @@ where
res
}
-fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
- peer: &Peer<C, T, B>,
- table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
+fn treebit_remove<E : Endpoint, A: Address, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ peer: &Peer<E, C, T, B>,
+ table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<E, C, T, B>>>>,
) {
let mut m = table.write();
@@ -107,8 +107,8 @@ impl EncryptionState {
}
}
-impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
- fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> {
+impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> DecryptionState<E, C, T, B> {
+ fn new(peer: &Arc<PeerInner<E, C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<E, C, T, B> {
DecryptionState {
confirmed: AtomicBool::new(keypair.initiator),
keypair: keypair.clone(),
@@ -119,7 +119,7 @@ impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
}
}
-impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
+impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Drop for Peer<E, C, T, B> {
fn drop(&mut self) {
let peer = &self.state;
@@ -167,10 +167,10 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
}
}
-pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
- device: Arc<DeviceInner<C, T, B>>,
+pub fn new_peer<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: Arc<DeviceInner<E, C, T, B>>,
opaque: C::Opaque,
-) -> Peer<C, T, B> {
+) -> Peer<E, C, T, B> {
let (out_tx, out_rx) = sync_channel(128);
let (in_tx, in_rx) = sync_channel(128);
@@ -215,7 +215,7 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
}
}
-impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
+impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> PeerInner<E, C, T, B> {
fn send_staged(&self) -> bool {
debug!("peer.send_staged");
let mut sent = false;
@@ -286,8 +286,8 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
pub fn recv_job(
&self,
- src: B::Endpoint,
- dec: Arc<DecryptionState<C, T, B>>,
+ src: E,
+ dec: Arc<DecryptionState<E, C, T, B>>,
mut msg: Vec<u8>,
) -> Option<JobParallel> {
let (tx, rx) = oneshot();
@@ -370,7 +370,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
}
}
-impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
+impl<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>> Peer<E, C, T, B> {
/// Set the endpoint of the peer
///
/// # Arguments
@@ -381,9 +381,9 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
///
/// This API still permits support for the "sticky socket" behavior,
/// as sockets should be "unsticked" when manually updating the endpoint
- pub fn set_endpoint(&self, address: SocketAddr) {
+ pub fn set_endpoint(&self, endpoint: E) {
debug!("peer.set_endpoint");
- *self.state.endpoint.lock() = Some(B::Endpoint::from_address(address));
+ *self.state.endpoint.lock() = Some(endpoint);
}
/// Returns the current endpoint of the peer (for configuration)
@@ -591,11 +591,12 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
debug!("peer.send");
let inner = &self.state;
match inner.endpoint.lock().as_ref() {
- Some(endpoint) => inner
- .device
- .bind
- .send(msg, endpoint)
- .map_err(|_| RouterError::SendError),
+ Some(endpoint) => inner.device
+ .outbound
+ .read()
+ .as_ref()
+ .ok_or(RouterError::SendError)
+ .and_then(|w| w.write(msg, endpoint).map_err(|_| RouterError::SendError) ),
None => Err(RouterError::NoEndpoint),
}
}
diff --git a/src/router/tests.rs b/src/router/tests.rs
index f42e1f6..3b6b941 100644
--- a/src/router/tests.rs
+++ b/src/router/tests.rs
@@ -1,18 +1,18 @@
-use std::error::Error;
-use std::fmt;
-use std::net::{IpAddr, SocketAddr};
+use std::net::IpAddr;
use std::sync::atomic::Ordering;
-use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
-use std::time::{Duration, Instant};
+use std::time::Duration;
use num_cpus;
use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
-use super::super::types::{dummy, Bind, Endpoint, Key, KeyPair, Tun};
+use super::super::types::bind::*;
+use super::super::types::tun::*;
+use super::super::types::*;
+
use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
extern crate test;
@@ -145,8 +145,9 @@ mod tests {
}
// create device
- let router: Device<BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
- Device::new(num_cpus::get(), dummy::TunTest {}, dummy::VoidBind::new());
+ let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap();
+ let router: Device<_, BencherCallbacks, dummy::TunTest, dummy::VoidBind> =
+ Device::new(num_cpus::get(), tun_writer);
// add new peer
let opaque = Arc::new(AtomicUsize::new(0));
@@ -174,8 +175,9 @@ mod tests {
init();
// create device
- let router: Device<TestCallbacks, _, _> =
- Device::new(1, dummy::TunTest::new(), dummy::VoidBind::new());
+ let (_reader, tun_writer, _mtu) = dummy::TunTest::create("name").unwrap();
+ let router: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer);
+ router.set_outbound_writer(dummy::VoidBind::new());
let tests = vec![
("192.168.1.0", 24, "192.168.1.20", true),
@@ -315,12 +317,18 @@ mod tests {
];
for (stage, p1, p2) in tests.iter() {
- // create matching devices
- let (bind1, bind2) = dummy::PairBind::pair();
- let router1: Device<TestCallbacks, _, _> =
- Device::new(1, dummy::TunTest::new(), bind1.clone());
- let router2: Device<TestCallbacks, _, _> =
- Device::new(1, dummy::TunTest::new(), bind2.clone());
+ let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) =
+ dummy::PairBind::pair();
+
+ // create matching device
+ let (tun_writer1, _, _) = dummy::TunTest::create("tun1").unwrap();
+ let (tun_writer2, _, _) = dummy::TunTest::create("tun1").unwrap();
+
+ let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1);
+ router1.set_outbound_writer(bind_writer1);
+
+ let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2);
+ router2.set_outbound_writer(bind_writer2);
// prepare opaque values for tracing callbacks
@@ -339,7 +347,7 @@ mod tests {
let peer2 = router2.new_peer(opaq2.clone());
let mask: IpAddr = mask.parse().unwrap();
peer2.add_subnet(mask, *len);
- peer2.set_endpoint("127.0.0.1:8080".parse().unwrap());
+ peer2.set_endpoint(dummy::UnitEndpoint::new());
if *stage {
// stage a packet which can be used for confirmation (in place of a keepalive)
@@ -372,7 +380,7 @@ mod tests {
// read confirming message received by the other end ("across the internet")
let mut buf = vec![0u8; 2048];
- let (len, from) = bind1.recv(&mut buf).unwrap();
+ let (len, from) = bind_reader1.read(&mut buf).unwrap();
buf.truncate(len);
router1.recv(from, buf).unwrap();
@@ -411,7 +419,7 @@ mod tests {
// receive ("across the internet") on the other end
let mut buf = vec![0u8; 2048];
- let (len, from) = bind2.recv(&mut buf).unwrap();
+ let (len, from) = bind_reader2.read(&mut buf).unwrap();
buf.truncate(len);
router2.recv(from, buf).unwrap();
diff --git a/src/router/types.rs b/src/router/types.rs
index b7c3ae0..4a72c27 100644
--- a/src/router/types.rs
+++ b/src/router/types.rs
@@ -1,6 +1,8 @@
use std::error::Error;
use std::fmt;
+use super::super::types::Endpoint;
+
pub trait Opaque: Send + Sync + 'static {}
impl<T> Opaque for T where T: Send + Sync + 'static {}
diff --git a/src/router/workers.rs b/src/router/workers.rs
index 6710816..2e89bb0 100644
--- a/src/router/workers.rs
+++ b/src/router/workers.rs
@@ -17,7 +17,7 @@ use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::PeerInner;
use super::types::Callbacks;
-use super::super::types::{Bind, Tun};
+use super::super::types::{Endpoint, tun, bind};
use super::ip::*;
const SIZE_TAG: usize = 16;
@@ -38,18 +38,18 @@ pub struct JobBuffer {
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
#[allow(type_alias_bounds)]
-pub type JobInbound<C, T, B: Bind> = (
- Arc<DecryptionState<C, T, B>>,
- B::Endpoint,
+pub type JobInbound<E, C, T, B: bind::Writer<E>> = (
+ Arc<DecryptionState<E, C, T, B>>,
+ E,
oneshot::Receiver<JobBuffer>,
);
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
#[inline(always)]
-fn check_route<C: Callbacks, T: Tun, B: Bind>(
- device: &Arc<DeviceInner<C, T, B>>,
- peer: &Arc<PeerInner<C, T, B>>,
+fn check_route<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: &Arc<DeviceInner<E, C, T, B>>,
+ peer: &Arc<PeerInner<E, C, T, B>>,
packet: &[u8],
) -> Option<usize> {
match packet[0] >> 4 {
@@ -93,10 +93,10 @@ fn check_route<C: Callbacks, T: Tun, B: Bind>(
}
}
-pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
- device: Arc<DeviceInner<C, T, B>>, // related device
- peer: Arc<PeerInner<C, T, B>>, // related peer
- receiver: Receiver<JobInbound<C, T, B>>,
+pub fn worker_inbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: Arc<DeviceInner<E, C, T, B>>, // related device
+ peer: Arc<PeerInner<E, C, T, B>>, // related peer
+ receiver: Receiver<JobInbound<E, C, T, B>>,
) {
loop {
// fetch job
@@ -153,7 +153,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
debug_assert!(inner_len <= length, "should be validated");
if inner_len <= length {
- sent = match device.tun.write(&packet[..inner_len]) {
+ sent = match device.inbound.write(&packet[..inner_len]) {
Err(e) => {
debug!("failed to write inbound packet to TUN: {:?}", e);
false
@@ -176,9 +176,9 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
}
}
-pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
- device: Arc<DeviceInner<C, T, B>>, // related device
- peer: Arc<PeerInner<C, T, B>>, // related peer
+pub fn worker_outbound<E : Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
+ device: Arc<DeviceInner<E, C, T, B>>, // related device
+ peer: Arc<PeerInner<E, C, T, B>>, // related peer
receiver: Receiver<JobOutbound>,
) {
loop {
@@ -198,12 +198,17 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
if buf.okay {
// write to UDP bind
let xmit = if let Some(dst) = peer.endpoint.lock().as_ref() {
- match device.bind.send(&buf.msg[..], dst) {
- Err(e) => {
- debug!("failed to send outbound packet: {:?}", e);
- false
+ let send : &Option<B> = &*device.outbound.read();
+ if let Some(writer) = send.as_ref() {
+ match writer.write(&buf.msg[..], dst) {
+ Err(e) => {
+ debug!("failed to send outbound packet: {:?}", e);
+ false
+ }
+ Ok(_) => true,
}
- Ok(_) => true,
+ } else {
+ false
}
} else {
false
diff --git a/src/timers.rs b/src/timers.rs
index 303fd35..67ece06 100644
--- a/src/timers.rs
+++ b/src/timers.rs
@@ -7,7 +7,7 @@ use hjul::{Runner, Timer};
use crate::constants::*;
use crate::router::Callbacks;
-use crate::types::{Bind, Tun};
+use crate::types::{tun, bind};
use crate::wireguard::{Peer, PeerInner};
pub struct Timers {
@@ -23,8 +23,8 @@ pub struct Timers {
impl Timers {
pub fn new<T, B>(runner: &Runner, peer: Peer<T, B>) -> Timers
where
- T: Tun,
- B: Bind,
+ T: tun::Tun,
+ B: bind::Bind,
{
// create a timer instance for the provided peer
Timers {
@@ -103,7 +103,7 @@ impl Timers {
pub struct Events<T, B>(PhantomData<(T, B)>);
-impl<T: Tun, B: Bind> Callbacks for Events<T, B> {
+impl<T: tun::Tun, B: bind::Bind> Callbacks for Events<T, B> {
type Opaque = Arc<PeerInner<B>>;
fn send(peer: &Self::Opaque, size: usize, data: bool, sent: bool) {
diff --git a/src/types/bind.rs b/src/types/bind.rs
index 62adbbb..fcc38c8 100644
--- a/src/types/bind.rs
+++ b/src/types/bind.rs
@@ -1,73 +1,28 @@
use super::Endpoint;
-use std::error;
+use std::error::Error;
-/// Traits representing the "internet facing" end of the VPN.
-///
-/// In practice this is a UDP socket (but the router interface is agnostic).
-/// Often these traits will be implemented on the same type.
+pub trait Reader<E: Endpoint>: Send + Sync {
+ type Error: Error;
-/// Bind interface provided to the router code
-pub trait RouterBind: Send + Sync {
- type Error: error::Error;
- type Endpoint: Endpoint;
+ fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>;
+}
- /// Receive a buffer on the bind
- ///
- /// # Arguments
- ///
- /// - `buf`, buffer for storing the packet. If the buffer is too short, the packet should just be truncated.
- ///
- /// # Note
- ///
- /// The size of the buffer is derieved from the MTU of the Tun device.
- fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error>;
+pub trait Writer<E: Endpoint>: Send + Sync + Clone + 'static {
+ type Error: Error;
- /// Send a buffer to the endpoint
- ///
- /// # Arguments
- ///
- /// - `buf`, packet src buffer (in practice the body of a UDP datagram)
- /// - `dst`, destination endpoint (in practice, src: (ip, port) + dst: (ip, port) for sticky sockets)
- ///
- /// # Returns
- ///
- /// The unit type or an error if transmission failed
- fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error>;
+ fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>;
}
-/// Bind interface provided for configuration (setting / getting the port)
-pub trait ConfigBind {
- type Error: error::Error;
-
- /// Return a new (unbound) instance of a configuration bind
- fn new() -> Self;
+pub trait Bind: Send + Sync + 'static {
+ type Error: Error;
+ type Endpoint: Endpoint;
- /// Updates the port of the bind
- ///
- /// # Arguments
- ///
- /// - `port`, the new port to bind to. 0 means any available port.
- ///
- /// # Returns
- ///
- /// The unit type or an error, if binding fails
- fn set_port(&self, port: u16) -> Result<(), Self::Error>;
+ /* Until Rust gets type equality constraints these have to be generic */
+ type Writer: Writer<Self::Endpoint>;
+ type Reader: Reader<Self::Endpoint>;
- /// Returns the current port of the bind
- fn get_port(&self) -> Option<u16>;
+ /* Used to close the reader/writer when binding to a new port */
+ type Closer;
- /// Set the mark (e.g. on Linus this is the fwmark) on the bind
- ///
- /// # Arguments
- ///
- /// - `mark`, the mark to set
- ///
- /// # Note
- ///
- /// The mark should be retained accross calls to `set_port`.
- ///
- /// # Returns
- ///
- /// The unit type or an error, if the operation fails due to permission errors
- fn set_mark(&self, mark: u16) -> Result<(), Self::Error>;
+ fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error>;
}
diff --git a/src/types/dummy.rs b/src/types/dummy.rs
index e15abb0..40a3bdd 100644
--- a/src/types/dummy.rs
+++ b/src/types/dummy.rs
@@ -5,8 +5,9 @@ use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Instant;
+use std::marker;
-use super::{Bind, Endpoint, Key, KeyPair, Tun};
+use super::*;
/* This submodule provides pure/dummy implementations of the IO interfaces
* for use in unit tests thoughout the project.
@@ -72,104 +73,103 @@ impl Endpoint for UnitEndpoint {
}
}
+impl UnitEndpoint {
+ pub fn new() -> UnitEndpoint {
+ UnitEndpoint{}
+ }
+}
+
+/* */
+
#[derive(Clone, Copy)]
pub struct TunTest {}
-impl Tun for TunTest {
+impl tun::Reader for TunTest {
type Error = TunError;
+ fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
+ Ok(0)
+ }
+}
+
+impl tun::MTU for TunTest {
fn mtu(&self) -> usize {
1500
}
+}
- fn read(&self, _buf: &mut [u8], _offset: usize) -> Result<usize, Self::Error> {
- Ok(0)
- }
+impl tun::Writer for TunTest {
+ type Error = TunError;
fn write(&self, _src: &[u8]) -> Result<(), Self::Error> {
Ok(())
}
}
+impl tun::Tun for TunTest {
+ type Writer = TunTest;
+ type Reader = TunTest;
+ type MTU = TunTest;
+ type Error = TunError;
+}
+
impl TunTest {
- pub fn new() -> TunTest {
- TunTest {}
+ pub fn create(_name: &str) -> Result<(TunTest, TunTest, TunTest), TunError> {
+ Ok((TunTest {},TunTest {}, TunTest{}))
}
}
-/* Bind implemenentations */
+/* Void Bind */
#[derive(Clone, Copy)]
pub struct VoidBind {}
-impl Bind for VoidBind {
+impl bind::Reader<UnitEndpoint> for VoidBind {
type Error = BindError;
- type Endpoint = UnitEndpoint;
- fn new() -> VoidBind {
- VoidBind {}
- }
-
- fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
- Ok(())
- }
-
- fn get_port(&self) -> Option<u16> {
- None
- }
-
- fn recv(&self, _buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
+ fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
Ok((0, UnitEndpoint {}))
}
-
- fn send(&self, _buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
- Ok(())
- }
}
-#[derive(Clone)]
-pub struct PairBind {
- send: Arc<Mutex<SyncSender<Vec<u8>>>>,
- recv: Arc<Mutex<Receiver<Vec<u8>>>>,
-}
+impl bind::Writer<UnitEndpoint> for VoidBind {
+ type Error = BindError;
-impl PairBind {
- pub fn pair() -> (PairBind, PairBind) {
- let (tx1, rx1) = sync_channel(128);
- let (tx2, rx2) = sync_channel(128);
- (
- PairBind {
- send: Arc::new(Mutex::new(tx1)),
- recv: Arc::new(Mutex::new(rx2)),
- },
- PairBind {
- send: Arc::new(Mutex::new(tx2)),
- recv: Arc::new(Mutex::new(rx1)),
- },
- )
+ fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
+ Ok(())
}
}
-impl Bind for PairBind {
+impl bind::Bind for VoidBind {
type Error = BindError;
type Endpoint = UnitEndpoint;
- fn new() -> PairBind {
- PairBind {
- send: Arc::new(Mutex::new(sync_channel(0).0)),
- recv: Arc::new(Mutex::new(sync_channel(0).1)),
- }
- }
+ type Reader = VoidBind;
+ type Writer = VoidBind;
+ type Closer = ();
- fn set_port(&self, _port: u16) -> Result<(), Self::Error> {
- Ok(())
+ fn bind(_ : u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
+ Ok((VoidBind{}, VoidBind{}, (), 2600))
}
+}
- fn get_port(&self) -> Option<u16> {
- None
+impl VoidBind {
+ pub fn new() -> VoidBind {
+ VoidBind{}
}
+}
- fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error> {
+/* Pair Bind */
+
+#[derive(Clone)]
+pub struct PairReader<E> {
+ recv: Arc<Mutex<Receiver<Vec<u8>>>>,
+ _marker: marker::PhantomData<E>,
+}
+
+impl bind::Reader<UnitEndpoint> for PairReader<UnitEndpoint> {
+ type Error = BindError;
+ fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> {
let vec = self
.recv
.lock()
@@ -180,8 +180,11 @@ impl Bind for PairBind {
buf[..len].copy_from_slice(&vec[..]);
Ok((vec.len(), UnitEndpoint {}))
}
+}
- fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
+impl bind::Writer<UnitEndpoint> for PairWriter<UnitEndpoint> {
+ type Error = BindError;
+ fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> {
let owned = buf.to_owned();
match self.send.lock().unwrap().send(owned) {
Err(_) => Err(BindError::Disconnected),
@@ -190,6 +193,57 @@ impl Bind for PairBind {
}
}
+#[derive(Clone)]
+pub struct PairWriter<E> {
+ send: Arc<Mutex<SyncSender<Vec<u8>>>>,
+ _marker: marker::PhantomData<E>,
+}
+
+#[derive(Clone)]
+pub struct PairBind {}
+
+impl PairBind {
+ pub fn pair<E>() -> ((PairReader<E>, PairWriter<E>), (PairReader<E>, PairWriter<E>)) {
+ let (tx1, rx1) = sync_channel(128);
+ let (tx2, rx2) = sync_channel(128);
+ (
+ (
+ PairReader{
+
+ recv: Arc::new(Mutex::new(rx1)),
+ _marker: marker::PhantomData
+ },
+ PairWriter{
+ send: Arc::new(Mutex::new(tx2)),
+ _marker: marker::PhantomData
+ }
+ ),
+ (
+ PairReader{
+ recv: Arc::new(Mutex::new(rx2)),
+ _marker: marker::PhantomData
+ },
+ PairWriter{
+ send: Arc::new(Mutex::new(tx1)),
+ _marker: marker::PhantomData
+ }
+ ),
+ )
+ }
+}
+
+impl bind::Bind for PairBind {
+ type Closer = ();
+ type Error = BindError;
+ type Endpoint = UnitEndpoint;
+ type Reader = PairReader<Self::Endpoint>;
+ type Writer = PairWriter<Self::Endpoint>;
+
+ fn bind(_port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer, u16), Self::Error> {
+ Err(BindError::Disconnected)
+ }
+}
+
pub fn keypair(initiator: bool) -> KeyPair {
let k1 = Key {
key: [0x53u8; 32],
diff --git a/src/types/endpoint.rs b/src/types/endpoint.rs
index 261203f..74796aa 100644
--- a/src/types/endpoint.rs
+++ b/src/types/endpoint.rs
@@ -1,6 +1,6 @@
use std::net::SocketAddr;
-pub trait Endpoint: Send {
+pub trait Endpoint: Send + 'static {
fn from_address(addr: SocketAddr) -> Self;
fn into_address(&self) -> SocketAddr;
}
diff --git a/src/types/mod.rs b/src/types/mod.rs
index 07ca44d..e0725f3 100644
--- a/src/types/mod.rs
+++ b/src/types/mod.rs
@@ -1,12 +1,10 @@
mod endpoint;
mod keys;
-mod tun;
-mod udp;
+pub mod tun;
+pub mod bind;
#[cfg(test)]
pub mod dummy;
pub use endpoint::Endpoint;
-pub use keys::{Key, KeyPair};
-pub use tun::Tun;
-pub use udp::Bind;
+pub use keys::{Key, KeyPair}; \ No newline at end of file
diff --git a/src/types/tun.rs b/src/types/tun.rs
index fc8044a..2ba16ff 100644
--- a/src/types/tun.rs
+++ b/src/types/tun.rs
@@ -1,18 +1,22 @@
-use std::error;
+use std::error::Error;
-pub trait Tun: Send + Sync + Clone + 'static {
- type Error: error::Error;
+pub trait Writer: Send + Sync + 'static {
+ type Error: Error;
- /// Returns the MTU of the device
+ /// Receive a cryptkey routed IP packet
///
- /// This function needs to be efficient (called for every read).
- /// The goto implementation strategy is to .load an atomic variable,
- /// then use e.g. netlink to update the variable in a separate thread.
+ /// # Arguments
+ ///
+ /// - src: Buffer containing the IP packet to be written
///
/// # Returns
///
- /// The MTU of the interface in bytes
- fn mtu(&self) -> usize;
+ /// Unit type or an error
+ fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
+}
+
+pub trait Reader: Send + 'static {
+ type Error: Error;
/// Reads an IP packet into dst[offset:] from the tunnel device
///
@@ -29,15 +33,24 @@ pub trait Tun: Send + Sync + Clone + 'static {
///
/// The size of the IP packet (ignoring the header) or an std::error::Error instance:
fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error>;
+}
- /// Writes an IP packet to the tunnel device
- ///
- /// # Arguments
+pub trait MTU: Send + Sync + Clone + 'static {
+ /// Returns the MTU of the device
///
- /// - src: Buffer containing the IP packet to be written
+ /// This function needs to be efficient (called for every read).
+ /// The goto implementation strategy is to .load an atomic variable,
+ /// then use e.g. netlink to update the variable in a separate thread.
///
/// # Returns
///
- /// Unit type or an error
- fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
+ /// The MTU of the interface in bytes
+ fn mtu(&self) -> usize;
+}
+
+pub trait Tun: Send + Sync + 'static {
+ type Writer: Writer;
+ type Reader: Reader;
+ type MTU: MTU;
+ type Error: Error;
}
diff --git a/src/types/udp.rs b/src/types/udp.rs
deleted file mode 100644
index 943bf94..0000000
--- a/src/types/udp.rs
+++ /dev/null
@@ -1,29 +0,0 @@
-use super::Endpoint;
-use std::error;
-
-/* Often times an a file descriptor in an atomic might suffice.
- */
-pub trait Bind: Send + Sync + Clone + 'static {
- type Error: error::Error + Send;
- type Endpoint: Endpoint;
-
- fn new() -> Self;
-
- /// Updates the port of the Bind
- ///
- /// # Arguments
- ///
- /// - port, The new port to bind to. 0 means any available port.
- ///
- /// # Returns
- ///
- /// The unit type or an error, if binding fails
- fn set_port(&self, port: u16) -> Result<(), Self::Error>;
-
- /// Returns the current port of the bind
- fn get_port(&self) -> Option<u16>;
-
- fn recv(&self, buf: &mut [u8]) -> Result<(usize, Self::Endpoint), Self::Error>;
-
- fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error>;
-}
diff --git a/src/wireguard.rs b/src/wireguard.rs
index ea600d0..ba81f47 100644
--- a/src/wireguard.rs
+++ b/src/wireguard.rs
@@ -2,11 +2,13 @@ use crate::constants::*;
use crate::handshake;
use crate::router;
use crate::timers::{Events, Timers};
-use crate::types::{Bind, Endpoint, Tun};
+
+use crate::types::Endpoint;
+use crate::types::tun::{Tun, Reader, MTU};
+use crate::types::bind::{Bind, Writer};
use hjul::Runner;
-use std::cmp;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
@@ -27,12 +29,20 @@ const SIZE_HANDSHAKE_QUEUE: usize = 128;
const THRESHOLD_UNDER_LOAD: usize = SIZE_HANDSHAKE_QUEUE / 4;
const DURATION_UNDER_LOAD: Duration = Duration::from_millis(10_000);
-#[derive(Clone)]
pub struct Peer<T: Tun, B: Bind> {
- pub router: Arc<router::Peer<Events<T, B>, T, B>>,
+ pub router: Arc<router::Peer<B::Endpoint, Events<T, B>, T::Writer, B::Writer>>,
pub state: Arc<PeerInner<B>>,
}
+impl <T : Tun, B : Bind> Clone for Peer<T, B > {
+ fn clone(&self) -> Peer<T, B> {
+ Peer{
+ router: self.router.clone(),
+ state: self.state.clone()
+ }
+ }
+}
+
pub struct PeerInner<B: Bind> {
pub keepalive: AtomicUsize, // keepalive interval
pub rx_bytes: AtomicU64,
@@ -66,20 +76,22 @@ pub enum HandshakeJob<E> {
}
struct WireguardInner<T: Tun, B: Bind> {
+ // provides access to the MTU value of the tun device
+ // (otherwise owned solely by the router and a dedicated read IO thread)
+ mtu: T::MTU,
+ send: RwLock<Option<B::Writer>>,
+
// identify and configuration map
peers: RwLock<HashMap<[u8; 32], Peer<T, B>>>,
// cryptkey router
- router: router::Device<Events<T, B>, T, B>,
+ router: router::Device<B::Endpoint, Events<T, B>, T::Writer, B::Writer>,
// handshake related state
handshake: RwLock<Handshake>,
under_load: AtomicBool,
pending: AtomicUsize, // num of pending handshake packets in queue
queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>,
-
- // IO
- bind: B,
}
pub struct Wireguard<T: Tun, B: Bind> {
@@ -87,6 +99,17 @@ pub struct Wireguard<T: Tun, B: Bind> {
state: Arc<WireguardInner<T, B>>,
}
+/* Returns the padded length of a message:
+ *
+ * # Arguments
+ *
+ * - `size` : Size of unpadded message
+ * - `mtu` : Maximum transmission unit of the device
+ *
+ * # Returns
+ *
+ * The padded length (always less than or equal to the MTU)
+ */
#[inline(always)]
const fn padding(size: usize, mtu: usize) -> usize {
#[inline(always)]
@@ -114,6 +137,15 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
}
}
+ pub fn get_sk(&self) -> Option<StaticSecret> {
+ let mut handshake = self.state.handshake.read();
+ if handshake.active {
+ Some(handshake.device.get_sk())
+ } else {
+ None
+ }
+ }
+
pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> {
let state = Arc::new(PeerInner {
pk,
@@ -137,20 +169,92 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
peer
}
- pub fn new(tun: T, bind: B) -> Wireguard<T, B> {
+ pub fn new_bind(
+ reader: B::Reader,
+ writer: B::Writer,
+ closer: B::Closer
+ ) {
+
+ // drop existing closer
+
+
+ // swap IO thread for new reader
+
+
+ // start UDP read IO thread
+
+ /*
+ {
+ let wg = wg.clone();
+ let mtu = mtu.clone();
+ thread::spawn(move || {
+ let mut last_under_load =
+ Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
+
+ loop {
+ // create vector big enough for any message given current MTU
+ let size = mtu.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
+ let mut msg: Vec<u8> = Vec::with_capacity(size);
+ msg.resize(size, 0);
+
+ // read UDP packet into vector
+ let (size, src) = reader.read(&mut msg).unwrap(); // TODO handle error
+ msg.truncate(size);
+
+ // message type de-multiplexer
+ if msg.len() < std::mem::size_of::<u32>() {
+ continue;
+ }
+ match LittleEndian::read_u32(&msg[..]) {
+ handshake::TYPE_COOKIE_REPLY
+ | handshake::TYPE_INITIATION
+ | handshake::TYPE_RESPONSE => {
+ // update under_load flag
+ if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
+ last_under_load = Instant::now();
+ wg.under_load.store(true, Ordering::SeqCst);
+ } else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
+ wg.under_load.store(false, Ordering::SeqCst);
+ }
+
+ wg.queue
+ .lock()
+ .send(HandshakeJob::Message(msg, src))
+ .unwrap();
+ }
+ router::TYPE_TRANSPORT => {
+ // transport message
+ let _ = wg.router.recv(src, msg);
+ }
+ _ => (),
+ }
+ }
+ });
+ }
+ */
+
+
+ }
+
+ pub fn new(
+ reader: T::Reader,
+ writer: T::Writer,
+ mtu: T::MTU,
+ ) -> Wireguard<T, B> {
// create device state
let mut rng = OsRng::new().unwrap();
let (tx, rx): (Sender<HandshakeJob<B::Endpoint>>, _) = bounded(SIZE_HANDSHAKE_QUEUE);
let wg = Arc::new(WireguardInner {
+ mtu: mtu.clone(),
peers: RwLock::new(HashMap::new()),
- router: router::Device::new(num_cpus::get(), tun.clone(), bind.clone()),
+ send: RwLock::new(None),
+ router: router::Device::new(num_cpus::get(), writer), // router owns the writing half
pending: AtomicUsize::new(0),
handshake: RwLock::new(Handshake {
device: handshake::Device::new(StaticSecret::new(&mut rng)),
active: false,
}),
under_load: AtomicBool::new(false),
- bind: bind.clone(),
queue: Mutex::new(tx),
});
@@ -158,7 +262,6 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
for _ in 0..num_cpus::get() {
let wg = wg.clone();
let rx = rx.clone();
- let bind = bind.clone();
thread::spawn(move || {
// prepare OsRng instance for this thread
let mut rng = OsRng::new().unwrap();
@@ -189,19 +292,22 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
Ok((pk, msg, keypair)) => {
// send response
if let Some(msg) = msg {
- let _ = bind.send(&msg[..], &src).map_err(|e| {
- debug!(
- "handshake worker, failed to send response, error = {:?}",
- e
- )
- });
+ let send : &Option<B::Writer> = &*wg.send.read();
+ if let Some(writer) = send.as_ref() {
+ let _ = writer.write(&msg[..], &src).map_err(|e| {
+ debug!(
+ "handshake worker, failed to send response, error = {:?}",
+ e
+ )
+ });
+ }
}
// update timers
if let Some(pk) = pk {
if let Some(peer) = wg.peers.read().get(pk.as_bytes()) {
- // update endpoint (DISCUSS: right semantics?)
- peer.router.set_endpoint(src_validate);
+ // update endpoint
+ peer.router.set_endpoint(src);
// add keypair to peer and free any unused ids
if let Some(keypair) = keypair {
@@ -227,68 +333,18 @@ impl<T: Tun, B: Bind> Wireguard<T, B> {
});
}
- // start UDP read IO thread
- {
- let wg = wg.clone();
- let tun = tun.clone();
- let bind = bind.clone();
- thread::spawn(move || {
- let mut last_under_load =
- Instant::now() - DURATION_UNDER_LOAD - Duration::from_millis(1000);
-
- loop {
- // create vector big enough for any message given current MTU
- let size = tun.mtu() + handshake::MAX_HANDSHAKE_MSG_SIZE;
- let mut msg: Vec<u8> = Vec::with_capacity(size);
- msg.resize(size, 0);
-
- // read UDP packet into vector
- let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
- msg.truncate(size);
-
- // message type de-multiplexer
- if msg.len() < std::mem::size_of::<u32>() {
- continue;
- }
- match LittleEndian::read_u32(&msg[..]) {
- handshake::TYPE_COOKIE_REPLY
- | handshake::TYPE_INITIATION
- | handshake::TYPE_RESPONSE => {
- // update under_load flag
- if wg.pending.fetch_add(1, Ordering::SeqCst) > THRESHOLD_UNDER_LOAD {
- last_under_load = Instant::now();
- wg.under_load.store(true, Ordering::SeqCst);
- } else if last_under_load.elapsed() > DURATION_UNDER_LOAD {
- wg.under_load.store(false, Ordering::SeqCst);
- }
-
- wg.queue
- .lock()
- .send(HandshakeJob::Message(msg, src))
- .unwrap();
- }
- router::TYPE_TRANSPORT => {
- // transport message
- let _ = wg.router.recv(src, msg);
- }
- _ => (),
- }
- }
- });
- }
-
// start TUN read IO thread
{
let wg = wg.clone();
thread::spawn(move || loop {
// create vector big enough for any transport message (based on MTU)
- let mtu = tun.mtu();
+ let mtu = mtu.mtu();
let size = mtu + router::SIZE_MESSAGE_PREFIX;
let mut msg: Vec<u8> = Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
msg.resize(size, 0);
// read a new IP packet
- let payload = tun.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
+ let payload = reader.read(&mut msg[..], router::SIZE_MESSAGE_PREFIX).unwrap();
debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu);
// truncate padding