summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-17 16:31:08 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-08-17 16:31:08 +0200
commit78ab1a93e6d519bf404fbe61fc7ec3c3ab35a72a (patch)
tree75106e1ff89a03a6869184994b902a70315dfc30
parentBegin drafting cross-platform interface (diff)
downloadwireguard-rs-78ab1a93e6d519bf404fbe61fc7ec3c3ab35a72a.tar.xz
wireguard-rs-78ab1a93e6d519bf404fbe61fc7ec3c3ab35a72a.zip
Remove peer from cryptkey router on drop
-rw-r--r--src/constants.rs11
-rw-r--r--src/main.rs2
-rw-r--r--src/platform/tun.rs10
-rw-r--r--src/platform/udp.rs11
-rw-r--r--src/router/device.rs182
-rw-r--r--src/types/keys.rs26
-rw-r--r--src/types/mod.rs31
-rw-r--r--src/types/tun.rs43
-rw-r--r--src/types/udp.rs26
9 files changed, 242 insertions, 100 deletions
diff --git a/src/constants.rs b/src/constants.rs
new file mode 100644
index 0000000..829deac
--- /dev/null
+++ b/src/constants.rs
@@ -0,0 +1,11 @@
+use std::time::Duration;
+use std::u64;
+
+pub const REKEY_AFTER_MESSAGES: u64 = u64::MAX - (1 << 16);
+pub const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 4);
+
+pub const REKEY_AFTER_TIME: Duration = Duration::from_secs(120);
+pub const REJECT_AFTER_TIME: Duration = Duration::from_secs(180);
+pub const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90);
+pub const REKEY_TIMEOUT: Duration = Duration::from_secs(5);
+pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10);
diff --git a/src/main.rs b/src/main.rs
index 22e1585..82a4b0c 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,7 +1,7 @@
#![feature(test)]
+mod constants;
mod handshake;
-mod platform;
mod router;
mod types;
diff --git a/src/platform/tun.rs b/src/platform/tun.rs
deleted file mode 100644
index 45fd591..0000000
--- a/src/platform/tun.rs
+++ /dev/null
@@ -1,10 +0,0 @@
-use std::sync::atomic::AtomicUsize;
-use std::sync::Arc;
-
-pub trait Tun: Send + Sync {
- type Error;
-
- fn new(mtu: Arc<AtomicUsize>) -> Self;
- fn read(&self, dst: &mut [u8]) -> Result<usize, Self::Error>;
- fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
-}
diff --git a/src/platform/udp.rs b/src/platform/udp.rs
deleted file mode 100644
index f21a3d3..0000000
--- a/src/platform/udp.rs
+++ /dev/null
@@ -1,11 +0,0 @@
-/* Often times an a file descriptor in an atomic might suffice.
- */
-pub trait Bind<Endpoint>: Send + Sync {
- type Error;
-
- fn new() -> Self;
- fn set_port(&self, port: u16) -> Result<(), Self::Error>;
- fn get_port(&self) -> u16;
- fn recv(&self, dst: &mut [u8]) -> Endpoint;
- fn send(&self, src: &[u8], dst: &Endpoint);
-}
diff --git a/src/router/device.rs b/src/router/device.rs
index 5dfd22c..4dd6539 100644
--- a/src/router/device.rs
+++ b/src/router/device.rs
@@ -1,37 +1,38 @@
use arraydeque::{ArrayDeque, Wrapping};
+use treebitmap::address::Address;
use treebitmap::IpLookupTable;
use crossbeam_deque::{Injector, Steal};
use std::collections::HashMap;
-use std::mem;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
-use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
+use std::sync::mpsc::SyncSender;
use std::sync::{Arc, Mutex, Weak};
use std::thread;
-use std::time::{Duration, Instant};
+use std::time::Instant;
use spin;
+use super::super::constants::*;
use super::super::types::KeyPair;
use super::anti_replay::AntiReplay;
use std::u64;
-const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 4);
const MAX_STAGED_PACKETS: usize = 128;
struct DeviceInner {
stopped: AtomicBool,
- injector: Injector<()>, // parallel enc/dec task injector
- threads: Vec<thread::JoinHandle<()>>,
- recv: spin::RwLock<HashMap<u32, DecryptionState>>,
- ipv4: IpLookupTable<Ipv4Addr, Weak<PeerInner>>,
- ipv6: IpLookupTable<Ipv6Addr, Weak<PeerInner>>,
+ injector: Injector<()>, // parallel enc/dec task injector
+ threads: Vec<thread::JoinHandle<()>>, // join handles of worker threads
+ recv: spin::RwLock<HashMap<u32, DecryptionState>>, // receiver id -> decryption state
+ ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner>>>, // ipv4 cryptkey routing
+ ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner>>>, // ipv6 cryptkey routing
}
struct PeerInner {
stopped: AtomicBool,
+ device: Arc<DeviceInner>,
thread_outbound: spin::Mutex<thread::JoinHandle<()>>,
thread_inbound: spin::Mutex<thread::JoinHandle<()>>,
inorder_outbound: SyncSender<()>,
@@ -40,7 +41,7 @@ struct PeerInner {
rx_bytes: AtomicU64, // received bytes
tx_bytes: AtomicU64, // transmitted bytes
keys: spin::Mutex<KeyWheel>, // key-wheel
- ekey: spin::Mutex<EncryptionState>, // encryption state
+ ekey: spin::Mutex<Option<EncryptionState>>, // encryption state
endpoint: spin::Mutex<Option<Arc<SocketAddr>>>,
}
@@ -68,26 +69,104 @@ struct KeyWheel {
pub struct Peer(Arc<PeerInner>);
pub struct Device(DeviceInner);
+fn treebit_list<A, R>(
+ peer: &Peer,
+ table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner>>>,
+ callback: Box<dyn Fn(A, u32) -> R>,
+) -> Vec<R>
+where
+ A: Address,
+{
+ let mut res = Vec::new();
+ for subnet in table.read().iter() {
+ let (ip, masklen, p) = subnet;
+ if let Some(p) = p.upgrade() {
+ if Arc::ptr_eq(&p, &peer.0) {
+ res.push(callback(ip, masklen))
+ }
+ }
+ }
+ res
+}
+
+fn treebit_remove<A>(peer: &Peer, table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner>>>)
+where
+ A: Address,
+{
+ let mut m = table.write();
+
+ // collect keys for value
+ let mut subnets = vec![];
+ for subnet in m.iter() {
+ let (ip, masklen, p) = subnet;
+ if let Some(p) = p.upgrade() {
+ if Arc::ptr_eq(&p, &peer.0) {
+ subnets.push((ip, masklen))
+ }
+ }
+ }
+
+ // remove all key mappings
+ for subnet in subnets {
+ let r = m.remove(subnet.0, subnet.1);
+ debug_assert!(r.is_some());
+ }
+}
+
impl Drop for Peer {
fn drop(&mut self) {
// mark peer as stopped
- let inner = &self.0;
- inner.stopped.store(true, Ordering::SeqCst);
+ let peer = &self.0;
+ peer.stopped.store(true, Ordering::SeqCst);
+
+ // remove from cryptkey router
+ treebit_remove(self, &peer.device.ipv4);
+ treebit_remove(self, &peer.device.ipv6);
+
+ // unpark threads
+
+ peer.thread_inbound.lock().thread().unpark();
+ peer.thread_outbound.lock().thread().unpark();
+ // collect ids to release
+ let mut keys = peer.keys.lock();
+ let mut release = Vec::with_capacity(3);
+
+ keys.next.map(|k| release.push(k.recv.id));
+ keys.current.map(|k| release.push(k.recv.id));
+ keys.previous.map(|k| release.push(k.recv.id));
+
+ // remove from receive id map
+ if release.len() > 0 {
+ let mut recv = peer.device.recv.write();
+ for id in &release {
+ recv.remove(id);
+ }
+ }
+
+ // null key-material (TODO: extend)
- // unpark threads to stop
- inner.thread_inbound.lock().thread().unpark();
- inner.thread_outbound.lock().thread().unpark();
+ keys.next = None;
+ keys.current = None;
+ keys.previous = None;
+
+ *peer.ekey.lock() = None;
+ *peer.endpoint.lock() = None;
}
}
impl Drop for Device {
fn drop(&mut self) {
// mark device as stopped
- let inner = &self.0;
- inner.stopped.store(true, Ordering::SeqCst);
+ let device = &self.0;
+ device.stopped.store(true, Ordering::SeqCst);
// eat all parallel jobs
- while inner.injector.steal() != Steal::Empty {}
+ while device.injector.steal() != Steal::Empty {}
+
+ // unpark all threads
+ for handle in &device.threads {
+ handle.thread().unpark();
+ }
}
}
@@ -97,12 +176,12 @@ impl Peer {
}
pub fn keypair_confirm(&self, ks: Arc<KeyPair>) {
- *self.0.ekey.lock() = EncryptionState {
+ *self.0.ekey.lock() = Some(EncryptionState {
id: ks.send.id,
key: ks.send.key,
nonce: 0,
- death: ks.birth + Duration::from_millis(1337), // todo
- };
+ death: ks.birth + REJECT_AFTER_TIME,
+ });
}
fn keypair_add(&self, new: KeyPair) -> Option<u32> {
@@ -112,12 +191,12 @@ impl Peer {
// update key-wheel
if new.confirmed {
// start using key for encryption
- *self.0.ekey.lock() = EncryptionState {
+ *self.0.ekey.lock() = Some(EncryptionState {
id: new.send.id,
key: new.send.key,
nonce: 0,
- death: new.birth + Duration::from_millis(1337), // todo
- };
+ death: new.birth + REJECT_AFTER_TIME,
+ });
// move current into previous
keys.previous = keys.current;
@@ -148,42 +227,39 @@ impl Device {
stopped: AtomicBool::new(false),
injector: Injector::new(),
recv: spin::RwLock::new(HashMap::new()),
- ipv4: IpLookupTable::new(),
- ipv6: IpLookupTable::new(),
+ ipv4: spin::RwLock::new(IpLookupTable::new()),
+ ipv6: spin::RwLock::new(IpLookupTable::new()),
})
}
pub fn add_subnet(&mut self, ip: IpAddr, masklen: u32, peer: Peer) {
match ip {
- IpAddr::V4(v4) => self.0.ipv4.insert(v4, masklen, Arc::downgrade(&peer.0)),
- IpAddr::V6(v6) => self.0.ipv6.insert(v6, masklen, Arc::downgrade(&peer.0)),
+ IpAddr::V4(v4) => self
+ .0
+ .ipv4
+ .write()
+ .insert(v4, masklen, Arc::downgrade(&peer.0)),
+ IpAddr::V6(v6) => self
+ .0
+ .ipv6
+ .write()
+ .insert(v6, masklen, Arc::downgrade(&peer.0)),
};
}
- pub fn subnets(&self, peer: Peer) -> Vec<(IpAddr, u32)> {
- let mut subnets = Vec::new();
-
- // extract ipv4 entries
- for subnet in self.0.ipv4.iter() {
- let (ip, masklen, p) = subnet;
- if let Some(p) = p.upgrade() {
- if Arc::ptr_eq(&p, &peer.0) {
- subnets.push((IpAddr::V4(ip), masklen))
- }
- }
- }
-
- // extract ipv6 entries
- for subnet in self.0.ipv6.iter() {
- let (ip, masklen, p) = subnet;
- if let Some(p) = p.upgrade() {
- if Arc::ptr_eq(&p, &peer.0) {
- subnets.push((IpAddr::V6(ip), masklen))
- }
- }
- }
-
- subnets
+ pub fn list_subnets(&self, peer: Peer) -> Vec<(IpAddr, u32)> {
+ let mut res = Vec::new();
+ res.append(&mut treebit_list(
+ &peer,
+ &self.0.ipv4,
+ Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)),
+ ));
+ res.append(&mut treebit_list(
+ &peer,
+ &self.0.ipv6,
+ Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)),
+ ));
+ res
}
pub fn keypair_add(&self, peer: Peer, new: KeyPair) -> Option<u32> {
@@ -208,7 +284,7 @@ impl Device {
key: new.recv.key,
protector: Arc::new(spin::Mutex::new(AntiReplay::new())),
peer: Arc::downgrade(&peer.0),
- death: new.birth + Duration::from_millis(2600), // todo
+ death: new.birth + REJECT_AFTER_TIME,
},
);
diff --git a/src/types/keys.rs b/src/types/keys.rs
new file mode 100644
index 0000000..0b52d18
--- /dev/null
+++ b/src/types/keys.rs
@@ -0,0 +1,26 @@
+use std::time::Instant;
+
+/* This file holds types passed between components.
+ * Whenever a type cannot be held local to a single module.
+ */
+
+#[derive(Debug, Clone, Copy)]
+pub struct Key {
+ pub key: [u8; 32],
+ pub id: u32,
+}
+
+#[cfg(test)]
+impl PartialEq for Key {
+ fn eq(&self, other: &Self) -> bool {
+ self.id == other.id && self.key[..] == other.key[..]
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+pub struct KeyPair {
+ pub birth: Instant, // when was the key-pair created
+ pub confirmed: bool, // has the key-pair been confirmed?
+ pub send: Key, // key for outbound messages
+ pub recv: Key, // key for inbound messages
+} \ No newline at end of file
diff --git a/src/types/mod.rs b/src/types/mod.rs
index ea7c570..868fb71 100644
--- a/src/types/mod.rs
+++ b/src/types/mod.rs
@@ -1,26 +1,7 @@
-use std::time::Instant;
+mod keys;
+mod tun;
+mod udp;
-/* This file holds types passed between components.
- * Whenever a type cannot be held local to a single module.
- */
-
-#[derive(Debug, Clone, Copy)]
-pub struct Key {
- pub key: [u8; 32],
- pub id: u32,
-}
-
-#[cfg(test)]
-impl PartialEq for Key {
- fn eq(&self, other: &Self) -> bool {
- self.id == other.id && self.key[..] == other.key[..]
- }
-}
-
-#[derive(Debug, Clone, Copy)]
-pub struct KeyPair {
- pub birth: Instant, // when was the key-pair created
- pub confirmed: bool, // has the key-pair been confirmed?
- pub send: Key, // key for outbound messages
- pub recv: Key, // key for inbound messages
-}
+pub use keys::{Key, KeyPair};
+pub use tun::Tun;
+pub use udp::Bind; \ No newline at end of file
diff --git a/src/types/tun.rs b/src/types/tun.rs
new file mode 100644
index 0000000..72caa71
--- /dev/null
+++ b/src/types/tun.rs
@@ -0,0 +1,43 @@
+use std::error;
+
+pub trait Tun: Send + Sync {
+ type Error: error::Error;
+
+ /// Returns the MTU of the device
+ ///
+ /// This function needs to be efficient (called for every read).
+ /// The goto implementation stragtegy is to .load an atomic variable,
+ /// then use e.g. netlink to update the variable in a seperate thread.
+ ///
+ /// # Returns
+ ///
+ /// The MTU of the interface in bytes
+ fn mtu(&self) -> usize;
+
+ /// Reads an IP packet into dst[offset:] from the tunnel device
+ ///
+ /// The reason for providing space for a prefix
+ /// is to efficiently accommodate platforms on which the packet is prefaced by a header.
+ /// This space is later used to construct the transport message inplace.
+ ///
+ /// # Arguments
+ ///
+ /// - dst: Destination buffer (enough space for MTU bytes + header)
+ /// - offset: Offset for the beginning of the IP packet
+ ///
+ /// # Returns
+ ///
+ /// The size of the IP packet (ignoring the header) or an std::error::Error instance:
+ fn read(&self, dst: &mut [u8], offset: usize) -> Result<usize, Self::Error>;
+
+ /// Writes an IP packet to the tunnel device
+ ///
+ /// # Arguments
+ ///
+ /// - src: Buffer containing the IP packet to be written
+ ///
+ /// # Returns
+ ///
+ /// Unit type or an error
+ fn write(&self, src: &[u8]) -> Result<(), Self::Error>;
+}
diff --git a/src/types/udp.rs b/src/types/udp.rs
new file mode 100644
index 0000000..f45cf85
--- /dev/null
+++ b/src/types/udp.rs
@@ -0,0 +1,26 @@
+use std::error;
+
+/* Often times an a file descriptor in an atomic might suffice.
+ */
+pub trait Bind<Endpoint>: Send + Sync {
+ type Error : error::Error;
+
+ fn new() -> Self;
+
+ /// Updates the port of the Bind
+ ///
+ /// # Arguments
+ ///
+ /// - port, The new port to bind to. 0 means any available port.
+ ///
+ /// # Returns
+ ///
+ /// The unit type or an error, if binding fails
+ fn set_port(&self, port: u16) -> Result<(), Self::Error>;
+
+ /// Returns the current port of the bind
+ fn get_port(&self) -> u16;
+
+ fn recv(&self, dst: &mut [u8]) -> Endpoint;
+ fn send(&self, src: &[u8], dst: &Endpoint);
+}