diff options
author | 2019-10-23 10:32:18 +0200 | |
---|---|---|
committer | 2019-10-23 10:32:18 +0200 | |
commit | 3fa928b3158ce33a57e9ba2c1913485eb409ff4b (patch) | |
tree | 87562a6c84bf421a19d01ef153927f3f8315cf98 | |
parent | Work on porting timer semantics and linux platform (diff) | |
download | wireguard-rs-3fa928b3158ce33a57e9ba2c1913485eb409ff4b.tar.xz wireguard-rs-3fa928b3158ce33a57e9ba2c1913485eb409ff4b.zip |
Work on platform specific code (Linux)
-rw-r--r-- | src/configuration/config.rs (renamed from src/wireguard/config.rs) | 102 | ||||
-rw-r--r-- | src/configuration/mod.rs | 5 | ||||
-rw-r--r-- | src/main.rs | 6 | ||||
-rw-r--r-- | src/platform/dummy.rs | 22 | ||||
-rw-r--r-- | src/platform/linux/mod.rs | 3 | ||||
-rw-r--r-- | src/platform/linux/tun.rs | 39 | ||||
-rw-r--r-- | src/platform/linux/udp.rs | 27 | ||||
-rw-r--r-- | src/platform/mod.rs | 23 | ||||
-rw-r--r-- | src/wireguard/handshake/device.rs | 6 | ||||
-rw-r--r-- | src/wireguard/mod.rs | 13 | ||||
-rw-r--r-- | src/wireguard/types/dummy.rs | 54 | ||||
-rw-r--r-- | src/wireguard/types/endpoint.rs | 2 | ||||
-rw-r--r-- | src/wireguard/types/mod.rs | 5 | ||||
-rw-r--r-- | src/wireguard/wireguard.rs | 50 |
14 files changed, 277 insertions, 80 deletions
diff --git a/src/wireguard/config.rs b/src/configuration/config.rs index 0f2953d..24b1349 100644 --- a/src/wireguard/config.rs +++ b/src/configuration/config.rs @@ -1,9 +1,11 @@ +use spin::Mutex; use std::net::{IpAddr, SocketAddr}; use x25519_dalek::{PublicKey, StaticSecret}; -use super::wireguard::Wireguard; -use super::types::bind::Bind; -use super::types::tun::Tun; +use super::BindOwner; +use super::PlatformBind; +use super::Tun; +use super::Wireguard; /// The goal of the configuration interface is, among others, /// to hide the IO implementations (over which the WG device is generic), @@ -19,15 +21,28 @@ pub struct PeerState { allowed_ips: Vec<(IpAddr, u32)>, } +struct UDPState<O: BindOwner> { + fwmark: Option<u32>, + owner: O, + port: u16, +} + +pub struct WireguardConfig<T: Tun, B: PlatformBind> { + wireguard: Wireguard<T, B>, + network: Mutex<Option<UDPState<B::Owner>>>, +} + pub enum ConfigError { - NoSuchPeer + NoSuchPeer, + NotListening, } impl ConfigError { - fn errno(&self) -> i32 { + // TODO: obtain the correct error values match self { NoSuchPeer => 1, + NotListening => 2, } } } @@ -122,7 +137,11 @@ pub trait Configuration { /// /// - `peer': The public key of the peer /// - `psk` - fn set_persistent_keepalive_interval(&self, peer: PublicKey) -> Option<ConfigError>; + fn set_persistent_keepalive_interval( + &self, + peer: PublicKey, + interval: usize, + ) -> Option<ConfigError>; /// Remove all allowed IPs from the peer /// @@ -161,26 +180,81 @@ pub trait Configuration { fn get_peers(&self) -> Vec<PeerState>; } -impl <T : Tun, B : Bind>Configuration for Wireguard<T, B> { - - fn set_private_key(&self, sk : Option<StaticSecret>) { - self.set_key(sk) +impl<T: Tun, B: PlatformBind> Configuration for WireguardConfig<T, B> { + fn set_private_key(&self, sk: Option<StaticSecret>) { + self.wireguard.set_key(sk) } fn get_private_key(&self) -> Option<StaticSecret> { - self.get_sk() + self.wireguard.get_sk() } fn get_protocol_version(&self) -> usize { 1 } - fn set_listen_port(&self, port : u16) -> Option<ConfigError> { + fn set_listen_port(&self, port: u16) -> Option<ConfigError> { + let mut udp = self.network.lock(); + + // close the current listener + *udp = None; + None } - + fn set_fwmark(&self, mark: Option<u32>) -> Option<ConfigError> { + match self.network.lock().as_mut() { + Some(mut bind) => { + // there is a active bind + // set the fwmark (the IO operation) + bind.owner.set_fwmark(mark).unwrap(); // TODO: handle + + // update stored value + bind.fwmark = mark; + None + } + None => Some(ConfigError::NotListening), + } + } + + fn replace_peers(&self) { + self.wireguard.clear_peers(); + } + + fn remove_peer(&self, peer: PublicKey) { + self.wireguard.remove_peer(peer); + } + + fn add_peer(&self, peer: PublicKey) -> bool { + self.wireguard.new_peer(peer); + false + } + + fn set_preshared_key(&self, peer: PublicKey, psk: Option<[u8; 32]>) -> Option<ConfigError> { + None + } + + fn set_endpoint(&self, peer: PublicKey, addr: SocketAddr) -> Option<ConfigError> { + None + } + + fn set_persistent_keepalive_interval( + &self, + peer: PublicKey, + interval: usize, + ) -> Option<ConfigError> { None } -}
\ No newline at end of file + fn replace_allowed_ips(&self, peer: PublicKey) -> Option<ConfigError> { + None + } + + fn add_allowed_ip(&self, peer: PublicKey, ip: IpAddr, masklen: u32) -> Option<ConfigError> { + None + } + + fn get_peers(&self) -> Vec<PeerState> { + vec![] + } +} diff --git a/src/configuration/mod.rs b/src/configuration/mod.rs new file mode 100644 index 0000000..56a83e2 --- /dev/null +++ b/src/configuration/mod.rs @@ -0,0 +1,5 @@ +mod config; + +use super::platform::{BindOwner, PlatformBind}; +use super::wireguard::tun::Tun; +use super::wireguard::Wireguard; diff --git a/src/main.rs b/src/main.rs index aad8d02..4dac3cd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,12 +6,12 @@ extern crate jemallocator; #[global_allocator] static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; -// mod config; +mod configuration; mod platform; mod wireguard; -use platform::TunBind; +use platform::PlatformTun; fn main() { - let (readers, writers, mtu) = platform::PlatformTun::create("test").unwrap(); + let (readers, writer, mtu) = platform::TunInstance::create("test").unwrap(); } diff --git a/src/platform/dummy.rs b/src/platform/dummy.rs new file mode 100644 index 0000000..208febe --- /dev/null +++ b/src/platform/dummy.rs @@ -0,0 +1,22 @@ +#[cfg(test)] +use super::super::wireguard::dummy; +use super::BindOwner; +use super::PlatformBind; + +pub struct VoidOwner {} + +impl BindOwner for VoidOwner { + type Error = dummy::BindError; + + fn set_fwmark(&self, value: Option<u32>) -> Option<Self::Error> { + None + } +} + +impl PlatformBind for dummy::PairBind { + type Owner = VoidOwner; + + fn bind(_port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error> { + Err(dummy::BindError::Disconnected) + } +} diff --git a/src/platform/linux/mod.rs b/src/platform/linux/mod.rs index 7a456ad..7d6a61c 100644 --- a/src/platform/linux/mod.rs +++ b/src/platform/linux/mod.rs @@ -1,4 +1,5 @@ mod tun; mod udp; -pub use tun::PlatformTun; +pub use tun::LinuxTun; +pub use udp::LinuxBind; diff --git a/src/platform/linux/tun.rs b/src/platform/linux/tun.rs index 17390a1..5b7b105 100644 --- a/src/platform/linux/tun.rs +++ b/src/platform/linux/tun.rs @@ -1,6 +1,6 @@ use super::super::super::wireguard::tun::*; +use super::super::PlatformTun; use super::super::Tun; -use super::super::TunBind; use libc::*; @@ -32,13 +32,13 @@ struct Ifreq { _pad: [u8; 64], } -pub struct PlatformTun {} +pub struct LinuxTun {} -pub struct PlatformTunReader { +pub struct LinuxTunReader { fd: RawFd, } -pub struct PlatformTunWriter { +pub struct LinuxTunWriter { fd: RawFd, } @@ -46,7 +46,7 @@ pub struct PlatformTunWriter { * announcing an MTU update for the interface */ #[derive(Clone)] -pub struct PlatformTunMTU { +pub struct LinuxTunMTU { value: Arc<AtomicUsize>, } @@ -83,14 +83,14 @@ impl Error for LinuxTunError { } } -impl MTU for PlatformTunMTU { +impl MTU for LinuxTunMTU { #[inline(always)] fn mtu(&self) -> usize { self.value.load(Ordering::Relaxed) } } -impl Reader for PlatformTunReader { +impl Reader for LinuxTunReader { type Error = LinuxTunError; fn read(&self, buf: &mut [u8], offset: usize) -> Result<usize, Self::Error> { @@ -109,7 +109,7 @@ impl Reader for PlatformTunReader { } } -impl Writer for PlatformTunWriter { +impl Writer for LinuxTunWriter { type Error = LinuxTunError; fn write(&self, src: &[u8]) -> Result<(), Self::Error> { @@ -120,14 +120,14 @@ impl Writer for PlatformTunWriter { } } -impl Tun for PlatformTun { +impl Tun for LinuxTun { type Error = LinuxTunError; - type Reader = PlatformTunReader; - type Writer = PlatformTunWriter; - type MTU = PlatformTunMTU; + type Reader = LinuxTunReader; + type Writer = LinuxTunWriter; + type MTU = LinuxTunMTU; } -impl TunBind for PlatformTun { +impl PlatformTun for LinuxTun { fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error> { // construct request struct let mut req = Ifreq { @@ -157,10 +157,10 @@ impl TunBind for PlatformTun { // create PlatformTunMTU instance Ok(( - vec![PlatformTunReader { fd }], // TODO: enable multi-queue for Linux - PlatformTunWriter { fd }, - PlatformTunMTU { - value: Arc::new(AtomicUsize::new(1500)), + vec![LinuxTunReader { fd }], // TODO: enable multi-queue for Linux + LinuxTunWriter { fd }, + LinuxTunMTU { + value: Arc::new(AtomicUsize::new(1500)), // TODO: fetch and update }, )) } @@ -174,7 +174,7 @@ mod tests { fn is_root() -> bool { match env::var("USER") { Ok(val) => val == "root", - Err(e) => false, + Err(_) => false, } } @@ -183,6 +183,7 @@ mod tests { if !is_root() { return; } - let (readers, writers, mtu) = PlatformTun::create("test").unwrap(); + let (readers, writers, mtu) = LinuxTun::create("test").unwrap(); + // TODO: test (any good idea how?) } } diff --git a/src/platform/linux/udp.rs b/src/platform/linux/udp.rs index e69de29..0a1a186 100644 --- a/src/platform/linux/udp.rs +++ b/src/platform/linux/udp.rs @@ -0,0 +1,27 @@ +use super::super::Bind; +use super::super::Endpoint; +use super::super::PlatformBind; + +use std::net::SocketAddr; + +pub struct LinuxEndpoint {} + +pub struct LinuxBind {} + +impl Endpoint for LinuxEndpoint { + fn clear_src(&mut self) {} + + fn from_address(addr: SocketAddr) -> Self { + LinuxEndpoint {} + } + + fn into_address(&self) -> SocketAddr { + "127.0.0.1:6060".parse().unwrap() + } +} + +/* +impl Bind for PlatformBind { + type Endpoint = PlatformEndpoint; +} +*/ diff --git a/src/platform/mod.rs b/src/platform/mod.rs index de33714..a0bbc13 100644 --- a/src/platform/mod.rs +++ b/src/platform/mod.rs @@ -2,21 +2,32 @@ use std::error::Error; use super::wireguard::bind::Bind; use super::wireguard::tun::Tun; +use super::wireguard::Endpoint; + +#[cfg(test)] +mod dummy; #[cfg(target_os = "linux")] mod linux; #[cfg(target_os = "linux")] -pub use linux::PlatformTun; +pub use linux::LinuxTun as TunInstance; + +pub trait BindOwner: Send { + type Error: Error; + + fn set_fwmark(&self, value: Option<u32>) -> Option<Self::Error>; +} -pub trait UDPBind: Bind { - type Closer; +pub trait PlatformBind: Bind { + type Owner: BindOwner; /// Bind to a new port, returning the reader/writer and - /// an associated instance of the Closer type, which closes the UDP socket upon "drop". - fn bind(port: u16) -> Result<(Self::Reader, Self::Writer, Self::Closer), Self::Error>; + /// an associated instance of the owner type, which closes the UDP socket upon "drop" + /// and enables configuration of the fwmark value. + fn bind(port: u16) -> Result<(Vec<Self::Reader>, Self::Writer, Self::Owner), Self::Error>; } -pub trait TunBind: Tun { +pub trait PlatformTun: Tun { fn create(name: &str) -> Result<(Vec<Self::Reader>, Self::Writer, Self::MTU), Self::Error>; } diff --git a/src/wireguard/handshake/device.rs b/src/wireguard/handshake/device.rs index 6a55f6e..c2e3a6e 100644 --- a/src/wireguard/handshake/device.rs +++ b/src/wireguard/handshake/device.rs @@ -77,9 +77,9 @@ impl Device { } /// Return the secret key of the device - /// + /// /// # Returns - /// + /// /// A secret key (x25519 scalar) pub fn get_sk(&self) -> StaticSecret { StaticSecret::from(self.sk.to_bytes()) @@ -95,7 +95,7 @@ impl Device { pub fn add(&mut self, pk: PublicKey) -> Result<(), ConfigError> { // check that the pk is not added twice if let Some(_) = self.pk_map.get(pk.as_bytes()) { - return Err(ConfigError::new("Duplicate public key")); + return Ok(()); }; // check that the pk is not that of the device diff --git a/src/wireguard/mod.rs b/src/wireguard/mod.rs index 9417e57..563a22f 100644 --- a/src/wireguard/mod.rs +++ b/src/wireguard/mod.rs @@ -1,7 +1,6 @@ -mod wireguard; -// mod config; mod constants; mod timers; +mod wireguard; mod handshake; mod router; @@ -12,12 +11,14 @@ mod tests; /// The WireGuard sub-module contains a pure, configurable implementation of WireGuard. /// The implementation is generic over: -/// +/// /// - TUN type, specifying how packets are received on the interface side: a reader/writer and MTU reporting interface. /// - Bind type, specifying how WireGuard messages are sent/received from the internet and what constitutes an "endpoint" - -pub use wireguard::{Wireguard, Peer}; +pub use wireguard::{Peer, Wireguard}; pub use types::bind; pub use types::tun; -pub use types::Endpoint;
\ No newline at end of file +pub use types::Endpoint; + +#[cfg(test)] +pub use types::dummy; diff --git a/src/wireguard/types/dummy.rs b/src/wireguard/types/dummy.rs index 2403c9b..384f123 100644 --- a/src/wireguard/types/dummy.rs +++ b/src/wireguard/types/dummy.rs @@ -2,11 +2,11 @@ use std::error::Error; use std::fmt; use std::marker; use std::net::SocketAddr; +use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; use std::sync::Arc; use std::sync::Mutex; use std::time::Instant; -use std::sync::atomic::{Ordering, AtomicUsize}; use super::*; @@ -43,7 +43,7 @@ impl fmt::Display for BindError { #[derive(Debug)] pub enum TunError { - Disconnected + Disconnected, } impl Error for TunError { @@ -76,7 +76,7 @@ impl Endpoint for UnitEndpoint { "127.0.0.1:8080".parse().unwrap() } - fn clear_src(&self) {} + fn clear_src(&mut self) {} } impl UnitEndpoint { @@ -92,21 +92,21 @@ pub struct TunTest {} pub struct TunFakeIO { store: bool, tx: SyncSender<Vec<u8>>, - rx: Receiver<Vec<u8>> + rx: Receiver<Vec<u8>>, } pub struct TunReader { - rx: Receiver<Vec<u8>> + rx: Receiver<Vec<u8>>, } pub struct TunWriter { store: bool, - tx: Mutex<SyncSender<Vec<u8>>> + tx: Mutex<SyncSender<Vec<u8>>>, } #[derive(Clone)] pub struct TunMTU { - mtu: Arc<AtomicUsize> + mtu: Arc<AtomicUsize>, } impl tun::Reader for TunReader { @@ -118,7 +118,7 @@ impl tun::Reader for TunReader { buf[offset..].copy_from_slice(&m[..]); Ok(m.len()) } - Err(_) => Err(TunError::Disconnected) + Err(_) => Err(TunError::Disconnected), } } } @@ -131,7 +131,7 @@ impl tun::Writer for TunWriter { let m = src.to_owned(); match self.tx.lock().unwrap().send(m) { Ok(_) => Ok(()), - Err(_) => Err(TunError::Disconnected) + Err(_) => Err(TunError::Disconnected), } } else { Ok(()) @@ -153,7 +153,7 @@ impl tun::Tun for TunTest { } impl TunFakeIO { - pub fn write(&self, msg : Vec<u8>) { + pub fn write(&self, msg: Vec<u8>) { if self.store { self.tx.send(msg).unwrap(); } @@ -165,15 +165,31 @@ impl TunFakeIO { } impl TunTest { - pub fn create(mtu : usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) { - - let (tx1, rx1) = if store { sync_channel(32) } else { sync_channel(1) }; - let (tx2, rx2) = if store { sync_channel(32) } else { sync_channel(1) }; - - let fake = TunFakeIO{tx: tx1, rx: rx2, store}; - let reader = TunReader{rx : rx1}; - let writer = TunWriter{tx : Mutex::new(tx2), store}; - let mtu = TunMTU{mtu : Arc::new(AtomicUsize::new(mtu))}; + pub fn create(mtu: usize, store: bool) -> (TunFakeIO, TunReader, TunWriter, TunMTU) { + let (tx1, rx1) = if store { + sync_channel(32) + } else { + sync_channel(1) + }; + let (tx2, rx2) = if store { + sync_channel(32) + } else { + sync_channel(1) + }; + + let fake = TunFakeIO { + tx: tx1, + rx: rx2, + store, + }; + let reader = TunReader { rx: rx1 }; + let writer = TunWriter { + tx: Mutex::new(tx2), + store, + }; + let mtu = TunMTU { + mtu: Arc::new(AtomicUsize::new(mtu)), + }; (fake, reader, writer, mtu) } diff --git a/src/wireguard/types/endpoint.rs b/src/wireguard/types/endpoint.rs index f4f93da..4702aab 100644 --- a/src/wireguard/types/endpoint.rs +++ b/src/wireguard/types/endpoint.rs @@ -3,5 +3,5 @@ use std::net::SocketAddr; pub trait Endpoint: Send + 'static { fn from_address(addr: SocketAddr) -> Self; fn into_address(&self) -> SocketAddr; - fn clear_src(&self); + fn clear_src(&mut self); } diff --git a/src/wireguard/types/mod.rs b/src/wireguard/types/mod.rs index e0725f3..20a1238 100644 --- a/src/wireguard/types/mod.rs +++ b/src/wireguard/types/mod.rs @@ -1,10 +1,11 @@ mod endpoint; mod keys; -pub mod tun; + pub mod bind; +pub mod tun; #[cfg(test)] pub mod dummy; pub use endpoint::Endpoint; -pub use keys::{Key, KeyPair};
\ No newline at end of file +pub use keys::{Key, KeyPair}; diff --git a/src/wireguard/wireguard.rs b/src/wireguard/wireguard.rs index 1363c27..9bcac0a 100644 --- a/src/wireguard/wireguard.rs +++ b/src/wireguard/wireguard.rs @@ -54,7 +54,7 @@ pub struct PeerInner<B: Bind> { pub handshake_queued: AtomicBool, pub queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, // handshake queue - pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. + pub pk: PublicKey, // DISCUSS: Change layout in handshake module (adopt pattern of router), to avoid this. TODO: remove pub timers: RwLock<Timers>, // } @@ -99,7 +99,7 @@ pub enum HandshakeJob<E> { New(PublicKey), } -struct WireguardInner<T: Tun, B: Bind> { +pub struct WireguardInner<T: Tun, B: Bind> { // provides access to the MTU value of the tun device // (otherwise owned solely by the router and a dedicated read IO thread) mtu: T::MTU, @@ -118,9 +118,21 @@ struct WireguardInner<T: Tun, B: Bind> { queue: Mutex<Sender<HandshakeJob<B::Endpoint>>>, } +#[derive(Clone)] +pub struct WireguardHandle<T: Tun, B: Bind> { + inner: Arc<WireguardInner<T, B>>, +} + +impl<T: Tun, B: Bind> Deref for WireguardHandle<T, B> { + type Target = Arc<WireguardInner<T, B>>; + fn deref(&self) -> &Self::Target { + &self.inner + } +} + pub struct Wireguard<T: Tun, B: Bind> { runner: Runner, - state: Arc<WireguardInner<T, B>>, + state: WireguardHandle<T, B>, } /* Returns the padded length of a message: @@ -146,6 +158,24 @@ const fn padding(size: usize, mtu: usize) -> usize { } impl<T: Tun, B: Bind> Wireguard<T, B> { + pub fn clear_peers(&self) { + self.state.peers.write().clear(); + } + + pub fn remove_peer(&self, pk: PublicKey) { + self.state.peers.write().remove(pk.as_bytes()); + } + + pub fn list_peers(&self) -> Vec<Peer<T, B>> { + let peers = self.state.peers.read(); + let mut list = Vec::with_capacity(peers.len()); + for (k, v) in peers.iter() { + debug_assert!(k == v.pk.as_bytes()); + list.push(v.clone()); + } + list + } + pub fn set_key(&self, sk: Option<StaticSecret>) { let mut handshake = self.state.handshake.write(); match sk { @@ -170,7 +200,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } } - pub fn new_peer(&self, pk: PublicKey) -> Peer<T, B> { + pub fn new_peer(&self, pk: PublicKey) { let state = Arc::new(PeerInner { pk, last_handshake: Mutex::new(SystemTime::UNIX_EPOCH), @@ -182,8 +212,13 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { timers: RwLock::new(Timers::dummy(&self.runner)), }); + // create a router peer let router = Arc::new(self.state.router.new_peer(state.clone())); + // add to the handshake device + self.state.handshake.write().device.add(pk).unwrap(); // TODO: handle adding of public key for interface + + // form WireGuard peer let peer = Peer { router, state }; /* The need for dummy timers arises from the chicken-egg @@ -193,7 +228,10 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { * TODO: Consider the ease of using atomic pointers instead. */ *peer.timers.write() = Timers::new(&self.runner, peer.clone()); - peer + + // finally, add the peer to the wireguard device + let mut peers = self.state.peers.write(); + peers.entry(*pk.as_bytes()).or_insert(peer); } /* Begin consuming messages from the reader. @@ -417,7 +455,7 @@ impl<T: Tun, B: Bind> Wireguard<T, B> { } Wireguard { - state: wg, + state: WireguardHandle { inner: wg }, runner: Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY), } } |