use std::mem; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::mpsc::{sync_channel, SyncSender}; use std::sync::{Arc, Weak}; use std::thread; use spin::Mutex; use arraydeque::{ArrayDeque, Saturating, Wrapping}; use zerocopy::{AsBytes, LayoutVerified}; use treebitmap::address::Address; use treebitmap::IpLookupTable; use super::super::constants::*; use super::super::types::{Bind, KeyPair, Tun}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; use super::device::DeviceInner; use super::device::EncryptionState; use super::messages::TransportHeader; use super::workers::{worker_inbound, worker_outbound}; use super::workers::{JobBuffer, JobInbound, JobInner, JobOutbound}; use super::workers::{Operation, Status}; use super::types::Callbacks; const MAX_STAGED_PACKETS: usize = 128; pub struct KeyWheel { next: Option>, // next key state (unconfirmed) current: Option>, // current key state (used for encryption) previous: Option>, // old key state (used for decryption) retired: Option, // retired id (previous id, after confirming key-pair) } pub struct PeerInner { pub stopped: AtomicBool, pub opaque: C::Opaque, pub outbound: Mutex>, pub inbound: Mutex; MAX_STAGED_PACKETS], Saturating>>, pub device: Arc>, pub thread_outbound: Mutex>>, pub thread_inbound: Mutex>>, pub staged_packets: Mutex; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake pub rx_bytes: AtomicU64, // received bytes pub tx_bytes: AtomicU64, // transmitted bytes pub keys: Mutex, // key-wheel pub ekey: Mutex>, // encryption state pub endpoint: Mutex>>, } pub struct Peer(Arc>); fn treebit_list( peer: &Arc>, table: &spin::RwLock>>>, callback: Box E>, ) -> Vec where A: Address, { let mut res = Vec::new(); for subnet in table.read().iter() { let (ip, masklen, p) = subnet; if let Some(p) = p.upgrade() { if Arc::ptr_eq(&p, &peer) { res.push(callback(ip, masklen)) } } } res } fn treebit_remove( peer: &Peer, table: &spin::RwLock>>>, ) { let mut m = table.write(); // collect keys for value let mut subnets = vec![]; for subnet in m.iter() { let (ip, masklen, p) = subnet; if let Some(p) = p.upgrade() { if Arc::ptr_eq(&p, &peer.0) { subnets.push((ip, masklen)) } } } // remove all key mappings for subnet in subnets { let r = m.remove(subnet.0, subnet.1); debug_assert!(r.is_some()); } } impl Drop for Peer { fn drop(&mut self) { // mark peer as stopped let peer = &self.0; peer.stopped.store(true, Ordering::SeqCst); // remove from cryptkey router treebit_remove(self, &peer.device.ipv4); treebit_remove(self, &peer.device.ipv6); // unpark threads peer.thread_inbound .lock() .as_ref() .unwrap() .thread() .unpark(); peer.thread_outbound .lock() .as_ref() .unwrap() .thread() .unpark(); // release ids from the receiver map let mut keys = peer.keys.lock(); let mut release = Vec::with_capacity(3); keys.next.as_ref().map(|k| release.push(k.recv.id)); keys.current.as_ref().map(|k| release.push(k.recv.id)); keys.previous.as_ref().map(|k| release.push(k.recv.id)); if release.len() > 0 { let mut recv = peer.device.recv.write(); for id in &release { recv.remove(id); } } // null key-material (TODO: extend) keys.next = None; keys.current = None; keys.previous = None; *peer.ekey.lock() = None; *peer.endpoint.lock() = None; } } pub fn new_peer( device: Arc>, opaque: C::Opaque, ) -> Peer { // allocate peer object let peer = { let device = device.clone(); Arc::new(PeerInner { opaque, inbound: Mutex::new(ArrayDeque::new()), outbound: Mutex::new(ArrayDeque::new()), stopped: AtomicBool::new(false), device: device, ekey: spin::Mutex::new(None), endpoint: spin::Mutex::new(None), keys: spin::Mutex::new(KeyWheel { next: None, current: None, previous: None, retired: None, }), rx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0), staged_packets: spin::Mutex::new(ArrayDeque::new()), thread_inbound: spin::Mutex::new(None), thread_outbound: spin::Mutex::new(None), }) }; // spawn outbound thread *peer.thread_inbound.lock() = { let peer = peer.clone(); let device = device.clone(); Some(thread::spawn(move || worker_outbound(device, peer))) }; // spawn inbound thread *peer.thread_outbound.lock() = { let peer = peer.clone(); let device = device.clone(); Some(thread::spawn(move || worker_inbound(device, peer))) }; Peer(peer) } impl PeerInner { pub fn confirm_key(&self, kp: Weak) { // upgrade key-pair to strong reference // check it is the new unconfirmed key // rotate key-wheel } pub fn send_job(&self, mut msg: Vec) -> Option { debug_assert!(msg.len() >= mem::size_of::()); // parse / cast let (header, _) = LayoutVerified::new_from_prefix(&mut msg[..]).unwrap(); let mut header: LayoutVerified<&mut [u8], TransportHeader> = header; // check if has key let key = match self.ekey.lock().as_mut() { None => { // add to staged packets (create no job) (self.device.call_need_key)(&self.opaque); self.staged_packets.lock().push_back(msg); return None; } Some(mut state) => { // allocate nonce state.nonce += 1; if state.nonce >= REJECT_AFTER_MESSAGES { state.nonce -= 1; return None; } // set transport message fields header.f_counter.set(state.nonce); header.f_receiver.set(state.id); state.key } }; // create job let job = Arc::new(spin::Mutex::new(JobInner { msg, key, status: Status::Waiting, op: Operation::Encryption, })); // add job to in-order queue and return to device for inclusion in worker pool match self.outbound.lock().push_back(job.clone()) { Ok(_) => Some(job), Err(_) => None, } } } impl Peer { fn new(inner: PeerInner) -> Peer { Peer(Arc::new(inner)) } pub fn set_endpoint(&self, endpoint: SocketAddr) { *self.0.endpoint.lock() = Some(Arc::new(endpoint)) } /// Add a new keypair /// /// # Arguments /// /// - new: The new confirmed/unconfirmed key pair /// /// # Returns /// /// A vector of ids which has been released. /// These should be released in the handshake module. pub fn add_keypair(&self, new: KeyPair) -> Vec { let mut keys = self.0.keys.lock(); let mut release = Vec::with_capacity(2); let new = Arc::new(new); // collect ids to be released keys.retired.map(|v| release.push(v)); keys.previous.as_ref().map(|k| release.push(k.recv.id)); // update key-wheel if new.initiator { // start using key for encryption *self.0.ekey.lock() = Some(EncryptionState { id: new.send.id, key: new.send.key, nonce: 0, death: new.birth + REJECT_AFTER_TIME, }); // move current into previous keys.previous = keys.current.as_ref().map(|v| v.clone());; keys.current = Some(new.clone()); } else { // store the key and await confirmation keys.previous = keys.next.as_ref().map(|v| v.clone());; keys.next = Some(new.clone()); }; // update incoming packet id map { let mut recv = self.0.device.recv.write(); // purge recv map of released ids for id in &release { recv.remove(&id); } // map new id to keypair debug_assert!(!recv.contains_key(&new.recv.id)); recv.insert( new.recv.id, DecryptionState { confirmed: AtomicBool::new(new.initiator), keypair: Arc::downgrade(&new), key: new.recv.key, protector: spin::Mutex::new(AntiReplay::new()), peer: Arc::downgrade(&self.0), death: new.birth + REJECT_AFTER_TIME, }, ); } // return the released id (for handshake state machine) release } pub fn rx_bytes(&self) -> u64 { self.0.rx_bytes.load(Ordering::Relaxed) } pub fn tx_bytes(&self) -> u64 { self.0.tx_bytes.load(Ordering::Relaxed) } pub fn add_subnet(&self, ip: IpAddr, masklen: u32) { match ip { IpAddr::V4(v4) => { self.0 .device .ipv4 .write() .insert(v4, masklen, Arc::downgrade(&self.0)) } IpAddr::V6(v6) => { self.0 .device .ipv6 .write() .insert(v6, masklen, Arc::downgrade(&self.0)) } }; } pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> { let mut res = Vec::new(); res.append(&mut treebit_list( &self.0, &self.0.device.ipv4, Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)), )); res.append(&mut treebit_list( &self.0, &self.0.device.ipv6, Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)), )); res } pub fn remove_subnets(&self) { treebit_remove(self, &self.0.device.ipv4); treebit_remove(self, &self.0.device.ipv6); } fn send(&self, msg: Vec) {} }