aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-07 18:38:19 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2019-09-07 18:38:19 +0200
commit7b61ee4c2db87e195f5291fb1a3927648d38a2a4 (patch)
tree410c0609c3f4d1afbd0d87791b9156a538f59398 /src
parentAdded outbound benchmark (diff)
downloadwireguard-rs-7b61ee4c2db87e195f5291fb1a3927648d38a2a4.tar.xz
wireguard-rs-7b61ee4c2db87e195f5291fb1a3927648d38a2a4.zip
Write inbound packets to TUN device
Diffstat (limited to 'src')
-rw-r--r--src/router/constants.rs5
-rw-r--r--src/router/device.rs169
-rw-r--r--src/router/ip.rs37
-rw-r--r--src/router/mod.rs1
-rw-r--r--src/router/peer.rs114
-rw-r--r--src/router/tests.rs2
-rw-r--r--src/router/types.rs4
-rw-r--r--src/router/workers.rs108
-rw-r--r--src/types/endpoint.rs4
9 files changed, 306 insertions, 138 deletions
diff --git a/src/router/constants.rs b/src/router/constants.rs
index b3015ed..0ca824a 100644
--- a/src/router/constants.rs
+++ b/src/router/constants.rs
@@ -1,2 +1,7 @@
+// WireGuard semantics constants
+
pub const MAX_STAGED_PACKETS: usize = 128;
+
+// performance constants
+
pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS;
diff --git a/src/router/device.rs b/src/router/device.rs
index 2196dd1..69304d8 100644
--- a/src/router/device.rs
+++ b/src/router/device.rs
@@ -10,8 +10,9 @@ use std::time::Instant;
use log::debug;
-use spin;
+use spin::{Mutex, RwLock};
use treebitmap::IpLookupTable;
+use zerocopy::LayoutVerified;
use super::super::types::{Bind, KeyPair, Tun};
@@ -20,23 +21,15 @@ use super::peer;
use super::peer::{Peer, PeerInner};
use super::SIZE_MESSAGE_PREFIX;
-use super::constants::WORKER_QUEUE_SIZE;
-use super::messages::TYPE_TRANSPORT;
-use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
-use super::workers::{worker_parallel, JobParallel};
-
-// minimum sizes for IP headers
-const SIZE_IP4_HEADER: usize = 16;
-const SIZE_IP6_HEADER: usize = 36;
+use super::constants::*;
+use super::ip::*;
-const VERSION_IP4: u8 = 4;
-const VERSION_IP6: u8 = 6;
-
-const OFFSET_IP4_DST: usize = 16;
-const OFFSET_IP6_DST: usize = 24;
+use super::messages::{TransportHeader, TYPE_TRANSPORT};
+use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError};
+use super::workers::{worker_parallel, JobParallel, Operation};
pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
- // IO & timer generics
+ // IO & timer callbacks
pub tun: T,
pub bind: B,
pub call_recv: C::CallbackRecv,
@@ -44,9 +37,9 @@ pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> {
pub call_need_key: C::CallbackKey,
// routing
- pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T, B>>>, // receiver id -> decryption state
- pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
- pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
+ pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state
+ pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing
+ pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing
}
pub struct EncryptionState {
@@ -57,19 +50,18 @@ pub struct EncryptionState {
}
pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> {
- pub key: [u8; 32],
- pub keypair: Weak<KeyPair>, // only the key-wheel has a strong reference
+ pub keypair: Arc<KeyPair>,
pub confirmed: AtomicBool,
- pub protector: spin::Mutex<AntiReplay>,
- pub peer: Weak<PeerInner<C, T, B>>,
+ pub protector: Mutex<AntiReplay>,
+ pub peer: Arc<PeerInner<C, T, B>>,
pub death: Instant, // time when the key can no longer be used for decryption
}
pub struct Device<C: Callbacks, T: Tun, B: Bind> {
- pub state: Arc<DeviceInner<C, T, B>>, // reference to device state
- pub handles: Vec<thread::JoinHandle<()>>, // join handles for workers
- pub queue_next: AtomicUsize, // next round-robin index
- pub queues: Vec<spin::Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread)
+ state: Arc<DeviceInner<C, T, B>>, // reference to device state
+ handles: Vec<thread::JoinHandle<()>>, // join handles for workers
+ queue_next: AtomicUsize, // next round-robin index
+ queues: Vec<Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread)
}
impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> {
@@ -109,9 +101,9 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
call_recv,
call_send,
call_need_key,
- recv: spin::RwLock::new(HashMap::new()),
- ipv4: spin::RwLock::new(IpLookupTable::new()),
- ipv6: spin::RwLock::new(IpLookupTable::new()),
+ recv: RwLock::new(HashMap::new()),
+ ipv4: RwLock::new(IpLookupTable::new()),
+ ipv6: RwLock::new(IpLookupTable::new()),
});
// start worker threads
@@ -119,7 +111,7 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
let mut threads = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE);
- queues.push(spin::Mutex::new(tx));
+ queues.push(Mutex::new(tx));
threads.push(thread::spawn(move || worker_parallel(rx)));
}
@@ -133,6 +125,40 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi
}
}
+#[inline(always)]
+fn get_route<C: Callbacks, T: Tun, B: Bind>(
+ device: &Arc<DeviceInner<C, T, B>>,
+ packet: &[u8],
+) -> Option<Arc<PeerInner<C, T, B>>> {
+ match packet[0] >> 4 {
+ VERSION_IP4 => {
+ // check length and cast to IPv4 header
+ let (header, _) = LayoutVerified::new_from_prefix(packet)?;
+ let header: LayoutVerified<&[u8], IPv4Header> = header;
+
+ // check IPv4 source address
+ device
+ .ipv4
+ .read()
+ .longest_match(Ipv4Addr::from(header.f_source))
+ .and_then(|(_, _, p)| Some(p.clone()))
+ }
+ VERSION_IP6 => {
+ // check length and cast to IPv6 header
+ let (header, packet) = LayoutVerified::new_from_prefix(packet)?;
+ let header: LayoutVerified<&[u8], IPv6Header> = header;
+
+ // check IPv6 source address
+ device
+ .ipv6
+ .read()
+ .longest_match(Ipv6Addr::from(header.f_source))
+ .and_then(|(_, _, p)| Some(p.clone()))
+ }
+ _ => None,
+ }
+}
+
impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
/// Adds a new peer to the device
///
@@ -159,48 +185,12 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
let packet = &msg[SIZE_MESSAGE_PREFIX..];
// lookup peer based on IP packet destination address
- let peer = match packet[0] >> 4 {
- VERSION_IP4 => {
- if msg.len() >= SIZE_IP4_HEADER {
- // extract IPv4 destination address
- let mut dst = [0u8; 4];
- dst.copy_from_slice(&packet[OFFSET_IP4_DST..OFFSET_IP4_DST + 4]);
- let dst = Ipv4Addr::from(dst);
-
- // lookup peer (project unto and clone "value" field)
- self.state
- .ipv4
- .read()
- .longest_match(dst)
- .and_then(|(_, _, p)| p.upgrade())
- .ok_or(RouterError::NoCryptKeyRoute)
- } else {
- Err(RouterError::MalformedIPHeader)
- }
- }
- VERSION_IP6 => {
- if msg.len() >= SIZE_IP6_HEADER {
- // extract IPv6 destination address
- let mut dst = [0u8; 16];
- dst.copy_from_slice(&packet[OFFSET_IP6_DST..OFFSET_IP6_DST + 16]);
- let dst = Ipv6Addr::from(dst);
-
- // lookup peer (project unto and clone "value" field)
- self.state
- .ipv6
- .read()
- .longest_match(dst)
- .and_then(|(_, _, p)| p.upgrade())
- .ok_or(RouterError::NoCryptKeyRoute)
- } else {
- Err(RouterError::MalformedIPHeader)
- }
- }
- _ => Err(RouterError::MalformedIPHeader),
- }?;
+ let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?;
// schedule for encryption and transmission to peer
if let Some(job) = peer.send_job(msg) {
+ debug_assert_eq!(job.1.op, Operation::Encryption);
+
// add job to worker queue
let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
self.queues[idx % self.queues.len()]
@@ -216,17 +206,44 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> {
///
/// # Arguments
///
+ /// - src: Source address of the packet
/// - msg: Encrypted transport message
- pub fn recv(&self, msg: Vec<u8>) -> Result<(), RouterError> {
- // ensure that the type field access is within bounds
- if msg.len() < SIZE_MESSAGE_PREFIX || msg[0] != TYPE_TRANSPORT {
- return Err(RouterError::MalformedTransportMessage);
- }
-
+ ///
+ /// # Returns
+ ///
+ ///
+ pub fn recv(&self, src: B::Endpoint, msg: Vec<u8>) -> Result<(), RouterError> {
// parse / cast
+ let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) {
+ Some(v) => v,
+ None => {
+ return Err(RouterError::MalformedTransportMessage);
+ }
+ };
+ let header: LayoutVerified<&[u8], TransportHeader> = header;
+ debug_assert!(
+ header.f_type.get() == TYPE_TRANSPORT as u32,
+ "this should be checked by the message type multiplexer"
+ );
// lookup peer based on receiver id
+ let dec = self.state.recv.read();
+ let dec = dec
+ .get(&header.f_receiver.get())
+ .ok_or(RouterError::UnkownReceiverId)?;
+
+ // schedule for decryption and TUN write
+ if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) {
+ debug_assert_eq!(job.1.op, Operation::Decryption);
- unimplemented!();
+ // add job to worker queue
+ let idx = self.queue_next.fetch_add(1, Ordering::SeqCst);
+ self.queues[idx % self.queues.len()]
+ .lock()
+ .send(job)
+ .unwrap();
+ }
+
+ Ok(())
}
}
diff --git a/src/router/ip.rs b/src/router/ip.rs
new file mode 100644
index 0000000..6eb303c
--- /dev/null
+++ b/src/router/ip.rs
@@ -0,0 +1,37 @@
+use byteorder::BigEndian;
+use zerocopy::byteorder::U16;
+use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified};
+
+pub const SIZE_IP4_HEADER: usize = 16;
+pub const SIZE_IP6_HEADER: usize = 36;
+
+pub const VERSION_IP4: u8 = 4;
+pub const VERSION_IP6: u8 = 6;
+
+pub const OFFSET_IP4_SRC: usize = 12;
+pub const OFFSET_IP6_SRC: usize = 8;
+
+pub const OFFSET_IP4_DST: usize = 16;
+pub const OFFSET_IP6_DST: usize = 24;
+
+pub const TYPE_TRANSPORT: u8 = 4;
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct IPv4Header {
+ _f_space1: [u8; 2],
+ pub f_total_len: U16<BigEndian>,
+ _f_space2: [u8; 8],
+ pub f_source: [u8; 4],
+ pub f_destination: [u8; 4],
+}
+
+#[repr(packed)]
+#[derive(Copy, Clone, FromBytes, AsBytes)]
+pub struct IPv6Header {
+ _f_pre: [u8; 4],
+ pub f_len: U16<BigEndian>,
+ _f_space2: [u8; 2],
+ pub f_source: [u8; 16],
+ pub f_destination: [u8; 16],
+}
diff --git a/src/router/mod.rs b/src/router/mod.rs
index ec560b4..883c875 100644
--- a/src/router/mod.rs
+++ b/src/router/mod.rs
@@ -1,6 +1,7 @@
mod anti_replay;
mod constants;
mod device;
+mod ip;
mod messages;
mod peer;
mod types;
diff --git a/src/router/peer.rs b/src/router/peer.rs
index 634f980..0cd588d 100644
--- a/src/router/peer.rs
+++ b/src/router/peer.rs
@@ -30,7 +30,7 @@ use super::workers::Operation;
use super::workers::{worker_inbound, worker_outbound};
use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel};
-use super::constants::MAX_STAGED_PACKETS;
+use super::constants::*;
use super::types::Callbacks;
pub struct KeyWheel {
@@ -50,7 +50,7 @@ pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> {
pub tx_bytes: AtomicU64, // transmitted bytes
pub keys: Mutex<KeyWheel>, // key-wheel
pub ekey: Mutex<Option<EncryptionState>>, // encryption state
- pub endpoint: Mutex<Option<Arc<SocketAddr>>>,
+ pub endpoint: Mutex<Option<B::Endpoint>>,
}
pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
@@ -61,7 +61,7 @@ pub struct Peer<C: Callbacks, T: Tun, B: Bind> {
fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>(
peer: &Arc<PeerInner<C, T, B>>,
- table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>,
+ table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
callback: Box<dyn Fn(A, u32) -> E>,
) -> Vec<E>
where
@@ -70,10 +70,8 @@ where
let mut res = Vec::new();
for subnet in table.read().iter() {
let (ip, masklen, p) = subnet;
- if let Some(p) = p.upgrade() {
- if Arc::ptr_eq(&p, &peer) {
- res.push(callback(ip, masklen))
- }
+ if Arc::ptr_eq(&p, &peer) {
+ res.push(callback(ip, masklen))
}
}
res
@@ -81,7 +79,7 @@ where
fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
peer: &Peer<C, T, B>,
- table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>,
+ table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>,
) {
let mut m = table.write();
@@ -89,10 +87,8 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
let mut subnets = vec![];
for subnet in m.iter() {
let (ip, masklen, p) = subnet;
- if let Some(p) = p.upgrade() {
- if Arc::ptr_eq(&p, &peer.state) {
- subnets.push((ip, masklen))
- }
+ if Arc::ptr_eq(&p, &peer.state) {
+ subnets.push((ip, masklen))
}
}
@@ -103,6 +99,29 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>(
}
}
+impl EncryptionState {
+ fn new(keypair: &Arc<KeyPair>) -> EncryptionState {
+ EncryptionState {
+ id: keypair.send.id,
+ key: keypair.send.key,
+ nonce: 0,
+ death: keypair.birth + REJECT_AFTER_TIME,
+ }
+ }
+}
+
+impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> {
+ fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> {
+ DecryptionState {
+ confirmed: AtomicBool::new(keypair.initiator),
+ keypair: keypair.clone(),
+ protector: spin::Mutex::new(AntiReplay::new()),
+ peer: peer.clone(),
+ death: keypair.birth + REJECT_AFTER_TIME,
+ }
+ }
+}
+
impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> {
fn drop(&mut self) {
let peer = &self.state;
@@ -202,12 +221,52 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>(
}
impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
- pub fn confirm_key(&self, kp: Weak<KeyPair>) {
- // upgrade key-pair to strong reference
+ pub fn confirm_key(&self, keypair: &Arc<KeyPair>) {
+ // take lock and check keypair = keys.next
+ let mut keys = self.keys.lock();
+ let next = match keys.next.as_ref() {
+ Some(next) => next,
+ None => {
+ return;
+ }
+ };
+ if !Arc::ptr_eq(&next, keypair) {
+ return;
+ }
- // check it is the new unconfirmed key
+ // allocate new encryption state
+ let ekey = Some(EncryptionState::new(&next));
// rotate key-wheel
+ let mut swap = None;
+ mem::swap(&mut keys.next, &mut swap);
+ mem::swap(&mut keys.current, &mut swap);
+ mem::swap(&mut keys.previous, &mut swap);
+
+ // set new encryption key
+ *self.ekey.lock() = ekey;
+ }
+
+ pub fn recv_job(
+ &self,
+ src: B::Endpoint,
+ dec: Arc<DecryptionState<C, T, B>>,
+ mut msg: Vec<u8>,
+ ) -> Option<JobParallel> {
+ let (tx, rx) = oneshot();
+ let key = dec.keypair.send.key;
+ match self.inbound.lock().try_send((dec, src, rx)) {
+ Ok(_) => Some((
+ tx,
+ JobBuffer {
+ msg,
+ key: key,
+ okay: false,
+ op: Operation::Decryption,
+ },
+ )),
+ Err(_) => None,
+ }
}
pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> {
@@ -260,7 +319,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> {
impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
pub fn set_endpoint(&self, endpoint: SocketAddr) {
- *self.state.endpoint.lock() = Some(Arc::new(endpoint))
+ *self.state.endpoint.lock() = Some(endpoint.into());
}
/// Add a new keypair
@@ -285,12 +344,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
// update key-wheel
if new.initiator {
// start using key for encryption
- *self.state.ekey.lock() = Some(EncryptionState {
- id: new.send.id,
- key: new.send.key,
- nonce: 0,
- death: new.birth + REJECT_AFTER_TIME,
- });
+ *self.state.ekey.lock() = Some(EncryptionState::new(&new));
// move current into previous
keys.previous = keys.current.as_ref().map(|v| v.clone());
@@ -310,19 +364,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
recv.remove(&id);
}
- // map new id to keypair
+ // map new id to decryption state
debug_assert!(!recv.contains_key(&new.recv.id));
-
recv.insert(
new.recv.id,
- DecryptionState {
- confirmed: AtomicBool::new(new.initiator),
- keypair: Arc::downgrade(&new),
- key: new.recv.key,
- protector: spin::Mutex::new(AntiReplay::new()),
- peer: Arc::downgrade(&self.state),
- death: new.birth + REJECT_AFTER_TIME,
- },
+ Arc::new(DecryptionState::new(&self.state, &new)),
);
}
@@ -345,14 +391,14 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> {
.device
.ipv4
.write()
- .insert(v4, masklen, Arc::downgrade(&self.state))
+ .insert(v4, masklen, self.state.clone())
}
IpAddr::V6(v6) => {
self.state
.device
.ipv6
.write()
- .insert(v6, masklen, Arc::downgrade(&self.state))
+ .insert(v6, masklen, self.state.clone())
}
};
}
diff --git a/src/router/tests.rs b/src/router/tests.rs
index 5463532..7fe2b7a 100644
--- a/src/router/tests.rs
+++ b/src/router/tests.rs
@@ -156,6 +156,8 @@ mod tests {
#[bench]
fn bench_outbound(b: &mut Bencher) {
+ init();
+
// type for tracking number of packets
type Opaque = Arc<AtomicU64>;
diff --git a/src/router/types.rs b/src/router/types.rs
index 336f56b..7706997 100644
--- a/src/router/types.rs
+++ b/src/router/types.rs
@@ -57,6 +57,7 @@ pub enum RouterError {
NoCryptKeyRoute,
MalformedIPHeader,
MalformedTransportMessage,
+ UnkownReceiverId,
}
impl fmt::Display for RouterError {
@@ -65,6 +66,9 @@ impl fmt::Display for RouterError {
RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"),
RouterError::MalformedIPHeader => write!(f, "IP header is malformed"),
RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"),
+ RouterError::UnkownReceiverId => {
+ write!(f, "No decryption state associated with receiver id")
+ }
}
}
}
diff --git a/src/router/workers.rs b/src/router/workers.rs
index b18b038..45e1058 100644
--- a/src/router/workers.rs
+++ b/src/router/workers.rs
@@ -1,6 +1,6 @@
use std::mem;
use std::sync::mpsc::Receiver;
-use std::sync::{Arc, Weak};
+use std::sync::Arc;
use futures::sync::oneshot;
use futures::*;
@@ -8,15 +8,17 @@ use futures::*;
use log::debug;
use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305};
-use std::sync::atomic::{AtomicBool, Ordering};
+use std::net::{Ipv4Addr, Ipv6Addr};
+use std::sync::atomic::Ordering;
use zerocopy::{AsBytes, LayoutVerified};
-use super::device::DecryptionState;
-use super::device::DeviceInner;
+use super::device::{DecryptionState, DeviceInner};
use super::messages::TransportHeader;
use super::peer::PeerInner;
use super::types::Callbacks;
+use super::ip::*;
+
use super::super::types::{Bind, Tun};
#[derive(PartialEq, Debug)]
@@ -33,9 +35,60 @@ pub struct JobBuffer {
}
pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer);
-pub type JobInbound<C, T, B> = (Weak<DecryptionState<C, T, B>>, oneshot::Receiver<JobBuffer>);
+pub type JobInbound<C, T, B: Bind> = (
+ Arc<DecryptionState<C, T, B>>,
+ B::Endpoint,
+ oneshot::Receiver<JobBuffer>,
+);
pub type JobOutbound = oneshot::Receiver<JobBuffer>;
+#[inline(always)]
+fn check_route<C: Callbacks, T: Tun, B: Bind>(
+ device: &Arc<DeviceInner<C, T, B>>,
+ peer: &Arc<PeerInner<C, T, B>>,
+ packet: &[u8],
+) -> Option<usize> {
+ match packet[0] >> 4 {
+ VERSION_IP4 => {
+ // check length and cast to IPv4 header
+ let (header, _) = LayoutVerified::new_from_prefix(packet)?;
+ let header: LayoutVerified<&[u8], IPv4Header> = header;
+
+ // check IPv4 source address
+ device
+ .ipv4
+ .read()
+ .longest_match(Ipv4Addr::from(header.f_source))
+ .and_then(|(_, _, p)| {
+ if Arc::ptr_eq(p, &peer) {
+ Some(header.f_total_len.get() as usize)
+ } else {
+ None
+ }
+ })
+ }
+ VERSION_IP6 => {
+ // check length and cast to IPv6 header
+ let (header, packet) = LayoutVerified::new_from_prefix(packet)?;
+ let header: LayoutVerified<&[u8], IPv6Header> = header;
+
+ // check IPv6 source address
+ device
+ .ipv6
+ .read()
+ .longest_match(Ipv6Addr::from(header.f_source))
+ .and_then(|(_, _, p)| {
+ if Arc::ptr_eq(p, &peer) {
+ Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
+ } else {
+ None
+ }
+ })
+ }
+ _ => None,
+ }
+}
+
pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
device: Arc<DeviceInner<C, T, B>>, // related device
peer: Arc<PeerInner<C, T, B>>, // related peer
@@ -43,7 +96,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
) {
loop {
// fetch job
- let (state, rx) = match receiver.recv() {
+ let (state, endpoint, rx) = match receiver.recv() {
Ok(v) => v,
_ => {
return;
@@ -62,13 +115,10 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
}
};
let header: LayoutVerified<&[u8], TransportHeader> = header;
-
- // obtain strong reference to decryption state
- let state = if let Some(state) = state.upgrade() {
- state
- } else {
- return;
- };
+ debug_assert!(
+ packet.len() >= 16,
+ "this should be checked earlier in the pipeline"
+ );
// check for replay
if !state.protector.lock().update(header.f_counter.get()) {
@@ -77,23 +127,29 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>(
// check for confirms key
if !state.confirmed.swap(true, Ordering::SeqCst) {
- peer.confirm_key(state.keypair.clone());
+ peer.confirm_key(&state.keypair);
}
- // update endpoint, TODO
-
- // write packet to TUN device, TODO
+ // update endpoint
+ *peer.endpoint.lock() = Some(endpoint);
+
+ // calculate length of IP packet + padding
+ let length = packet.len() - CHACHA20_POLY1305.nonce_len();
+
+ // check if should be written to TUN
+ let mut sent = false;
+ if length > 0 {
+ if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) {
+ debug_assert!(inner_len <= length, "should be validated");
+ if inner_len <= length {
+ sent = true;
+ let _ = device.tun.write(&packet[..inner_len]);
+ }
+ }
+ }
// trigger callback
- debug_assert!(
- packet.len() >= CHACHA20_POLY1305.nonce_len(),
- "this should be checked earlier in the pipeline"
- );
- (device.call_recv)(
- &peer.opaque,
- packet.len() > CHACHA20_POLY1305.nonce_len(),
- true,
- );
+ (device.call_recv)(&peer.opaque, length == 0, sent);
}
})
.wait();
diff --git a/src/types/endpoint.rs b/src/types/endpoint.rs
index 6bc99b9..8033080 100644
--- a/src/types/endpoint.rs
+++ b/src/types/endpoint.rs
@@ -1,5 +1,5 @@
use std::net::SocketAddr;
-pub trait Endpoint: Into<SocketAddr> + From<SocketAddr> {}
+pub trait Endpoint: Into<SocketAddr> + From<SocketAddr> + Send {}
-impl<T> Endpoint for T where T: Into<SocketAddr> + From<SocketAddr> {}
+impl<T> Endpoint for T where T: Into<SocketAddr> + From<SocketAddr> + Send {}