From b31becda71feace70f96043cd39bbe022a054225 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Sat, 14 Sep 2019 12:43:09 +0200 Subject: Begin work on the pure Wireguard implemenation Start joining the handshake device and router device in the top-level Wireguard implemenation. --- src/handshake/device.rs | 16 ++++++-- src/handshake/messages.rs | 6 +-- src/handshake/mod.rs | 1 + src/main.rs | 1 + src/router/device.rs | 76 +++++++++++++++--------------------- src/router/mod.rs | 4 ++ src/router/peer.rs | 2 +- src/router/tests.rs | 98 +++++++++++++++++++++-------------------------- src/router/types.rs | 33 +++------------- src/router/workers.rs | 4 +- src/types/tun.rs | 6 +-- src/types/udp.rs | 4 +- src/wireguard.rs | 75 ++++++++++++++++++++++++++++++++++++ 13 files changed, 182 insertions(+), 144 deletions(-) create mode 100644 src/wireguard.rs (limited to 'src') diff --git a/src/handshake/device.rs b/src/handshake/device.rs index cf88303..5396854 100644 --- a/src/handshake/device.rs +++ b/src/handshake/device.rs @@ -4,6 +4,8 @@ use std::net::SocketAddr; use std::sync::Mutex; use zerocopy::AsBytes; +use byteorder::{LittleEndian, ByteOrder}; + use rand::prelude::*; use x25519_dalek::PublicKey; @@ -206,8 +208,14 @@ where where &'a S: Into<&'a SocketAddr>, { - match msg.get(0) { - Some(&TYPE_INITIATION) => { + // ensure type read in-range + if msg.len() < 4 { + return Err(HandshakeError::InvalidMessageFormat); + } + + // de-multiplex the message type field + match LittleEndian::read_u32(msg) { + TYPE_INITIATION => { // parse message let msg = Initiation::parse(msg)?; @@ -267,7 +275,7 @@ where Some(keys), )) } - Some(&TYPE_RESPONSE) => { + TYPE_RESPONSE => { let msg = Response::parse(msg)?; // check mac1 field @@ -300,7 +308,7 @@ where // consume inner playload noise::consume_response(self, &msg.noise) } - Some(&TYPE_COOKIE_REPLY) => { + TYPE_COOKIE_REPLY => { let msg = CookieReply::parse(msg)?; // lookup peer diff --git a/src/handshake/messages.rs b/src/handshake/messages.rs index 07c2b1a..8611609 100644 --- a/src/handshake/messages.rs +++ b/src/handshake/messages.rs @@ -17,9 +17,9 @@ const SIZE_COOKIE: usize = 16; // const SIZE_X25519_POINT: usize = 32; // x25519 public key const SIZE_TIMESTAMP: usize = 12; -pub const TYPE_INITIATION: u8 = 1; -pub const TYPE_RESPONSE: u8 = 2; -pub const TYPE_COOKIE_REPLY: u8 = 3; +pub const TYPE_INITIATION: u32 = 1; +pub const TYPE_RESPONSE: u32 = 2; +pub const TYPE_COOKIE_REPLY: u32 = 3; /* Handshake messsages */ diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs index 8095147..8452de8 100644 --- a/src/handshake/mod.rs +++ b/src/handshake/mod.rs @@ -18,3 +18,4 @@ mod types; // publicly exposed interface pub use device::Device; +pub use messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE }; diff --git a/src/main.rs b/src/main.rs index 53b2a51..103bc65 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,5 +9,6 @@ mod constants; mod handshake; mod router; mod types; +mod wireguard; fn main() {} diff --git a/src/router/device.rs b/src/router/device.rs index 703fa55..e9e0fb3 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -17,7 +17,7 @@ use super::constants::*; use super::ip::*; use super::messages::{TransportHeader, TYPE_TRANSPORT}; use super::peer::{new_peer, Peer, PeerInner}; -use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError}; +use super::types::{Callbacks, Opaque, RouterError}; use super::workers::{worker_parallel, JobParallel, Operation}; use super::SIZE_MESSAGE_PREFIX; @@ -27,9 +27,6 @@ pub struct DeviceInner { // IO & timer callbacks pub tun: T, pub bind: B, - pub call_recv: C::CallbackRecv, - pub call_send: C::CallbackSend, - pub call_need_key: C::CallbackKey, // routing pub recv: RwLock>>>, // receiver id -> decryption state @@ -83,47 +80,6 @@ impl Drop for Device { } } -impl, S: Callback, K: KeyCallback, T: Tun, B: Bind> - Device, T, B> -{ - pub fn new( - num_workers: usize, - tun: T, - bind: B, - call_send: S, - call_recv: R, - call_need_key: K, - ) -> Device, T, B> { - // allocate shared device state - let mut inner = DeviceInner { - tun, - bind, - call_recv, - call_send, - queues: Mutex::new(Vec::with_capacity(num_workers)), - queue_next: AtomicUsize::new(0), - call_need_key, - recv: RwLock::new(HashMap::new()), - ipv4: RwLock::new(IpLookupTable::new()), - ipv6: RwLock::new(IpLookupTable::new()), - }; - - // start worker threads - let mut threads = Vec::with_capacity(num_workers); - for _ in 0..num_workers { - let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); - inner.queues.lock().push(tx); - threads.push(thread::spawn(move || worker_parallel(rx))); - } - - // return exported device handle - Device { - state: Arc::new(inner), - handles: threads, - } - } -} - #[inline(always)] fn get_route( device: &Arc>, @@ -165,6 +121,34 @@ fn get_route( } impl Device { + + pub fn new(num_workers: usize, tun: T, bind: B) -> Device { + // allocate shared device state + let mut inner = DeviceInner { + tun, + bind, + queues: Mutex::new(Vec::with_capacity(num_workers)), + queue_next: AtomicUsize::new(0), + recv: RwLock::new(HashMap::new()), + ipv4: RwLock::new(IpLookupTable::new()), + ipv6: RwLock::new(IpLookupTable::new()), + }; + + // start worker threads + let mut threads = Vec::with_capacity(num_workers); + for _ in 0..num_workers { + let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); + inner.queues.lock().push(tx); + threads.push(thread::spawn(move || worker_parallel(rx))); + } + + // return exported device handle + Device { + state: Arc::new(inner), + handles: threads, + } + } + /// Adds a new peer to the device /// /// # Returns @@ -228,7 +212,7 @@ impl Device { let dec = self.state.recv.read(); let dec = dec .get(&header.f_receiver.get()) - .ok_or(RouterError::UnkownReceiverId)?; + .ok_or(RouterError::UnknownReceiverId)?; // schedule for decryption and TUN write if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { diff --git a/src/router/mod.rs b/src/router/mod.rs index 8cd0d3b..7a29cd9 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -14,5 +14,9 @@ use messages::TransportHeader; use std::mem; pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::(); +pub const CAPACITY_MESSAGE_POSTFIX: usize = 16; + +pub use messages::TYPE_TRANSPORT; pub use device::Device; pub use peer::Peer; +pub use types::Callbacks; diff --git a/src/router/peer.rs b/src/router/peer.rs index 43317cc..f032f45 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -280,7 +280,7 @@ impl PeerInner { None => { // add to staged packets (create no job) debug!("execute callback: call_need_key"); - (self.device.call_need_key)(&self.opaque); + C::need_key(&self.opaque); self.staged_packets.lock().push_back(msg); return None; } diff --git a/src/router/tests.rs b/src/router/tests.rs index 80c3ea9..c2ea225 100644 --- a/src/router/tests.rs +++ b/src/router/tests.rs @@ -13,7 +13,7 @@ use pnet::packet::ipv4::MutableIpv4Packet; use pnet::packet::ipv6::MutableIpv6Packet; use super::super::types::{Bind, Key, KeyPair, Tun}; -use super::{Device, SIZE_MESSAGE_PREFIX}; +use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX}; extern crate test; @@ -82,6 +82,7 @@ impl Into for UnitEndpoint { } } +#[derive(Clone, Copy)] struct TunTest {} impl Tun for TunTest { @@ -102,6 +103,7 @@ impl Tun for TunTest { /* Bind implemenentations */ +#[derive(Clone, Copy)] struct VoidBind {} impl Bind for VoidBind { @@ -166,7 +168,7 @@ impl Bind for PairBind { Ok((vec.len(), UnitEndpoint {})) } - fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> { + fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> { let owned = buf.to_owned(); match self.send.lock().unwrap().send(owned) { Err(_) => Err(BindError::Disconnected), @@ -221,7 +223,7 @@ mod tests { use super::*; use env_logger; use log::debug; - use std::sync::atomic::{AtomicU64, AtomicUsize}; + use std::sync::atomic::AtomicUsize; use test::Bencher; // type for tracking events inside the router module @@ -234,6 +236,8 @@ mod tests { #[derive(Clone)] struct Opaque(Arc); + struct TestCallbacks(); + impl Opaque { fn new() -> Opaque { Opaque(Arc::new(Flags { @@ -269,16 +273,20 @@ mod tests { } } - fn callback_send(t: &Opaque, size: usize, data: bool, sent: bool) { - t.0.send.lock().unwrap().push((size, data, sent)) - } + impl Callbacks for TestCallbacks { + type Opaque = Opaque; - fn callback_recv(t: &Opaque, size: usize, data: bool, sent: bool) { - t.0.recv.lock().unwrap().push((size, data, sent)) - } + fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) { + t.0.send.lock().unwrap().push((size, data, sent)) + } + + fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) { + t.0.recv.lock().unwrap().push((size, data, sent)) + } - fn callback_need_key(t: &Opaque) { - t.0.need_key.lock().unwrap().push(()); + fn need_key(t: &Self::Opaque) { + t.0.need_key.lock().unwrap().push(()); + } } fn init() { @@ -306,19 +314,19 @@ mod tests { #[bench] fn bench_outbound(b: &mut Bencher) { - type Opaque = Arc; + struct BencherCallbacks {} + impl Callbacks for BencherCallbacks { + type Opaque = Arc; + fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) { + t.fetch_add(size, Ordering::SeqCst); + } + fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} + fn need_key(_: &Self::Opaque) {} + } // create device - let router = Device::new( - num_cpus::get(), - TunTest {}, - VoidBind::new(), - |t: &Opaque, size: usize, _data: bool, _sent: bool| { - t.fetch_add(size, Ordering::SeqCst); - }, - |t: &Opaque, _size: usize, _data: bool, _sent: bool| {}, - |t: &Opaque| (), - ); + let router: Device = + Device::new(num_cpus::get(), TunTest {}, VoidBind::new()); // add new peer let opaque = Arc::new(AtomicUsize::new(0)); @@ -328,15 +336,15 @@ mod tests { // add subnet to peer let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20"); let mask: IpAddr = mask.parse().unwrap(); - let ip: IpAddr = ip.parse().unwrap(); + let ip1: IpAddr = ip.parse().unwrap(); peer.add_subnet(mask, len); - // every iteration sends 10 MB + // every iteration sends 50 GB b.iter(|| { opaque.store(0, Ordering::SeqCst); - while opaque.load(Ordering::Acquire) < 10 * 1024 { - let msg = make_packet(1024, ip); - router.send(msg).unwrap(); + let msg = make_packet(1024, ip1); + while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 { + router.send(msg.to_vec()).unwrap(); } }); } @@ -346,14 +354,7 @@ mod tests { init(); // create device - let router = Device::new( - 1, - TunTest {}, - VoidBind::new(), - callback_send, - callback_recv, - callback_need_key, - ); + let router: Device = Device::new(1, TunTest {}, VoidBind::new()); let tests = vec![ ("192.168.1.0", 24, "192.168.1.20", true), @@ -447,7 +448,7 @@ mod tests { } fn wait() { - thread::sleep(Duration::from_millis(10)); + thread::sleep(Duration::from_millis(20)); } #[test] @@ -472,23 +473,9 @@ mod tests { // create matching devices - let router1 = Device::new( - 1, - TunTest {}, - bind1.clone(), - callback_send, - callback_recv, - callback_need_key, - ); - - let router2 = Device::new( - 1, - TunTest {}, - bind2.clone(), - callback_send, - callback_recv, - callback_need_key, - ); + let router1: Device = Device::new(1, TunTest {}, bind1.clone()); + + let router2: Device = Device::new(1, TunTest {}, bind2.clone()); // prepare opaque values for tracing callbacks @@ -514,6 +501,7 @@ mod tests { let (_mask, _len, ip, _okay) = p2; let msg = make_packet(1024, ip.parse().unwrap()); router2.send(msg).expect("failed to sent staged packet"); + wait(); assert!(opaq2.recv().is_none()); assert!( @@ -537,7 +525,7 @@ mod tests { assert!(opaq2.recv().is_none()); assert!(opaq2.need_key().is_none()); assert!(opaq2.is_empty()); - assert!(opaq1.is_empty(), "nothing should happend on peer1"); + assert!(opaq1.is_empty(), "nothing should happened on peer1"); // read confirming message received by the other end ("across the internet") let mut buf = vec![0u8; 2048]; @@ -551,7 +539,7 @@ mod tests { assert!(opaq1.need_key().is_none()); assert!(opaq1.is_empty()); assert!(peer1.get_endpoint().is_some()); - assert!(opaq2.is_empty(), "nothing should happend on peer2"); + assert!(opaq2.is_empty(), "nothing should happened on peer2"); // how that peer1 has an endpoint // route packets : peer1 -> peer2 diff --git a/src/router/types.rs b/src/router/types.rs index 61f1fe7..f9f867a 100644 --- a/src/router/types.rs +++ b/src/router/types.rs @@ -22,34 +22,11 @@ pub trait KeyCallback: Fn(&T) -> () + Sync + Send + 'static {} impl KeyCallback for F where F: Fn(&T) -> () + Sync + Send + 'static {} -pub trait Endpoint: Send + Sync {} - pub trait Callbacks: Send + Sync + 'static { type Opaque: Opaque; - type CallbackRecv: Callback; - type CallbackSend: Callback; - type CallbackKey: KeyCallback; -} - -/* Concrete implementation of "Callbacks", - * used to hide the constituent type parameters. - * - * This type is never instantiated. - */ -pub struct PhantomCallbacks, S: Callback, K: KeyCallback> { - _phantom_opaque: PhantomData, - _phantom_recv: PhantomData, - _phantom_send: PhantomData, - _phantom_key: PhantomData, -} - -impl, S: Callback, K: KeyCallback> Callbacks - for PhantomCallbacks -{ - type Opaque = O; - type CallbackRecv = R; - type CallbackSend = S; - type CallbackKey = K; + fn send(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} + fn recv(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {} + fn need_key(_opaque: &Self::Opaque) {} } #[derive(Debug)] @@ -57,7 +34,7 @@ pub enum RouterError { NoCryptKeyRoute, MalformedIPHeader, MalformedTransportMessage, - UnkownReceiverId, + UnknownReceiverId, NoEndpoint, SendError, } @@ -68,7 +45,7 @@ impl fmt::Display for RouterError { RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"), RouterError::MalformedIPHeader => write!(f, "IP header is malformed"), RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"), - RouterError::UnkownReceiverId => { + RouterError::UnknownReceiverId => { write!(f, "No decryption state associated with receiver id") } RouterError::NoEndpoint => write!(f, "No endpoint for peer"), diff --git a/src/router/workers.rs b/src/router/workers.rs index 5415e8c..6710816 100644 --- a/src/router/workers.rs +++ b/src/router/workers.rs @@ -167,7 +167,7 @@ pub fn worker_inbound( } // trigger callback - (device.call_recv)(&peer.opaque, buf.msg.len(), length == 0, sent); + C::recv(&peer.opaque, buf.msg.len(), length == 0, sent); } else { debug!("inbound worker: authentication failure") } @@ -210,7 +210,7 @@ pub fn worker_outbound( }; // trigger callback - (device.call_send)( + C::send( &peer.opaque, buf.msg.len(), buf.msg.len() > SIZE_TAG + mem::size_of::(), diff --git a/src/types/tun.rs b/src/types/tun.rs index b36089e..fc8044a 100644 --- a/src/types/tun.rs +++ b/src/types/tun.rs @@ -1,13 +1,13 @@ use std::error; -pub trait Tun: Send + Sync + 'static { +pub trait Tun: Send + Sync + Clone + 'static { type Error: error::Error; /// Returns the MTU of the device /// /// This function needs to be efficient (called for every read). - /// The goto implementation stragtegy is to .load an atomic variable, - /// then use e.g. netlink to update the variable in a seperate thread. + /// The goto implementation strategy is to .load an atomic variable, + /// then use e.g. netlink to update the variable in a separate thread. /// /// # Returns /// diff --git a/src/types/udp.rs b/src/types/udp.rs index 71d5a79..943bf94 100644 --- a/src/types/udp.rs +++ b/src/types/udp.rs @@ -3,8 +3,8 @@ use std::error; /* Often times an a file descriptor in an atomic might suffice. */ -pub trait Bind: Send + Sync + 'static { - type Error: error::Error; +pub trait Bind: Send + Sync + Clone + 'static { + type Error: error::Error + Send; type Endpoint: Endpoint; fn new() -> Self; diff --git a/src/wireguard.rs b/src/wireguard.rs new file mode 100644 index 0000000..0bd5da7 --- /dev/null +++ b/src/wireguard.rs @@ -0,0 +1,75 @@ +use crate::handshake; +use crate::router; +use crate::types::{Bind, Tun}; + +use byteorder::{ByteOrder, LittleEndian}; + +use std::thread; + +use x25519_dalek::StaticSecret; + +pub struct Timers {} + +pub struct Events(); + +impl router::Callbacks for Events { + type Opaque = Timers; + + fn send(t: &Timers, size: usize, data: bool, sent: bool) {} + + fn recv(t: &Timers, size: usize, data: bool, sent: bool) {} + + fn need_key(t: &Timers) {} +} + +pub struct Wireguard { + router: router::Device, + handshake: Option>, +} + +impl Wireguard { + fn new(tun: T, bind: B) -> Wireguard { + let router = router::Device::new(num_cpus::get(), tun.clone(), bind.clone()); + + // start UDP read IO thread + { + let tun = tun.clone(); + thread::spawn(move || { + loop { + // read UDP packet into vector + let size = tun.mtu() + 148; // maximum message size + let mut msg: Vec = + Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX); + msg.resize(size, 0); + let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error + msg.truncate(size); + + // message type de-multiplexer + if msg.len() < 4 { + continue; + } + match LittleEndian::read_u32(&msg[..]) { + handshake::TYPE_COOKIE_REPLY + | handshake::TYPE_INITIATION + | handshake::TYPE_RESPONSE => { + // handshake message + } + router::TYPE_TRANSPORT => { + // transport message + } + _ => (), + } + } + }); + } + + // start TUN read IO thread + + thread::spawn(move || {}); + + Wireguard { + router, + handshake: None, + } + } +} -- cgit v1.2.3-59-g8ed1b