From ee3599d5507ceee23ef3382dbda9de8e73c54a00 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Wed, 23 Oct 2019 12:08:35 +0200 Subject: Moved IO traits into platform module --- src/configuration/config.rs | 25 +-- src/configuration/mod.rs | 6 +- src/main.rs | 19 ++- src/platform/bind.rs | 43 +++++ src/platform/dummy.rs | 22 --- src/platform/dummy/bind.rs | 211 +++++++++++++++++++++++++ src/platform/dummy/endpoint.rs | 1 + src/platform/dummy/mod.rs | 13 ++ src/platform/dummy/tun.rs | 172 ++++++++++++++++++++ src/platform/endpoint.rs | 7 + src/platform/linux/tun.rs | 6 +- src/platform/linux/udp.rs | 3 +- src/platform/mod.rs | 32 +--- src/platform/tun.rs | 61 ++++++++ src/tests.rs | 1 + src/wireguard/mod.rs | 10 +- src/wireguard/router/device.rs | 2 +- src/wireguard/router/peer.rs | 2 +- src/wireguard/router/tests.rs | 14 +- src/wireguard/router/workers.rs | 2 +- src/wireguard/tests.rs | 3 +- src/wireguard/timers.rs | 2 +- src/wireguard/types.rs | 63 ++++++++ src/wireguard/types/bind.rs | 23 --- src/wireguard/types/dummy.rs | 339 ---------------------------------------- src/wireguard/types/endpoint.rs | 7 - src/wireguard/types/keys.rs | 36 ----- src/wireguard/types/mod.rs | 11 -- src/wireguard/types/tun.rs | 56 ------- src/wireguard/wireguard.rs | 8 +- 30 files changed, 641 insertions(+), 559 deletions(-) create mode 100644 src/platform/bind.rs delete mode 100644 src/platform/dummy.rs create mode 100644 src/platform/dummy/bind.rs create mode 100644 src/platform/dummy/endpoint.rs create mode 100644 src/platform/dummy/mod.rs create mode 100644 src/platform/dummy/tun.rs create mode 100644 src/platform/endpoint.rs create mode 100644 src/platform/tun.rs create mode 100644 src/tests.rs create mode 100644 src/wireguard/types.rs delete mode 100644 src/wireguard/types/bind.rs delete mode 100644 src/wireguard/types/dummy.rs delete mode 100644 src/wireguard/types/endpoint.rs delete mode 100644 src/wireguard/types/keys.rs delete mode 100644 src/wireguard/types/mod.rs delete mode 100644 src/wireguard/types/tun.rs diff --git a/src/configuration/config.rs b/src/configuration/config.rs index 24b1349..f42b53b 100644 --- a/src/configuration/config.rs +++ b/src/configuration/config.rs @@ -2,10 +2,8 @@ use spin::Mutex; use std::net::{IpAddr, SocketAddr}; use x25519_dalek::{PublicKey, StaticSecret}; -use super::BindOwner; -use super::PlatformBind; -use super::Tun; -use super::Wireguard; +use super::*; +use bind::Owner; /// The goal of the configuration interface is, among others, /// to hide the IO implementations (over which the WG device is generic), @@ -21,17 +19,26 @@ pub struct PeerState { allowed_ips: Vec<(IpAddr, u32)>, } -struct UDPState { +struct UDPState { fwmark: Option, owner: O, port: u16, } -pub struct WireguardConfig { +pub struct WireguardConfig { wireguard: Wireguard, network: Mutex>>, } +impl WireguardConfig { + fn new(wg: Wireguard) -> WireguardConfig { + WireguardConfig { + wireguard: wg, + network: Mutex::new(None), + } + } +} + pub enum ConfigError { NoSuchPeer, NotListening, @@ -41,8 +48,8 @@ impl ConfigError { fn errno(&self) -> i32 { // TODO: obtain the correct error values match self { - NoSuchPeer => 1, - NotListening => 2, + ConfigError::NoSuchPeer => 1, + ConfigError::NotListening => 2, } } } @@ -180,7 +187,7 @@ pub trait Configuration { fn get_peers(&self) -> Vec; } -impl Configuration for WireguardConfig { +impl Configuration for WireguardConfig { fn set_private_key(&self, sk: Option) { self.wireguard.set_key(sk) } diff --git a/src/configuration/mod.rs b/src/configuration/mod.rs index 56a83e2..520b397 100644 --- a/src/configuration/mod.rs +++ b/src/configuration/mod.rs @@ -1,5 +1,7 @@ mod config; -use super::platform::{BindOwner, PlatformBind}; -use super::wireguard::tun::Tun; +use super::platform::{bind, tun}; use super::wireguard::Wireguard; + +pub use config::Configuration; +pub use config::WireguardConfig; diff --git a/src/main.rs b/src/main.rs index 4dac3cd..5aaeb25 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,8 +10,23 @@ mod configuration; mod platform; mod wireguard; -use platform::PlatformTun; +mod tests; + +use platform::tun; + +use configuration::WireguardConfig; fn main() { - let (readers, writer, mtu) = platform::TunInstance::create("test").unwrap(); + /* + let (mut readers, writer, mtu) = platform::TunInstance::create("test").unwrap(); + let wg = wireguard::Wireguard::new(readers, writer, mtu); + */ +} + +/* +fn test_wg_configuration() { + let (mut readers, writer, mtu) = platform::dummy:: + + let wg = wireguard::Wireguard::new(readers, writer, mtu); } +*/ diff --git a/src/platform/bind.rs b/src/platform/bind.rs new file mode 100644 index 0000000..f22a5d7 --- /dev/null +++ b/src/platform/bind.rs @@ -0,0 +1,43 @@ +use super::Endpoint; +use std::error::Error; + +pub trait Reader: Send + Sync { + type Error: Error; + + fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>; +} + +pub trait Writer: Send + Sync + Clone + 'static { + type Error: Error; + + fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>; +} + +pub trait Bind: Send + Sync + 'static { + type Error: Error; + type Endpoint: Endpoint; + + /* Until Rust gets type equality constraints these have to be generic */ + type Writer: Writer; + type Reader: Reader; +} + +/// On platforms where fwmark can be set and the +/// implementation can bind to a new port during later configuration (UAPI support), +/// this type provides the ability to set the fwmark and close the socket (by dropping the instance) +pub trait Owner: Send { + type Error: Error; + + fn set_fwmark(&self, value: Option) -> Option; +} + +/// On some platforms the application can itself bind to a socket. +/// This enables configuration using the UAPI interface. +pub trait Platform: Bind { + type Owner: Owner; + + /// Bind to a new port, returning the reader/writer and + /// an associated instance of the owner type, which closes the UDP socket upon "drop" + /// and enables configuration of the fwmark value. + fn bind(port: u16) -> Result<(Vec, Self::Writer, Self::Owner), Self::Error>; +} diff --git a/src/platform/dummy.rs b/src/platform/dummy.rs deleted file mode 100644 index 208febe..0000000 --- a/src/platform/dummy.rs +++ /dev/null @@ -1,22 +0,0 @@ -#[cfg(test)] -use super::super::wireguard::dummy; -use super::BindOwner; -use super::PlatformBind; - -pub struct VoidOwner {} - -impl BindOwner for VoidOwner { - type Error = dummy::BindError; - - fn set_fwmark(&self, value: Option) -> Option { - None - } -} - -impl PlatformBind for dummy::PairBind { - type Owner = VoidOwner; - - fn bind(_port: u16) -> Result<(Vec, Self::Writer, Self::Owner), Self::Error> { - Err(dummy::BindError::Disconnected) - } -} diff --git a/src/platform/dummy/bind.rs b/src/platform/dummy/bind.rs new file mode 100644 index 0000000..14143ae --- /dev/null +++ b/src/platform/dummy/bind.rs @@ -0,0 +1,211 @@ +use std::error::Error; +use std::fmt; +use std::marker; +use std::net::SocketAddr; +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; +use std::sync::Arc; +use std::sync::Mutex; + +use super::super::bind::*; +use super::super::Endpoint; + +pub struct VoidOwner {} + +#[derive(Debug)] +pub enum BindError { + Disconnected, +} + +impl Error for BindError { + fn description(&self) -> &str { + "Generic Bind Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for BindError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BindError::Disconnected => write!(f, "PairBind disconnected"), + } + } +} + +/* TUN implementation */ + +#[derive(Debug)] +pub enum TunError { + Disconnected, +} + +impl Error for TunError { + fn description(&self) -> &str { + "Generic Tun Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for TunError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Not Possible") + } +} + +/* Endpoint implementation */ + +#[derive(Clone, Copy)] +pub struct UnitEndpoint {} + +impl Endpoint for UnitEndpoint { + fn from_address(_: SocketAddr) -> UnitEndpoint { + UnitEndpoint {} + } + + fn into_address(&self) -> SocketAddr { + "127.0.0.1:8080".parse().unwrap() + } + + fn clear_src(&mut self) {} +} + +impl UnitEndpoint { + pub fn new() -> UnitEndpoint { + UnitEndpoint {} + } +} + +#[derive(Clone, Copy)] +pub struct VoidBind {} + +impl Reader for VoidBind { + type Error = BindError; + + fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { + Ok((0, UnitEndpoint {})) + } +} + +impl Writer for VoidBind { + type Error = BindError; + + fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + Ok(()) + } +} + +impl Bind for VoidBind { + type Error = BindError; + type Endpoint = UnitEndpoint; + + type Reader = VoidBind; + type Writer = VoidBind; +} + +impl VoidBind { + pub fn new() -> VoidBind { + VoidBind {} + } +} + +/* Pair Bind */ + +#[derive(Clone)] +pub struct PairReader { + recv: Arc>>>, + _marker: marker::PhantomData, +} + +impl Reader for PairReader { + type Error = BindError; + fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { + let vec = self + .recv + .lock() + .unwrap() + .recv() + .map_err(|_| BindError::Disconnected)?; + let len = vec.len(); + buf[..len].copy_from_slice(&vec[..]); + Ok((vec.len(), UnitEndpoint {})) + } +} + +impl Writer for PairWriter { + type Error = BindError; + fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { + let owned = buf.to_owned(); + match self.send.lock().unwrap().send(owned) { + Err(_) => Err(BindError::Disconnected), + Ok(_) => Ok(()), + } + } +} + +#[derive(Clone)] +pub struct PairWriter { + send: Arc>>>, + _marker: marker::PhantomData, +} + +#[derive(Clone)] +pub struct PairBind {} + +impl PairBind { + pub fn pair() -> ( + (PairReader, PairWriter), + (PairReader, PairWriter), + ) { + let (tx1, rx1) = sync_channel(128); + let (tx2, rx2) = sync_channel(128); + ( + ( + PairReader { + recv: Arc::new(Mutex::new(rx1)), + _marker: marker::PhantomData, + }, + PairWriter { + send: Arc::new(Mutex::new(tx2)), + _marker: marker::PhantomData, + }, + ), + ( + PairReader { + recv: Arc::new(Mutex::new(rx2)), + _marker: marker::PhantomData, + }, + PairWriter { + send: Arc::new(Mutex::new(tx1)), + _marker: marker::PhantomData, + }, + ), + ) + } +} + +impl Bind for PairBind { + type Error = BindError; + type Endpoint = UnitEndpoint; + type Reader = PairReader; + type Writer = PairWriter; +} + +impl Owner for VoidOwner { + type Error = BindError; + + fn set_fwmark(&self, _value: Option) -> Option { + None + } +} + +impl Platform for PairBind { + type Owner = VoidOwner; + fn bind(_port: u16) -> Result<(Vec, Self::Writer, Self::Owner), Self::Error> { + Err(BindError::Disconnected) + } +} diff --git a/src/platform/dummy/endpoint.rs b/src/platform/dummy/endpoint.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/platform/dummy/endpoint.rs @@ -0,0 +1 @@ + diff --git a/src/platform/dummy/mod.rs b/src/platform/dummy/mod.rs new file mode 100644 index 0000000..884bd7e --- /dev/null +++ b/src/platform/dummy/mod.rs @@ -0,0 +1,13 @@ +mod bind; +mod endpoint; +mod tun; + +/* A pure dummy platform available during "test-time" + * + * The use of the dummy platform is to enable unit testing of full WireGuard, + * the configuration interface and the UAPI parser. + */ + +pub use bind::*; +pub use endpoint::*; +pub use tun::*; diff --git a/src/platform/dummy/tun.rs b/src/platform/dummy/tun.rs new file mode 100644 index 0000000..9fe9480 --- /dev/null +++ b/src/platform/dummy/tun.rs @@ -0,0 +1,172 @@ +use std::error::Error; +use std::fmt; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; +use std::sync::Arc; +use std::sync::Mutex; + +use super::super::tun::*; + +/* This submodule provides pure/dummy implementations of the IO interfaces + * for use in unit tests thoughout the project. + */ + +/* Error implementation */ + +#[derive(Debug)] +pub enum BindError { + Disconnected, +} + +impl Error for BindError { + fn description(&self) -> &str { + "Generic Bind Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for BindError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + BindError::Disconnected => write!(f, "PairBind disconnected"), + } + } +} + +#[derive(Debug)] +pub enum TunError { + Disconnected, +} + +impl Error for TunError { + fn description(&self) -> &str { + "Generic Tun Error" + } + + fn source(&self) -> Option<&(dyn Error + 'static)> { + None + } +} + +impl fmt::Display for TunError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Not Possible") + } +} + +pub struct TunTest {} + +pub struct TunFakeIO { + store: bool, + tx: SyncSender>, + rx: Receiver>, +} + +pub struct TunReader { + rx: Receiver>, +} + +pub struct TunWriter { + store: bool, + tx: Mutex>>, +} + +#[derive(Clone)] +pub struct TunMTU { + mtu: Arc, +} + +impl Reader for TunReader { + type Error = TunError; + + fn read(&self, buf: &mut [u8], offset: usize) -> Result { + match self.rx.recv() { + Ok(m) => { + buf[offset..].copy_from_slice(&m[..]); + Ok(m.len()) + } + Err(_) => Err(TunError::Disconnected), + } + } +} + +impl Writer for TunWriter { + type Error = TunError; + + fn write(&self, src: &[u8]) -> Result<(), Self::Error> { + if self.store { + let m = src.to_owned(); + match self.tx.lock().unwrap().send(m) { + Ok(_) => Ok(()), + Err(_) => Err(TunError::Disconnected), + } + } else { + Ok(()) + } + } +} + +impl MTU for TunMTU { + fn mtu(&self) -> usize { + self.mtu.load(Ordering::Acquire) + } +} + +impl Tun for TunTest { + type Writer = TunWriter; + type Reader = TunReader; + type MTU = TunMTU; + type Error = TunError; +} + +impl TunFakeIO { + pub fn write(&self, msg: Vec) { + if self.store { + self.tx.send(msg).unwrap(); + } + } + + pub fn read(&self) -> Vec { + self.rx.recv().unwrap() + } +} + +impl TunTest { + pub fn create(mtu: usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) { + let (tx1, rx1) = if store { + sync_channel(32) + } else { + sync_channel(1) + }; + let (tx2, rx2) = if store { + sync_channel(32) + } else { + sync_channel(1) + }; + + let fake = TunFakeIO { + tx: tx1, + rx: rx2, + store, + }; + let reader = TunReader { rx: rx1 }; + let writer = TunWriter { + tx: Mutex::new(tx2), + store, + }; + let mtu = TunMTU { + mtu: Arc::new(AtomicUsize::new(mtu)), + }; + + (fake, reader, writer, mtu) + } +} + +impl Platform for TunTest { + fn create(_name: &str) -> Result<(Vec, Self::Writer, Self::MTU), Self::Error> { + Err(TunError::Disconnected) + } +} diff --git a/src/platform/endpoint.rs b/src/platform/endpoint.rs new file mode 100644 index 0000000..4702aab --- /dev/null +++ b/src/platform/endpoint.rs @@ -0,0 +1,7 @@ +use std::net::SocketAddr; + +pub trait Endpoint: Send + 'static { + fn from_address(addr: SocketAddr) -> Self; + fn into_address(&self) -> SocketAddr; + fn clear_src(&mut self); +} diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs index 5b7b105..090569a 100644 --- a/src/platform/linux/tun.rs +++ b/src/platform/linux/tun.rs @@ -1,6 +1,4 @@ -use super::super::super::wireguard::tun::*; -use super::super::PlatformTun; -use super::super::Tun; +use super::super::tun::*; use libc::*; @@ -127,7 +125,7 @@ impl Tun for LinuxTun { type MTU = LinuxTunMTU; } -impl PlatformTun for LinuxTun { +impl Platform for LinuxTun { fn create(name: &str) -> Result<(Vec, Self::Writer, Self::MTU), Self::Error> { // construct request struct let mut req = Ifreq { diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs index 0a1a186..52e4c45 100644 --- a/src/platform/linux/udp.rs +++ b/src/platform/linux/udp.rs @@ -1,6 +1,5 @@ -use super::super::Bind; +use super::super::bind::*; use super::super::Endpoint; -use super::super::PlatformBind; use std::net::SocketAddr; diff --git a/src/platform/mod.rs b/src/platform/mod.rs index a0bbc13..ecd559a 100644 --- a/src/platform/mod.rs +++ b/src/platform/mod.rs @@ -1,33 +1,15 @@ -use std::error::Error; +mod endpoint; -use super::wireguard::bind::Bind; -use super::wireguard::tun::Tun; -use super::wireguard::Endpoint; +pub mod bind; +pub mod tun; -#[cfg(test)] -mod dummy; +pub use endpoint::Endpoint; #[cfg(target_os = "linux")] mod linux; +#[cfg(test)] +pub mod dummy; + #[cfg(target_os = "linux")] pub use linux::LinuxTun as TunInstance; - -pub trait BindOwner: Send { - type Error: Error; - - fn set_fwmark(&self, value: Option) -> Option; -} - -pub trait PlatformBind: Bind { - type Owner: BindOwner; - - /// Bind to a new port, returning the reader/writer and - /// an associated instance of the owner type, which closes the UDP socket upon "drop" - /// and enables configuration of the fwmark value. - fn bind(port: u16) -> Result<(Vec, Self::Writer, Self::Owner), Self::Error>; -} - -pub trait PlatformTun: Tun { - fn create(name: &str) -> Result<(Vec, Self::Writer, Self::MTU), Self::Error>; -} diff --git a/src/platform/tun.rs b/src/platform/tun.rs new file mode 100644 index 0000000..f49d4af --- /dev/null +++ b/src/platform/tun.rs @@ -0,0 +1,61 @@ +use std::error::Error; + +pub trait Writer: Send + Sync + 'static { + type Error: Error; + + /// Receive a cryptkey routed IP packet + /// + /// # Arguments + /// + /// - src: Buffer containing the IP packet to be written + /// + /// # Returns + /// + /// Unit type or an error + fn write(&self, src: &[u8]) -> Result<(), Self::Error>; +} + +pub trait Reader: Send + 'static { + type Error: Error; + + /// Reads an IP packet into dst[offset:] from the tunnel device + /// + /// The reason for providing space for a prefix + /// is to efficiently accommodate platforms on which the packet is prefaced by a header. + /// This space is later used to construct the transport message inplace. + /// + /// # Arguments + /// + /// - buf: Destination buffer (enough space for MTU bytes + header) + /// - offset: Offset for the beginning of the IP packet + /// + /// # Returns + /// + /// The size of the IP packet (ignoring the header) or an std::error::Error instance: + fn read(&self, buf: &mut [u8], offset: usize) -> Result; +} + +pub trait MTU: Send + Sync + Clone + 'static { + /// Returns the MTU of the device + /// + /// This function needs to be efficient (called for every read). + /// The goto implementation strategy is to .load an atomic variable, + /// then use e.g. netlink to update the variable in a separate thread. + /// + /// # Returns + /// + /// The MTU of the interface in bytes + fn mtu(&self) -> usize; +} + +pub trait Tun: Send + Sync + 'static { + type Writer: Writer; + type Reader: Reader; + type MTU: MTU; + type Error: Error; +} + +/// On some platforms the application can create the TUN device itself. +pub trait Platform: Tun { + fn create(name: &str) -> Result<(Vec, Self::Writer, Self::MTU), Self::Error>; +} diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1 @@ + diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index 563a22f..c3e9c58 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -16,9 +16,11 @@ mod tests; /// - Bind type, specifying how WireGuard messages are sent/received from the internet and what constitutes an "endpoint" pub use wireguard::{Peer, Wireguard}; -pub use types::bind; -pub use types::tun; -pub use types::Endpoint; +#[cfg(test)] +pub use types::dummy_keypair; #[cfg(test)] -pub use types::dummy; +use super::platform::dummy; + +use super::platform::{bind, tun, Endpoint}; +use types::{Key, KeyPair}; diff --git a/src/wireguard/router/device.rs b/src/wireguard/router/device.rs index 455020c..b122bf4 100644 --- a/src/wireguard/router/device.rs +++ b/src/wireguard/router/device.rs @@ -21,7 +21,7 @@ use super::types::{Callbacks, RouterError}; use super::workers::{worker_parallel, JobParallel, Operation}; use super::SIZE_MESSAGE_PREFIX; -use super::super::types::{bind, tun, Endpoint, KeyPair}; +use super::super::{bind, tun, Endpoint, KeyPair}; pub struct DeviceInner> { // inbound writer (TUN) diff --git a/src/wireguard/router/peer.rs b/src/wireguard/router/peer.rs index 4f47604..0b193a4 100644 --- a/src/wireguard/router/peer.rs +++ b/src/wireguard/router/peer.rs @@ -14,7 +14,7 @@ use treebitmap::IpLookupTable; use zerocopy::LayoutVerified; use super::super::constants::*; -use super::super::types::{bind, tun, Endpoint, KeyPair}; +use super::super::{bind, tun, Endpoint, KeyPair}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; diff --git a/src/wireguard/router/tests.rs b/src/wireguard/router/tests.rs index 93c0773..d44a612 100644 --- a/src/wireguard/router/tests.rs +++ b/src/wireguard/router/tests.rs @@ -9,9 +9,9 @@ use num_cpus; use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv6::MutableIpv6Packet; -use super::super::types::bind::*; -use super::super::types::*; - +use super::super::bind::*; +use super::super::dummy; +use super::super::dummy_keypair; use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX}; extern crate test; @@ -151,7 +151,7 @@ mod tests { // add new peer let opaque = Arc::new(AtomicUsize::new(0)); let peer = router.new_peer(opaque.clone()); - peer.add_keypair(dummy::keypair(true)); + peer.add_keypair(dummy_keypair(true)); // add subnet to peer let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20"); @@ -211,7 +211,7 @@ mod tests { let peer = router.new_peer(opaque.clone()); let mask: IpAddr = mask.parse().unwrap(); if set_key { - peer.add_keypair(dummy::keypair(true)); + peer.add_keypair(dummy_keypair(true)); } // map subnet to peer @@ -340,7 +340,7 @@ mod tests { let peer1 = router1.new_peer(opaq1.clone()); let mask: IpAddr = mask.parse().unwrap(); peer1.add_subnet(mask, *len); - peer1.add_keypair(dummy::keypair(false)); + peer1.add_keypair(dummy_keypair(false)); let (mask, len, _ip, _okay) = p2; let peer2 = router2.new_peer(opaq2.clone()); @@ -370,7 +370,7 @@ mod tests { // this should cause a key-confirmation packet (keepalive or staged packet) // this also causes peer1 to learn the "endpoint" for peer2 assert!(peer1.get_endpoint().is_none()); - peer2.add_keypair(dummy::keypair(true)); + peer2.add_keypair(dummy_keypair(true)); wait(); assert!(opaq2.send().is_some()); diff --git a/src/wireguard/router/workers.rs b/src/wireguard/router/workers.rs index 61a7620..8ebb246 100644 --- a/src/wireguard/router/workers.rs +++ b/src/wireguard/router/workers.rs @@ -17,7 +17,7 @@ use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::PeerInner; use super::types::Callbacks; -use super::super::types::{bind, tun, Endpoint}; +use super::super::{bind, tun, Endpoint}; use super::ip::*; pub const SIZE_TAG: usize = 16; diff --git a/src/wireguard/tests.rs b/src/wireguard/tests.rs index 0148d5d..4ecd43b 100644 --- a/src/wireguard/tests.rs +++ b/src/wireguard/tests.rs @@ -1,6 +1,5 @@ -use super::types::tun::Tun; -use super::types::{bind, dummy, tun}; use super::wireguard::Wireguard; +use super::{bind, dummy, tun}; use std::thread; use std::time::Duration; diff --git a/src/wireguard/timers.rs b/src/wireguard/timers.rs index 1d9b8a0..40717f8 100644 --- a/src/wireguard/timers.rs +++ b/src/wireguard/timers.rs @@ -7,9 +7,9 @@ use log::info; use hjul::{Runner, Timer}; +use super::{bind, tun}; use super::constants::*; use super::router::{Callbacks, message_data_len}; -use super::types::{bind, tun}; use super::wireguard::{Peer, PeerInner}; pub struct Timers { diff --git a/src/wireguard/types.rs b/src/wireguard/types.rs new file mode 100644 index 0000000..51898a0 --- /dev/null +++ b/src/wireguard/types.rs @@ -0,0 +1,63 @@ +use clear_on_drop::clear::Clear; +use std::time::Instant; + +#[cfg(test)] +pub fn dummy_keypair(initiator: bool) -> KeyPair { + let k1 = Key { + key: [0x53u8; 32], + id: 0x646e6573, + }; + let k2 = Key { + key: [0x52u8; 32], + id: 0x76636572, + }; + if initiator { + KeyPair { + birth: Instant::now(), + initiator: true, + send: k1, + recv: k2, + } + } else { + KeyPair { + birth: Instant::now(), + initiator: false, + send: k2, + recv: k1, + } + } +} + +#[derive(Debug, Clone)] +pub struct Key { + pub key: [u8; 32], + pub id: u32, +} + +// zero key on drop +impl Drop for Key { + fn drop(&mut self) { + self.key.clear() + } +} + +#[cfg(test)] +impl PartialEq for Key { + fn eq(&self, other: &Self) -> bool { + self.id == other.id && self.key[..] == other.key[..] + } +} + +#[derive(Debug, Clone)] +pub struct KeyPair { + pub birth: Instant, // when was the key-pair created + pub initiator: bool, // has the key-pair been confirmed? + pub send: Key, // key for outbound messages + pub recv: Key, // key for inbound messages +} + +impl KeyPair { + pub fn local_id(&self) -> u32 { + self.recv.id + } +} diff --git a/src/wireguard/types/bind.rs b/src/wireguard/types/bind.rs deleted file mode 100644 index 3d3f187..0000000 --- a/src/wireguard/types/bind.rs +++ /dev/null @@ -1,23 +0,0 @@ -use super::Endpoint; -use std::error::Error; - -pub trait Reader: Send + Sync { - type Error: Error; - - fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>; -} - -pub trait Writer: Send + Sync + Clone + 'static { - type Error: Error; - - fn write(&self, buf: &[u8], dst: &E) -> Result<(), Self::Error>; -} - -pub trait Bind: Send + Sync + 'static { - type Error: Error; - type Endpoint: Endpoint; - - /* Until Rust gets type equality constraints these have to be generic */ - type Writer: Writer; - type Reader: Reader; -} diff --git a/src/wireguard/types/dummy.rs b/src/wireguard/types/dummy.rs deleted file mode 100644 index 384f123..0000000 --- a/src/wireguard/types/dummy.rs +++ /dev/null @@ -1,339 +0,0 @@ -use std::error::Error; -use std::fmt; -use std::marker; -use std::net::SocketAddr; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; -use std::sync::Arc; -use std::sync::Mutex; -use std::time::Instant; - -use super::*; - -/* This submodule provides pure/dummy implementations of the IO interfaces - * for use in unit tests thoughout the project. - */ - -/* Error implementation */ - -#[derive(Debug)] -pub enum BindError { - Disconnected, -} - -impl Error for BindError { - fn description(&self) -> &str { - "Generic Bind Error" - } - - fn source(&self) -> Option<&(dyn Error + 'static)> { - None - } -} - -impl fmt::Display for BindError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - BindError::Disconnected => write!(f, "PairBind disconnected"), - } - } -} - -/* TUN implementation */ - -#[derive(Debug)] -pub enum TunError { - Disconnected, -} - -impl Error for TunError { - fn description(&self) -> &str { - "Generic Tun Error" - } - - fn source(&self) -> Option<&(dyn Error + 'static)> { - None - } -} - -impl fmt::Display for TunError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Not Possible") - } -} - -/* Endpoint implementation */ - -#[derive(Clone, Copy)] -pub struct UnitEndpoint {} - -impl Endpoint for UnitEndpoint { - fn from_address(_: SocketAddr) -> UnitEndpoint { - UnitEndpoint {} - } - - fn into_address(&self) -> SocketAddr { - "127.0.0.1:8080".parse().unwrap() - } - - fn clear_src(&mut self) {} -} - -impl UnitEndpoint { - pub fn new() -> UnitEndpoint { - UnitEndpoint {} - } -} - -/* */ - -pub struct TunTest {} - -pub struct TunFakeIO { - store: bool, - tx: SyncSender>, - rx: Receiver>, -} - -pub struct TunReader { - rx: Receiver>, -} - -pub struct TunWriter { - store: bool, - tx: Mutex>>, -} - -#[derive(Clone)] -pub struct TunMTU { - mtu: Arc, -} - -impl tun::Reader for TunReader { - type Error = TunError; - - fn read(&self, buf: &mut [u8], offset: usize) -> Result { - match self.rx.recv() { - Ok(m) => { - buf[offset..].copy_from_slice(&m[..]); - Ok(m.len()) - } - Err(_) => Err(TunError::Disconnected), - } - } -} - -impl tun::Writer for TunWriter { - type Error = TunError; - - fn write(&self, src: &[u8]) -> Result<(), Self::Error> { - if self.store { - let m = src.to_owned(); - match self.tx.lock().unwrap().send(m) { - Ok(_) => Ok(()), - Err(_) => Err(TunError::Disconnected), - } - } else { - Ok(()) - } - } -} - -impl tun::MTU for TunMTU { - fn mtu(&self) -> usize { - self.mtu.load(Ordering::Acquire) - } -} - -impl tun::Tun for TunTest { - type Writer = TunWriter; - type Reader = TunReader; - type MTU = TunMTU; - type Error = TunError; -} - -impl TunFakeIO { - pub fn write(&self, msg: Vec) { - if self.store { - self.tx.send(msg).unwrap(); - } - } - - pub fn read(&self) -> Vec { - self.rx.recv().unwrap() - } -} - -impl TunTest { - pub fn create(mtu: usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) { - let (tx1, rx1) = if store { - sync_channel(32) - } else { - sync_channel(1) - }; - let (tx2, rx2) = if store { - sync_channel(32) - } else { - sync_channel(1) - }; - - let fake = TunFakeIO { - tx: tx1, - rx: rx2, - store, - }; - let reader = TunReader { rx: rx1 }; - let writer = TunWriter { - tx: Mutex::new(tx2), - store, - }; - let mtu = TunMTU { - mtu: Arc::new(AtomicUsize::new(mtu)), - }; - - (fake, reader, writer, mtu) - } -} - -/* Void Bind */ - -#[derive(Clone, Copy)] -pub struct VoidBind {} - -impl bind::Reader for VoidBind { - type Error = BindError; - - fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { - Ok((0, UnitEndpoint {})) - } -} - -impl bind::Writer for VoidBind { - type Error = BindError; - - fn write(&self, _buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { - Ok(()) - } -} - -impl bind::Bind for VoidBind { - type Error = BindError; - type Endpoint = UnitEndpoint; - - type Reader = VoidBind; - type Writer = VoidBind; -} - -impl VoidBind { - pub fn new() -> VoidBind { - VoidBind {} - } -} - -/* Pair Bind */ - -#[derive(Clone)] -pub struct PairReader { - recv: Arc>>>, - _marker: marker::PhantomData, -} - -impl bind::Reader for PairReader { - type Error = BindError; - fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { - let vec = self - .recv - .lock() - .unwrap() - .recv() - .map_err(|_| BindError::Disconnected)?; - let len = vec.len(); - buf[..len].copy_from_slice(&vec[..]); - Ok((vec.len(), UnitEndpoint {})) - } -} - -impl bind::Writer for PairWriter { - type Error = BindError; - fn write(&self, buf: &[u8], _dst: &UnitEndpoint) -> Result<(), Self::Error> { - let owned = buf.to_owned(); - match self.send.lock().unwrap().send(owned) { - Err(_) => Err(BindError::Disconnected), - Ok(_) => Ok(()), - } - } -} - -#[derive(Clone)] -pub struct PairWriter { - send: Arc>>>, - _marker: marker::PhantomData, -} - -#[derive(Clone)] -pub struct PairBind {} - -impl PairBind { - pub fn pair() -> ( - (PairReader, PairWriter), - (PairReader, PairWriter), - ) { - let (tx1, rx1) = sync_channel(128); - let (tx2, rx2) = sync_channel(128); - ( - ( - PairReader { - recv: Arc::new(Mutex::new(rx1)), - _marker: marker::PhantomData, - }, - PairWriter { - send: Arc::new(Mutex::new(tx2)), - _marker: marker::PhantomData, - }, - ), - ( - PairReader { - recv: Arc::new(Mutex::new(rx2)), - _marker: marker::PhantomData, - }, - PairWriter { - send: Arc::new(Mutex::new(tx1)), - _marker: marker::PhantomData, - }, - ), - ) - } -} - -impl bind::Bind for PairBind { - type Error = BindError; - type Endpoint = UnitEndpoint; - type Reader = PairReader; - type Writer = PairWriter; -} - -pub fn keypair(initiator: bool) -> KeyPair { - let k1 = Key { - key: [0x53u8; 32], - id: 0x646e6573, - }; - let k2 = Key { - key: [0x52u8; 32], - id: 0x76636572, - }; - if initiator { - KeyPair { - birth: Instant::now(), - initiator: true, - send: k1, - recv: k2, - } - } else { - KeyPair { - birth: Instant::now(), - initiator: false, - send: k2, - recv: k1, - } - } -} diff --git a/src/wireguard/types/endpoint.rs b/src/wireguard/types/endpoint.rs deleted file mode 100644 index 4702aab..0000000 --- a/src/wireguard/types/endpoint.rs +++ /dev/null @@ -1,7 +0,0 @@ -use std::net::SocketAddr; - -pub trait Endpoint: Send + 'static { - fn from_address(addr: SocketAddr) -> Self; - fn into_address(&self) -> SocketAddr; - fn clear_src(&mut self); -} diff --git a/src/wireguard/types/keys.rs b/src/wireguard/types/keys.rs deleted file mode 100644 index 282c4ae..0000000 --- a/src/wireguard/types/keys.rs +++ /dev/null @@ -1,36 +0,0 @@ -use clear_on_drop::clear::Clear; -use std::time::Instant; - -#[derive(Debug, Clone)] -pub struct Key { - pub key: [u8; 32], - pub id: u32, -} - -// zero key on drop -impl Drop for Key { - fn drop(&mut self) { - self.key.clear() - } -} - -#[cfg(test)] -impl PartialEq for Key { - fn eq(&self, other: &Self) -> bool { - self.id == other.id && self.key[..] == other.key[..] - } -} - -#[derive(Debug, Clone)] -pub struct KeyPair { - pub birth: Instant, // when was the key-pair created - pub initiator: bool, // has the key-pair been confirmed? - pub send: Key, // key for outbound messages - pub recv: Key, // key for inbound messages -} - -impl KeyPair { - pub fn local_id(&self) -> u32 { - self.recv.id - } -} diff --git a/src/wireguard/types/mod.rs b/src/wireguard/types/mod.rs deleted file mode 100644 index 20a1238..0000000 --- a/src/wireguard/types/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod endpoint; -mod keys; - -pub mod bind; -pub mod tun; - -#[cfg(test)] -pub mod dummy; - -pub use endpoint::Endpoint; -pub use keys::{Key, KeyPair}; diff --git a/src/wireguard/types/tun.rs b/src/wireguard/types/tun.rs deleted file mode 100644 index 2ba16ff..0000000 --- a/src/wireguard/types/tun.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::error::Error; - -pub trait Writer: Send + Sync + 'static { - type Error: Error; - - /// Receive a cryptkey routed IP packet - /// - /// # Arguments - /// - /// - src: Buffer containing the IP packet to be written - /// - /// # Returns - /// - /// Unit type or an error - fn write(&self, src: &[u8]) -> Result<(), Self::Error>; -} - -pub trait Reader: Send + 'static { - type Error: Error; - - /// Reads an IP packet into dst[offset:] from the tunnel device - /// - /// The reason for providing space for a prefix - /// is to efficiently accommodate platforms on which the packet is prefaced by a header. - /// This space is later used to construct the transport message inplace. - /// - /// # Arguments - /// - /// - buf: Destination buffer (enough space for MTU bytes + header) - /// - offset: Offset for the beginning of the IP packet - /// - /// # Returns - /// - /// The size of the IP packet (ignoring the header) or an std::error::Error instance: - fn read(&self, buf: &mut [u8], offset: usize) -> Result; -} - -pub trait MTU: Send + Sync + Clone + 'static { - /// Returns the MTU of the device - /// - /// This function needs to be efficient (called for every read). - /// The goto implementation strategy is to .load an atomic variable, - /// then use e.g. netlink to update the variable in a separate thread. - /// - /// # Returns - /// - /// The MTU of the interface in bytes - fn mtu(&self) -> usize; -} - -pub trait Tun: Send + Sync + 'static { - type Writer: Writer; - type Reader: Reader; - type MTU: MTU; - type Error: Error; -} diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 9bcac0a..96a134c 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -3,10 +3,10 @@ use super::handshake; use super::router; use super::timers::{Events, Timers}; -use super::types::bind::Reader as BindReader; -use super::types::bind::{Bind, Writer}; -use super::types::tun::{Reader, Tun, MTU}; -use super::types::Endpoint; +use super::bind::Reader as BindReader; +use super::bind::{Bind, Writer}; +use super::tun::{Reader, Tun, MTU}; +use super::Endpoint; use hjul::Runner; -- cgit v1.2.3-59-g8ed1b