diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-09-07 18:38:19 +0200 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-09-07 18:38:19 +0200 |
commit | 7b61ee4c2db87e195f5291fb1a3927648d38a2a4 (patch) | |
tree | 410c0609c3f4d1afbd0d87791b9156a538f59398 /src | |
parent | Added outbound benchmark (diff) | |
download | wireguard-rs-7b61ee4c2db87e195f5291fb1a3927648d38a2a4.tar.xz wireguard-rs-7b61ee4c2db87e195f5291fb1a3927648d38a2a4.zip |
Write inbound packets to TUN device
Diffstat (limited to 'src')
-rw-r--r-- | src/router/constants.rs | 5 | ||||
-rw-r--r-- | src/router/device.rs | 169 | ||||
-rw-r--r-- | src/router/ip.rs | 37 | ||||
-rw-r--r-- | src/router/mod.rs | 1 | ||||
-rw-r--r-- | src/router/peer.rs | 114 | ||||
-rw-r--r-- | src/router/tests.rs | 2 | ||||
-rw-r--r-- | src/router/types.rs | 4 | ||||
-rw-r--r-- | src/router/workers.rs | 108 | ||||
-rw-r--r-- | src/types/endpoint.rs | 4 |
9 files changed, 306 insertions, 138 deletions
diff --git a/src/router/constants.rs b/src/router/constants.rs index b3015ed..0ca824a 100644 --- a/src/router/constants.rs +++ b/src/router/constants.rs @@ -1,2 +1,7 @@ +// WireGuard semantics constants + pub const MAX_STAGED_PACKETS: usize = 128; + +// performance constants + pub const WORKER_QUEUE_SIZE: usize = MAX_STAGED_PACKETS; diff --git a/src/router/device.rs b/src/router/device.rs index 2196dd1..69304d8 100644 --- a/src/router/device.rs +++ b/src/router/device.rs @@ -10,8 +10,9 @@ use std::time::Instant; use log::debug; -use spin; +use spin::{Mutex, RwLock}; use treebitmap::IpLookupTable; +use zerocopy::LayoutVerified; use super::super::types::{Bind, KeyPair, Tun}; @@ -20,23 +21,15 @@ use super::peer; use super::peer::{Peer, PeerInner}; use super::SIZE_MESSAGE_PREFIX; -use super::constants::WORKER_QUEUE_SIZE; -use super::messages::TYPE_TRANSPORT; -use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError}; -use super::workers::{worker_parallel, JobParallel}; - -// minimum sizes for IP headers -const SIZE_IP4_HEADER: usize = 16; -const SIZE_IP6_HEADER: usize = 36; +use super::constants::*; +use super::ip::*; -const VERSION_IP4: u8 = 4; -const VERSION_IP6: u8 = 6; - -const OFFSET_IP4_DST: usize = 16; -const OFFSET_IP6_DST: usize = 24; +use super::messages::{TransportHeader, TYPE_TRANSPORT}; +use super::types::{Callback, Callbacks, KeyCallback, Opaque, PhantomCallbacks, RouterError}; +use super::workers::{worker_parallel, JobParallel, Operation}; pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> { - // IO & timer generics + // IO & timer callbacks pub tun: T, pub bind: B, pub call_recv: C::CallbackRecv, @@ -44,9 +37,9 @@ pub struct DeviceInner<C: Callbacks, T: Tun, B: Bind> { pub call_need_key: C::CallbackKey, // routing - pub recv: spin::RwLock<HashMap<u32, DecryptionState<C, T, B>>>, // receiver id -> decryption state - pub ipv4: spin::RwLock<IpLookupTable<Ipv4Addr, Weak<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing - pub ipv6: spin::RwLock<IpLookupTable<Ipv6Addr, Weak<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing + pub recv: RwLock<HashMap<u32, Arc<DecryptionState<C, T, B>>>>, // receiver id -> decryption state + pub ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<PeerInner<C, T, B>>>>, // ipv4 cryptkey routing + pub ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<PeerInner<C, T, B>>>>, // ipv6 cryptkey routing } pub struct EncryptionState { @@ -57,19 +50,18 @@ pub struct EncryptionState { } pub struct DecryptionState<C: Callbacks, T: Tun, B: Bind> { - pub key: [u8; 32], - pub keypair: Weak<KeyPair>, // only the key-wheel has a strong reference + pub keypair: Arc<KeyPair>, pub confirmed: AtomicBool, - pub protector: spin::Mutex<AntiReplay>, - pub peer: Weak<PeerInner<C, T, B>>, + pub protector: Mutex<AntiReplay>, + pub peer: Arc<PeerInner<C, T, B>>, pub death: Instant, // time when the key can no longer be used for decryption } pub struct Device<C: Callbacks, T: Tun, B: Bind> { - pub state: Arc<DeviceInner<C, T, B>>, // reference to device state - pub handles: Vec<thread::JoinHandle<()>>, // join handles for workers - pub queue_next: AtomicUsize, // next round-robin index - pub queues: Vec<spin::Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread) + state: Arc<DeviceInner<C, T, B>>, // reference to device state + handles: Vec<thread::JoinHandle<()>>, // join handles for workers + queue_next: AtomicUsize, // next round-robin index + queues: Vec<Mutex<SyncSender<JobParallel>>>, // work queues (1 per thread) } impl<C: Callbacks, T: Tun, B: Bind> Drop for Device<C, T, B> { @@ -109,9 +101,9 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi call_recv, call_send, call_need_key, - recv: spin::RwLock::new(HashMap::new()), - ipv4: spin::RwLock::new(IpLookupTable::new()), - ipv6: spin::RwLock::new(IpLookupTable::new()), + recv: RwLock::new(HashMap::new()), + ipv4: RwLock::new(IpLookupTable::new()), + ipv6: RwLock::new(IpLookupTable::new()), }); // start worker threads @@ -119,7 +111,7 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi let mut threads = Vec::with_capacity(num_workers); for _ in 0..num_workers { let (tx, rx) = sync_channel(WORKER_QUEUE_SIZE); - queues.push(spin::Mutex::new(tx)); + queues.push(Mutex::new(tx)); threads.push(thread::spawn(move || worker_parallel(rx))); } @@ -133,6 +125,40 @@ impl<O: Opaque, R: Callback<O>, S: Callback<O>, K: KeyCallback<O>, T: Tun, B: Bi } } +#[inline(always)] +fn get_route<C: Callbacks, T: Tun, B: Bind>( + device: &Arc<DeviceInner<C, T, B>>, + packet: &[u8], +) -> Option<Arc<PeerInner<C, T, B>>> { + match packet[0] >> 4 { + VERSION_IP4 => { + // check length and cast to IPv4 header + let (header, _) = LayoutVerified::new_from_prefix(packet)?; + let header: LayoutVerified<&[u8], IPv4Header> = header; + + // check IPv4 source address + device + .ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_source)) + .and_then(|(_, _, p)| Some(p.clone())) + } + VERSION_IP6 => { + // check length and cast to IPv6 header + let (header, packet) = LayoutVerified::new_from_prefix(packet)?; + let header: LayoutVerified<&[u8], IPv6Header> = header; + + // check IPv6 source address + device + .ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_source)) + .and_then(|(_, _, p)| Some(p.clone())) + } + _ => None, + } +} + impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { /// Adds a new peer to the device /// @@ -159,48 +185,12 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { let packet = &msg[SIZE_MESSAGE_PREFIX..]; // lookup peer based on IP packet destination address - let peer = match packet[0] >> 4 { - VERSION_IP4 => { - if msg.len() >= SIZE_IP4_HEADER { - // extract IPv4 destination address - let mut dst = [0u8; 4]; - dst.copy_from_slice(&packet[OFFSET_IP4_DST..OFFSET_IP4_DST + 4]); - let dst = Ipv4Addr::from(dst); - - // lookup peer (project unto and clone "value" field) - self.state - .ipv4 - .read() - .longest_match(dst) - .and_then(|(_, _, p)| p.upgrade()) - .ok_or(RouterError::NoCryptKeyRoute) - } else { - Err(RouterError::MalformedIPHeader) - } - } - VERSION_IP6 => { - if msg.len() >= SIZE_IP6_HEADER { - // extract IPv6 destination address - let mut dst = [0u8; 16]; - dst.copy_from_slice(&packet[OFFSET_IP6_DST..OFFSET_IP6_DST + 16]); - let dst = Ipv6Addr::from(dst); - - // lookup peer (project unto and clone "value" field) - self.state - .ipv6 - .read() - .longest_match(dst) - .and_then(|(_, _, p)| p.upgrade()) - .ok_or(RouterError::NoCryptKeyRoute) - } else { - Err(RouterError::MalformedIPHeader) - } - } - _ => Err(RouterError::MalformedIPHeader), - }?; + let peer = get_route(&self.state, packet).ok_or(RouterError::NoCryptKeyRoute)?; // schedule for encryption and transmission to peer if let Some(job) = peer.send_job(msg) { + debug_assert_eq!(job.1.op, Operation::Encryption); + // add job to worker queue let idx = self.queue_next.fetch_add(1, Ordering::SeqCst); self.queues[idx % self.queues.len()] @@ -216,17 +206,44 @@ impl<C: Callbacks, T: Tun, B: Bind> Device<C, T, B> { /// /// # Arguments /// + /// - src: Source address of the packet /// - msg: Encrypted transport message - pub fn recv(&self, msg: Vec<u8>) -> Result<(), RouterError> { - // ensure that the type field access is within bounds - if msg.len() < SIZE_MESSAGE_PREFIX || msg[0] != TYPE_TRANSPORT { - return Err(RouterError::MalformedTransportMessage); - } - + /// + /// # Returns + /// + /// + pub fn recv(&self, src: B::Endpoint, msg: Vec<u8>) -> Result<(), RouterError> { // parse / cast + let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { + Some(v) => v, + None => { + return Err(RouterError::MalformedTransportMessage); + } + }; + let header: LayoutVerified<&[u8], TransportHeader> = header; + debug_assert!( + header.f_type.get() == TYPE_TRANSPORT as u32, + "this should be checked by the message type multiplexer" + ); // lookup peer based on receiver id + let dec = self.state.recv.read(); + let dec = dec + .get(&header.f_receiver.get()) + .ok_or(RouterError::UnkownReceiverId)?; + + // schedule for decryption and TUN write + if let Some(job) = dec.peer.recv_job(src, dec.clone(), msg) { + debug_assert_eq!(job.1.op, Operation::Decryption); - unimplemented!(); + // add job to worker queue + let idx = self.queue_next.fetch_add(1, Ordering::SeqCst); + self.queues[idx % self.queues.len()] + .lock() + .send(job) + .unwrap(); + } + + Ok(()) } } diff --git a/src/router/ip.rs b/src/router/ip.rs new file mode 100644 index 0000000..6eb303c --- /dev/null +++ b/src/router/ip.rs @@ -0,0 +1,37 @@ +use byteorder::BigEndian; +use zerocopy::byteorder::U16; +use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified}; + +pub const SIZE_IP4_HEADER: usize = 16; +pub const SIZE_IP6_HEADER: usize = 36; + +pub const VERSION_IP4: u8 = 4; +pub const VERSION_IP6: u8 = 6; + +pub const OFFSET_IP4_SRC: usize = 12; +pub const OFFSET_IP6_SRC: usize = 8; + +pub const OFFSET_IP4_DST: usize = 16; +pub const OFFSET_IP6_DST: usize = 24; + +pub const TYPE_TRANSPORT: u8 = 4; + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct IPv4Header { + _f_space1: [u8; 2], + pub f_total_len: U16<BigEndian>, + _f_space2: [u8; 8], + pub f_source: [u8; 4], + pub f_destination: [u8; 4], +} + +#[repr(packed)] +#[derive(Copy, Clone, FromBytes, AsBytes)] +pub struct IPv6Header { + _f_pre: [u8; 4], + pub f_len: U16<BigEndian>, + _f_space2: [u8; 2], + pub f_source: [u8; 16], + pub f_destination: [u8; 16], +} diff --git a/src/router/mod.rs b/src/router/mod.rs index ec560b4..883c875 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -1,6 +1,7 @@ mod anti_replay; mod constants; mod device; +mod ip; mod messages; mod peer; mod types; diff --git a/src/router/peer.rs b/src/router/peer.rs index 634f980..0cd588d 100644 --- a/src/router/peer.rs +++ b/src/router/peer.rs @@ -30,7 +30,7 @@ use super::workers::Operation; use super::workers::{worker_inbound, worker_outbound}; use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel}; -use super::constants::MAX_STAGED_PACKETS; +use super::constants::*; use super::types::Callbacks; pub struct KeyWheel { @@ -50,7 +50,7 @@ pub struct PeerInner<C: Callbacks, T: Tun, B: Bind> { pub tx_bytes: AtomicU64, // transmitted bytes pub keys: Mutex<KeyWheel>, // key-wheel pub ekey: Mutex<Option<EncryptionState>>, // encryption state - pub endpoint: Mutex<Option<Arc<SocketAddr>>>, + pub endpoint: Mutex<Option<B::Endpoint>>, } pub struct Peer<C: Callbacks, T: Tun, B: Bind> { @@ -61,7 +61,7 @@ pub struct Peer<C: Callbacks, T: Tun, B: Bind> { fn treebit_list<A, E, C: Callbacks, T: Tun, B: Bind>( peer: &Arc<PeerInner<C, T, B>>, - table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>, + table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>, callback: Box<dyn Fn(A, u32) -> E>, ) -> Vec<E> where @@ -70,10 +70,8 @@ where let mut res = Vec::new(); for subnet in table.read().iter() { let (ip, masklen, p) = subnet; - if let Some(p) = p.upgrade() { - if Arc::ptr_eq(&p, &peer) { - res.push(callback(ip, masklen)) - } + if Arc::ptr_eq(&p, &peer) { + res.push(callback(ip, masklen)) } } res @@ -81,7 +79,7 @@ where fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>( peer: &Peer<C, T, B>, - table: &spin::RwLock<IpLookupTable<A, Weak<PeerInner<C, T, B>>>>, + table: &spin::RwLock<IpLookupTable<A, Arc<PeerInner<C, T, B>>>>, ) { let mut m = table.write(); @@ -89,10 +87,8 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>( let mut subnets = vec![]; for subnet in m.iter() { let (ip, masklen, p) = subnet; - if let Some(p) = p.upgrade() { - if Arc::ptr_eq(&p, &peer.state) { - subnets.push((ip, masklen)) - } + if Arc::ptr_eq(&p, &peer.state) { + subnets.push((ip, masklen)) } } @@ -103,6 +99,29 @@ fn treebit_remove<A: Address, C: Callbacks, T: Tun, B: Bind>( } } +impl EncryptionState { + fn new(keypair: &Arc<KeyPair>) -> EncryptionState { + EncryptionState { + id: keypair.send.id, + key: keypair.send.key, + nonce: 0, + death: keypair.birth + REJECT_AFTER_TIME, + } + } +} + +impl<C: Callbacks, T: Tun, B: Bind> DecryptionState<C, T, B> { + fn new(peer: &Arc<PeerInner<C, T, B>>, keypair: &Arc<KeyPair>) -> DecryptionState<C, T, B> { + DecryptionState { + confirmed: AtomicBool::new(keypair.initiator), + keypair: keypair.clone(), + protector: spin::Mutex::new(AntiReplay::new()), + peer: peer.clone(), + death: keypair.birth + REJECT_AFTER_TIME, + } + } +} + impl<C: Callbacks, T: Tun, B: Bind> Drop for Peer<C, T, B> { fn drop(&mut self) { let peer = &self.state; @@ -202,12 +221,52 @@ pub fn new_peer<C: Callbacks, T: Tun, B: Bind>( } impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { - pub fn confirm_key(&self, kp: Weak<KeyPair>) { - // upgrade key-pair to strong reference + pub fn confirm_key(&self, keypair: &Arc<KeyPair>) { + // take lock and check keypair = keys.next + let mut keys = self.keys.lock(); + let next = match keys.next.as_ref() { + Some(next) => next, + None => { + return; + } + }; + if !Arc::ptr_eq(&next, keypair) { + return; + } - // check it is the new unconfirmed key + // allocate new encryption state + let ekey = Some(EncryptionState::new(&next)); // rotate key-wheel + let mut swap = None; + mem::swap(&mut keys.next, &mut swap); + mem::swap(&mut keys.current, &mut swap); + mem::swap(&mut keys.previous, &mut swap); + + // set new encryption key + *self.ekey.lock() = ekey; + } + + pub fn recv_job( + &self, + src: B::Endpoint, + dec: Arc<DecryptionState<C, T, B>>, + mut msg: Vec<u8>, + ) -> Option<JobParallel> { + let (tx, rx) = oneshot(); + let key = dec.keypair.send.key; + match self.inbound.lock().try_send((dec, src, rx)) { + Ok(_) => Some(( + tx, + JobBuffer { + msg, + key: key, + okay: false, + op: Operation::Decryption, + }, + )), + Err(_) => None, + } } pub fn send_job(&self, mut msg: Vec<u8>) -> Option<JobParallel> { @@ -260,7 +319,7 @@ impl<C: Callbacks, T: Tun, B: Bind> PeerInner<C, T, B> { impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { pub fn set_endpoint(&self, endpoint: SocketAddr) { - *self.state.endpoint.lock() = Some(Arc::new(endpoint)) + *self.state.endpoint.lock() = Some(endpoint.into()); } /// Add a new keypair @@ -285,12 +344,7 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { // update key-wheel if new.initiator { // start using key for encryption - *self.state.ekey.lock() = Some(EncryptionState { - id: new.send.id, - key: new.send.key, - nonce: 0, - death: new.birth + REJECT_AFTER_TIME, - }); + *self.state.ekey.lock() = Some(EncryptionState::new(&new)); // move current into previous keys.previous = keys.current.as_ref().map(|v| v.clone()); @@ -310,19 +364,11 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { recv.remove(&id); } - // map new id to keypair + // map new id to decryption state debug_assert!(!recv.contains_key(&new.recv.id)); - recv.insert( new.recv.id, - DecryptionState { - confirmed: AtomicBool::new(new.initiator), - keypair: Arc::downgrade(&new), - key: new.recv.key, - protector: spin::Mutex::new(AntiReplay::new()), - peer: Arc::downgrade(&self.state), - death: new.birth + REJECT_AFTER_TIME, - }, + Arc::new(DecryptionState::new(&self.state, &new)), ); } @@ -345,14 +391,14 @@ impl<C: Callbacks, T: Tun, B: Bind> Peer<C, T, B> { .device .ipv4 .write() - .insert(v4, masklen, Arc::downgrade(&self.state)) + .insert(v4, masklen, self.state.clone()) } IpAddr::V6(v6) => { self.state .device .ipv6 .write() - .insert(v6, masklen, Arc::downgrade(&self.state)) + .insert(v6, masklen, self.state.clone()) } }; } diff --git a/src/router/tests.rs b/src/router/tests.rs index 5463532..7fe2b7a 100644 --- a/src/router/tests.rs +++ b/src/router/tests.rs @@ -156,6 +156,8 @@ mod tests { #[bench] fn bench_outbound(b: &mut Bencher) { + init(); + // type for tracking number of packets type Opaque = Arc<AtomicU64>; diff --git a/src/router/types.rs b/src/router/types.rs index 336f56b..7706997 100644 --- a/src/router/types.rs +++ b/src/router/types.rs @@ -57,6 +57,7 @@ pub enum RouterError { NoCryptKeyRoute, MalformedIPHeader, MalformedTransportMessage, + UnkownReceiverId, } impl fmt::Display for RouterError { @@ -65,6 +66,9 @@ impl fmt::Display for RouterError { RouterError::NoCryptKeyRoute => write!(f, "No cryptkey route configured for subnet"), RouterError::MalformedIPHeader => write!(f, "IP header is malformed"), RouterError::MalformedTransportMessage => write!(f, "IP header is malformed"), + RouterError::UnkownReceiverId => { + write!(f, "No decryption state associated with receiver id") + } } } } diff --git a/src/router/workers.rs b/src/router/workers.rs index b18b038..45e1058 100644 --- a/src/router/workers.rs +++ b/src/router/workers.rs @@ -1,6 +1,6 @@ use std::mem; use std::sync::mpsc::Receiver; -use std::sync::{Arc, Weak}; +use std::sync::Arc; use futures::sync::oneshot; use futures::*; @@ -8,15 +8,17 @@ use futures::*; use log::debug; use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::net::{Ipv4Addr, Ipv6Addr}; +use std::sync::atomic::Ordering; use zerocopy::{AsBytes, LayoutVerified}; -use super::device::DecryptionState; -use super::device::DeviceInner; +use super::device::{DecryptionState, DeviceInner}; use super::messages::TransportHeader; use super::peer::PeerInner; use super::types::Callbacks; +use super::ip::*; + use super::super::types::{Bind, Tun}; #[derive(PartialEq, Debug)] @@ -33,9 +35,60 @@ pub struct JobBuffer { } pub type JobParallel = (oneshot::Sender<JobBuffer>, JobBuffer); -pub type JobInbound<C, T, B> = (Weak<DecryptionState<C, T, B>>, oneshot::Receiver<JobBuffer>); +pub type JobInbound<C, T, B: Bind> = ( + Arc<DecryptionState<C, T, B>>, + B::Endpoint, + oneshot::Receiver<JobBuffer>, +); pub type JobOutbound = oneshot::Receiver<JobBuffer>; +#[inline(always)] +fn check_route<C: Callbacks, T: Tun, B: Bind>( + device: &Arc<DeviceInner<C, T, B>>, + peer: &Arc<PeerInner<C, T, B>>, + packet: &[u8], +) -> Option<usize> { + match packet[0] >> 4 { + VERSION_IP4 => { + // check length and cast to IPv4 header + let (header, _) = LayoutVerified::new_from_prefix(packet)?; + let header: LayoutVerified<&[u8], IPv4Header> = header; + + // check IPv4 source address + device + .ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, &peer) { + Some(header.f_total_len.get() as usize) + } else { + None + } + }) + } + VERSION_IP6 => { + // check length and cast to IPv6 header + let (header, packet) = LayoutVerified::new_from_prefix(packet)?; + let header: LayoutVerified<&[u8], IPv6Header> = header; + + // check IPv6 source address + device + .ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, &peer) { + Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>()) + } else { + None + } + }) + } + _ => None, + } +} + pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( device: Arc<DeviceInner<C, T, B>>, // related device peer: Arc<PeerInner<C, T, B>>, // related peer @@ -43,7 +96,7 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( ) { loop { // fetch job - let (state, rx) = match receiver.recv() { + let (state, endpoint, rx) = match receiver.recv() { Ok(v) => v, _ => { return; @@ -62,13 +115,10 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( } }; let header: LayoutVerified<&[u8], TransportHeader> = header; - - // obtain strong reference to decryption state - let state = if let Some(state) = state.upgrade() { - state - } else { - return; - }; + debug_assert!( + packet.len() >= 16, + "this should be checked earlier in the pipeline" + ); // check for replay if !state.protector.lock().update(header.f_counter.get()) { @@ -77,23 +127,29 @@ pub fn worker_inbound<C: Callbacks, T: Tun, B: Bind>( // check for confirms key if !state.confirmed.swap(true, Ordering::SeqCst) { - peer.confirm_key(state.keypair.clone()); + peer.confirm_key(&state.keypair); } - // update endpoint, TODO - - // write packet to TUN device, TODO + // update endpoint + *peer.endpoint.lock() = Some(endpoint); + + // calculate length of IP packet + padding + let length = packet.len() - CHACHA20_POLY1305.nonce_len(); + + // check if should be written to TUN + let mut sent = false; + if length > 0 { + if let Some(inner_len) = check_route(&device, &peer, &packet[..length]) { + debug_assert!(inner_len <= length, "should be validated"); + if inner_len <= length { + sent = true; + let _ = device.tun.write(&packet[..inner_len]); + } + } + } // trigger callback - debug_assert!( - packet.len() >= CHACHA20_POLY1305.nonce_len(), - "this should be checked earlier in the pipeline" - ); - (device.call_recv)( - &peer.opaque, - packet.len() > CHACHA20_POLY1305.nonce_len(), - true, - ); + (device.call_recv)(&peer.opaque, length == 0, sent); } }) .wait(); diff --git a/src/types/endpoint.rs b/src/types/endpoint.rs index 6bc99b9..8033080 100644 --- a/src/types/endpoint.rs +++ b/src/types/endpoint.rs @@ -1,5 +1,5 @@ use std::net::SocketAddr; -pub trait Endpoint: Into<SocketAddr> + From<SocketAddr> {} +pub trait Endpoint: Into<SocketAddr> + From<SocketAddr> + Send {} -impl<T> Endpoint for T where T: Into<SocketAddr> + From<SocketAddr> {} +impl<T> Endpoint for T where T: Into<SocketAddr> + From<SocketAddr> + Send {} |