aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/handshake/device.rs16
-rw-r--r--src/handshake/messages.rs6
-rw-r--r--src/handshake/mod.rs1
-rw-r--r--src/main.rs1
-rw-r--r--src/router/device.rs76
-rw-r--r--src/router/mod.rs4
-rw-r--r--src/router/peer.rs2
-rw-r--r--src/router/tests.rs98
-rw-r--r--src/router/types.rs33
-rw-r--r--src/router/workers.rs4
-rw-r--r--src/types/tun.rs6
-rw-r--r--src/types/udp.rs4
-rw-r--r--src/wireguard.rs75
13 files changed, 182 insertions, 144 deletions
diff --git a/src/handshake/device.rs b/src/handshake/device.rs
index cf88303..5396854 100644
--- a/src/handshake/device.rs
+++ b/src/handshake/device.rs
@@ -4,6 +4,8 @@ use std::net::SocketAddr;
use std::sync::Mutex;
use zerocopy::AsBytes;
+use byteorder::{LittleEndian, ByteOrder};
+
use rand::prelude::*;
use x25519_dalek::PublicKey;
@@ -206,8 +208,14 @@ where
where
&'a S: Into<&'a SocketAddr>,
{
- match msg.get(0) {
- Some(&TYPE_INITIATION) => {
+ // ensure type read in-range
+ if msg.len() < 4 {
+ return Err(HandshakeError::InvalidMessageFormat);
+ }
+
+ // de-multiplex the message type field
+ match LittleEndian::read_u32(msg) {
+ TYPE_INITIATION => {
// parse message
let msg = Initiation::parse(msg)?;
@@ -267,7 +275,7 @@ where
Some(keys),
))
}
- Some(&TYPE_RESPONSE) => {
+ TYPE_RESPONSE => {
let msg = Response::parse(msg)?;
// check mac1 field
@@ -300,7 +308,7 @@ where
// consume inner playload
noise::consume_response(self, &msg.noise)
}
- Some(&TYPE_COOKIE_REPLY) => {
+ TYPE_COOKIE_REPLY => {
let msg = CookieReply::parse(msg)?;
// lookup peer
diff --git a/src/handshake/messages.rs b/src/handshake/messages.rs
index 07c2b1a..8611609 100644
--- a/src/handshake/messages.rs
+++ b/src/handshake/messages.rs
@@ -17,9 +17,9 @@ const SIZE_COOKIE: usize = 16; //
const SIZE_X25519_POINT: usize = 32; // x25519 public key
const SIZE_TIMESTAMP: usize = 12;
-pub const TYPE_INITIATION: u8 = 1;
-pub const TYPE_RESPONSE: u8 = 2;
-pub const TYPE_COOKIE_REPLY: u8 = 3;
+pub const TYPE_INITIATION: u32 = 1;
+pub const TYPE_RESPONSE: u32 = 2;
+pub const TYPE_COOKIE_REPLY: u32 = 3;
/* Handshake messsages */
diff --git a/src/handshake/mod.rs b/src/handshake/mod.rs
index 8095147..8452de8 100644
--- a/src/handshake/mod.rs
+++ b/src/handshake/mod.rs
@@ -18,3 +18,4 @@ mod types;
// publicly exposed interface
pub use device::Device;
+pub use messages::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE };
diff --git a/src/main.rs b/src/main.rs
index 53b2a51..103bc65 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -9,5 +9,6 @@ mod constants;
mod handshake;
mod router;
mod types;
+mod wireguard;
fn main() {}
diff --git a/src/router/device.rs b/src/router/device.rs
index 703fa55..e9e0fb3 100644
--- a/src/router/device.rs
+++ b/src/router/device.rs
@@ -17,7 +17,7 @@ use super::constants::*;
use super::ip::*;
use super::messages::{TransportHeader, TYPE_TRANSPORT};
use super::peer::{new_peer, Peer, PeerInner};
-use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
+use super::types::{Callbacks, Opaque, RouterError};
use super::workers::{worker_parallel, JobParallel, Operation};
use super::SIZE_MESSAGE_PREFIX;
@@ -27,9 +27,6 @@ pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
// IO & timer callbacks
pub tun: T,
pub bind: B,
- pub call_recv: C::CallbackRecv,
- pub call_send: C::CallbackSend,
- pub call_need_key: C::CallbackKey,
// routing
pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
@@ -83,47 +80,6 @@ impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
}
}
-impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bind>
- Device<PhantomCallbacks<O, R, S, K>, T, B>
-{
- pub fn new(
- num_workers: usize,
- tun: T,
- bind: B,
- call_send: S,
- call_recv: R,
- call_need_key: K,
- ) -> Device<PhantomCallbacks<O, R, S, K>, T, B> {
- // allocate shared device state
- let mut inner = DeviceInner {
- tun,
- bind,
- call_recv,
- call_send,
- queues: Mutex::new(Vec::with_capacity(num_workers)),
- queue_next: AtomicUsize::new(0),
- call_need_key,
- recv: RwLock::new(HashMap::new()),
- ipv4: RwLock::new(IpLookupTable::new()),
- ipv6: RwLock::new(IpLookupTable::new()),
- };
-
- // start worker threads
- let mut threads = Vec::with_capacity(num_workers);
- for _ in 0..num_workers {
- let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
- inner.queues.lock().push(tx);
- threads.push(thread::spawn(move || worker_parallel(rx)));
- }
-
- // return exported device handle
- Device {
- state: Arc::new(inner),
- handles: threads,
- }
- }
-}
-
#[inline(always)]
fn get_route<C: Callbacks, T: Tun, B: Bind>(
device: &Arc<DeviceInner<C, T, B>>,
@@ -165,6 +121,34 @@ fn get_route<C: Callbacks, T: Tun, B: Bind>(
}
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
+
+ pub fn new(num_workers: usize, tun: T, bind: B) -> Device<C, T, B> {
+ // allocate shared device state
+ let mut inner = DeviceInner {
+ tun,
+ bind,
+ queues: Mutex::new(Vec::with_capacity(num_workers)),
+ queue_next: AtomicUsize::new(0),
+ recv: RwLock::new(HashMap::new()),
+ ipv4: RwLock::new(IpLookupTable::new()),
+ ipv6: RwLock::new(IpLookupTable::new()),
+ };
+
+ // start worker threads
+ let mut threads = Vec::with_capacity(num_workers);
+ for _ in 0..num_workers {
+ let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
+ inner.queues.lock().push(tx);
+ threads.push(thread::spawn(move || worker_parallel(rx)));
+ }
+
+ // return exported device handle
+ Device {
+ state: Arc::new(inner),
+ handles: threads,
+ }
+ }
+
/// Adds a new peer to the device
///
/// # Returns
@@ -228,7 +212,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
let dec = self.state.recv.read();
let dec = dec
.get(&header.f_receiver.get())
- .ok_or(RouterError::UnkownReceiverId)?;
+ .ok_or(RouterError::UnknownReceiverId)?;
// schedule for decryption and TUN write
if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
diff --git a/src/router/mod.rs b/src/router/mod.rs
index 8cd0d3b..7a29cd9 100644
--- a/src/router/mod.rs
+++ b/src/router/mod.rs
@@ -14,5 +14,9 @@ use messages::TransportHeader;
use std::mem;
pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::<TransportHeader>();
+pub const CAPACITY_MESSAGE_POSTFIX: usize = 16;
+
+pub use messages::TYPE_TRANSPORT;
pub use device::Device;
pub use peer::Peer;
+pub use types::Callbacks;
diff --git a/src/router/peer.rs b/src/router/peer.rs
index 43317cc..f032f45 100644
--- a/src/router/peer.rs
+++ b/src/router/peer.rs
@@ -280,7 +280,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
None => {
// add to staged packets (create no job)
debug!("execute callback: call_need_key");
- (self.device.call_need_key)(&self.opaque);
+ C::need_key(&self.opaque);
self.staged_packets.lock().push_back(msg);
return None;
}
diff --git a/src/router/tests.rs b/src/router/tests.rs
index 80c3ea9..c2ea225 100644
--- a/src/router/tests.rs
+++ b/src/router/tests.rs
@@ -13,7 +13,7 @@ use pnet::packet::ipv4::MutableIpv4Packet;
use pnet::packet::ipv6::MutableIpv6Packet;
use super::super::types::{Bind, Key, KeyPair, Tun};
-use super::{Device, SIZE_MESSAGE_PREFIX};
+use super::{Callbacks, Device, SIZE_MESSAGE_PREFIX};
extern crate test;
@@ -82,6 +82,7 @@ impl Into<SocketAddr> for UnitEndpoint {
}
}
+#[derive(Clone, Copy)]
struct TunTest {}
impl Tun for TunTest {
@@ -102,6 +103,7 @@ impl Tun for TunTest {
/* Bind implemenentations */
+#[derive(Clone, Copy)]
struct VoidBind {}
impl Bind for VoidBind {
@@ -166,7 +168,7 @@ impl Bind for PairBind {
Ok((vec.len(), UnitEndpoint {}))
}
- fn send(&self, buf: &[u8], dst: &Self::Endpoint) -> Result<(), Self::Error> {
+ fn send(&self, buf: &[u8], _dst: &Self::Endpoint) -> Result<(), Self::Error> {
let owned = buf.to_owned();
match self.send.lock().unwrap().send(owned) {
Err(_) => Err(BindError::Disconnected),
@@ -221,7 +223,7 @@ mod tests {
use super::*;
use env_logger;
use log::debug;
- use std::sync::atomic::{AtomicU64, AtomicUsize};
+ use std::sync::atomic::AtomicUsize;
use test::Bencher;
// type for tracking events inside the router module
@@ -234,6 +236,8 @@ mod tests {
#[derive(Clone)]
struct Opaque(Arc<Flags>);
+ struct TestCallbacks();
+
impl Opaque {
fn new() -> Opaque {
Opaque(Arc::new(Flags {
@@ -269,16 +273,20 @@ mod tests {
}
}
- fn callback_send(t: &Opaque, size: usize, data: bool, sent: bool) {
- t.0.send.lock().unwrap().push((size, data, sent))
- }
+ impl Callbacks for TestCallbacks {
+ type Opaque = Opaque;
- fn callback_recv(t: &Opaque, size: usize, data: bool, sent: bool) {
- t.0.recv.lock().unwrap().push((size, data, sent))
- }
+ fn send(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
+ t.0.send.lock().unwrap().push((size, data, sent))
+ }
+
+ fn recv(t: &Self::Opaque, size: usize, data: bool, sent: bool) {
+ t.0.recv.lock().unwrap().push((size, data, sent))
+ }
- fn callback_need_key(t: &Opaque) {
- t.0.need_key.lock().unwrap().push(());
+ fn need_key(t: &Self::Opaque) {
+ t.0.need_key.lock().unwrap().push(());
+ }
}
fn init() {
@@ -306,19 +314,19 @@ mod tests {
#[bench]
fn bench_outbound(b: &mut Bencher) {
- type Opaque = Arc<AtomicUsize>;
+ struct BencherCallbacks {}
+ impl Callbacks for BencherCallbacks {
+ type Opaque = Arc<AtomicUsize>;
+ fn send(t: &Self::Opaque, size: usize, _data: bool, _sent: bool) {
+ t.fetch_add(size, Ordering::SeqCst);
+ }
+ fn recv(_: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
+ fn need_key(_: &Self::Opaque) {}
+ }
// create device
- let router = Device::new(
- num_cpus::get(),
- TunTest {},
- VoidBind::new(),
- |t: &Opaque, size: usize, _data: bool, _sent: bool| {
- t.fetch_add(size, Ordering::SeqCst);
- },
- |t: &Opaque, _size: usize, _data: bool, _sent: bool| {},
- |t: &Opaque| (),
- );
+ let router: Device<BencherCallbacks, TunTest, VoidBind> =
+ Device::new(num_cpus::get(), TunTest {}, VoidBind::new());
// add new peer
let opaque = Arc::new(AtomicUsize::new(0));
@@ -328,15 +336,15 @@ mod tests {
// add subnet to peer
let (mask, len, ip) = ("192.168.1.0", 24, "192.168.1.20");
let mask: IpAddr = mask.parse().unwrap();
- let ip: IpAddr = ip.parse().unwrap();
+ let ip1: IpAddr = ip.parse().unwrap();
peer.add_subnet(mask, len);
- // every iteration sends 10 MB
+ // every iteration sends 50 GB
b.iter(|| {
opaque.store(0, Ordering::SeqCst);
- while opaque.load(Ordering::Acquire) < 10 * 1024 {
- let msg = make_packet(1024, ip);
- router.send(msg).unwrap();
+ let msg = make_packet(1024, ip1);
+ while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 {
+ router.send(msg.to_vec()).unwrap();
}
});
}
@@ -346,14 +354,7 @@ mod tests {
init();
// create device
- let router = Device::new(
- 1,
- TunTest {},
- VoidBind::new(),
- callback_send,
- callback_recv,
- callback_need_key,
- );
+ let router: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, VoidBind::new());
let tests = vec![
("192.168.1.0", 24, "192.168.1.20", true),
@@ -447,7 +448,7 @@ mod tests {
}
fn wait() {
- thread::sleep(Duration::from_millis(10));
+ thread::sleep(Duration::from_millis(20));
}
#[test]
@@ -472,23 +473,9 @@ mod tests {
// create matching devices
- let router1 = Device::new(
- 1,
- TunTest {},
- bind1.clone(),
- callback_send,
- callback_recv,
- callback_need_key,
- );
-
- let router2 = Device::new(
- 1,
- TunTest {},
- bind2.clone(),
- callback_send,
- callback_recv,
- callback_need_key,
- );
+ let router1: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind1.clone());
+
+ let router2: Device<TestCallbacks, _, _> = Device::new(1, TunTest {}, bind2.clone());
// prepare opaque values for tracing callbacks
@@ -514,6 +501,7 @@ mod tests {
let (_mask, _len, ip, _okay) = p2;
let msg = make_packet(1024, ip.parse().unwrap());
router2.send(msg).expect("failed to sent staged packet");
+
wait();
assert!(opaq2.recv().is_none());
assert!(
@@ -537,7 +525,7 @@ mod tests {
assert!(opaq2.recv().is_none());
assert!(opaq2.need_key().is_none());
assert!(opaq2.is_empty());
- assert!(opaq1.is_empty(), "nothing should happend on peer1");
+ assert!(opaq1.is_empty(), "nothing should happened on peer1");
// read confirming message received by the other end ("across the internet")
let mut buf = vec![0u8; 2048];
@@ -551,7 +539,7 @@ mod tests {
assert!(opaq1.need_key().is_none());
assert!(opaq1.is_empty());
assert!(peer1.get_endpoint().is_some());
- assert!(opaq2.is_empty(), "nothing should happend on peer2");
+ assert!(opaq2.is_empty(), "nothing should happened on peer2");
// how that peer1 has an endpoint
// route packets : peer1 -> peer2
diff --git a/src/router/types.rs b/src/router/types.rs
index 61f1fe7..f9f867a 100644
--- a/src/router/types.rs
+++ b/src/router/types.rs
@@ -22,34 +22,11 @@ pub trait KeyCallback<T>: Fn(&T) -> () + Sync + Send + 'static {}
impl<T, F> KeyCallback<T> for F where F: Fn(&T) -> () + Sync + Send + 'static {}
-pub trait Endpoint: Send + Sync {}
-
pub trait Callbacks: Send + Sync + 'static {
type Opaque: Opaque;
- type CallbackRecv: Callback<Self::Opaque>;
- type CallbackSend: Callback<Self::Opaque>;
- type CallbackKey: KeyCallback<Self::Opaque>;
-}
-
-/* Concrete implementation of "Callbacks",
- * used to hide the constituent type parameters.
- *
- * This type is never instantiated.
- */
-pub struct PhantomCallbacks<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> {
- _phantom_opaque: PhantomData<O>,
- _phantom_recv: PhantomData<R>,
- _phantom_send: PhantomData<S>,
- _phantom_key: PhantomData<K>,
-}
-
-impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>> Callbacks
- for PhantomCallbacks<O, R, S, K>
-{
- type Opaque = O;
- type CallbackRecv = R;
- type CallbackSend = S;
- type CallbackKey = K;
+ fn send(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
+ fn recv(_opaque: &Self::Opaque, _size: usize, _data: bool, _sent: bool) {}
+ fn need_key(_opaque: &Self::Opaque) {}
}
#[derive(Debug)]
@@ -57,7 +34,7 @@ pub enum RouterError {
NoCryptKeyRoute,
MalformedIPHeader,
MalformedTransportMessage,
- UnkownReceiverId,
+ UnknownReceiverId,
NoEndpoint,
SendError,
}
@@ -68,7 +45,7 @@ impl fmt::Display for RouterError {
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
- RouterError::UnkownReceiverId => {
+ RouterError::UnknownReceiverId => {
write!(f, "No decryption state associated with receiver id")
}
RouterError::NoEndpoint => write!(f, "No endpoint for peer"),
diff --git a/src/router/workers.rs b/src/router/workers.rs
index 5415e8c..6710816 100644
--- a/src/router/workers.rs
+++ b/src/router/workers.rs
@@ -167,7 +167,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
}
// trigger callback
- (device.call_recv)(&peer.opaque, buf.msg.len(), length == 0, sent);
+ C::recv(&peer.opaque, buf.msg.len(), length == 0, sent);
} else {
debug!("inbound worker: authentication failure")
}
@@ -210,7 +210,7 @@ pub fn worker_outbound<C: Callbacks, T: Tun, B: Bind>(
};
// trigger callback
- (device.call_send)(
+ C::send(
&peer.opaque,
buf.msg.len(),
buf.msg.len() > SIZE_TAG + mem::size_of::<TransportHeader>(),
diff --git a/src/types/tun.rs b/src/types/tun.rs
index b36089e..fc8044a 100644
--- a/src/types/tun.rs
+++ b/src/types/tun.rs
@@ -1,13 +1,13 @@
use std::error;
-pub trait Tun: Send + Sync + 'static {
+pub trait Tun: Send + Sync + Clone + 'static {
type Error: error::Error;
/// Returns the MTU of the device
///
/// This function needs to be efficient (called for every read).
- /// The goto implementation stragtegy is to .load an atomic variable,
- /// then use e.g. netlink to update the variable in a seperate thread.
+ /// The goto implementation strategy is to .load an atomic variable,
+ /// then use e.g. netlink to update the variable in a separate thread.
///
/// # Returns
///
diff --git a/src/types/udp.rs b/src/types/udp.rs
index 71d5a79..943bf94 100644
--- a/src/types/udp.rs
+++ b/src/types/udp.rs
@@ -3,8 +3,8 @@ use std::error;
/* Often times an a file descriptor in an atomic might suffice.
*/
-pub trait Bind: Send + Sync + 'static {
- type Error: error::Error;
+pub trait Bind: Send + Sync + Clone + 'static {
+ type Error: error::Error + Send;
type Endpoint: Endpoint;
fn new() -> Self;
diff --git a/src/wireguard.rs b/src/wireguard.rs
new file mode 100644
index 0000000..0bd5da7
--- /dev/null
+++ b/src/wireguard.rs
@@ -0,0 +1,75 @@
+use crate::handshake;
+use crate::router;
+use crate::types::{Bind, Tun};
+
+use byteorder::{ByteOrder, LittleEndian};
+
+use std::thread;
+
+use x25519_dalek::StaticSecret;
+
+pub struct Timers {}
+
+pub struct Events();
+
+impl router::Callbacks for Events {
+ type Opaque = Timers;
+
+ fn send(t: &Timers, size: usize, data: bool, sent: bool) {}
+
+ fn recv(t: &Timers, size: usize, data: bool, sent: bool) {}
+
+ fn need_key(t: &Timers) {}
+}
+
+pub struct Wireguard<T: Tun, B: Bind> {
+ router: router::Device<Events, T, B>,
+ handshake: Option<handshake::Device<()>>,
+}
+
+impl<T: Tun, B: Bind> Wireguard<T, B> {
+ fn new(tun: T, bind: B) -> Wireguard<T, B> {
+ let router = router::Device::new(num_cpus::get(), tun.clone(), bind.clone());
+
+ // start UDP read IO thread
+ {
+ let tun = tun.clone();
+ thread::spawn(move || {
+ loop {
+ // read UDP packet into vector
+ let size = tun.mtu() + 148; // maximum message size
+ let mut msg: Vec<u8> =
+ Vec::with_capacity(size + router::CAPACITY_MESSAGE_POSTFIX);
+ msg.resize(size, 0);
+ let (size, src) = bind.recv(&mut msg).unwrap(); // TODO handle error
+ msg.truncate(size);
+
+ // message type de-multiplexer
+ if msg.len() < 4 {
+ continue;
+ }
+ match LittleEndian::read_u32(&msg[..]) {
+ handshake::TYPE_COOKIE_REPLY
+ | handshake::TYPE_INITIATION
+ | handshake::TYPE_RESPONSE => {
+ // handshake message
+ }
+ router::TYPE_TRANSPORT => {
+ // transport message
+ }
+ _ => (),
+ }
+ }
+ });
+ }
+
+ // start TUN read IO thread
+
+ thread::spawn(move || {});
+
+ Wireguard {
+ router,
+ handshake: None,
+ }
+ }
+}