use std::mem; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::mpsc::{sync_channel, SyncSender}; use std::sync::{Arc, Weak}; use std::thread; use log::debug; use spin::Mutex; use arraydeque::{ArrayDeque, Saturating, Wrapping}; use zerocopy::{AsBytes, LayoutVerified}; use treebitmap::address::Address; use treebitmap::IpLookupTable; use super::super::constants::*; use super::super::types::{Bind, KeyPair, Tun}; use super::anti_replay::AntiReplay; use super::device::DecryptionState; use super::device::DeviceInner; use super::device::EncryptionState; use super::messages::TransportHeader; use futures::*; use super::workers::Operation; use super::workers::{worker_inbound, worker_outbound}; use super::workers::{JobBuffer, JobInbound, JobOutbound, JobParallel}; use super::constants::*; use super::types::Callbacks; pub struct KeyWheel { next: Option>, // next key state (unconfirmed) current: Option>, // current key state (used for encryption) previous: Option>, // old key state (used for decryption) retired: Option, // retired id (previous id, after confirming key-pair) } pub struct PeerInner { pub device: Arc>, pub opaque: C::Opaque, pub outbound: Mutex>, pub inbound: Mutex>>, pub staged_packets: Mutex; MAX_STAGED_PACKETS], Wrapping>>, // packets awaiting handshake pub rx_bytes: AtomicU64, // received bytes pub tx_bytes: AtomicU64, // transmitted bytes pub keys: Mutex, // key-wheel pub ekey: Mutex>, // encryption state pub endpoint: Mutex>, } pub struct Peer { state: Arc>, thread_outbound: Option>, thread_inbound: Option>, } fn treebit_list( peer: &Arc>, table: &spin::RwLock>>>, callback: Box E>, ) -> Vec where A: Address, { let mut res = Vec::new(); for subnet in table.read().iter() { let (ip, masklen, p) = subnet; if Arc::ptr_eq(&p, &peer) { res.push(callback(ip, masklen)) } } res } fn treebit_remove( peer: &Peer, table: &spin::RwLock>>>, ) { let mut m = table.write(); // collect keys for value let mut subnets = vec![]; for subnet in m.iter() { let (ip, masklen, p) = subnet; if Arc::ptr_eq(&p, &peer.state) { subnets.push((ip, masklen)) } } // remove all key mappings for (ip, masklen) in subnets { let r = m.remove(ip, masklen); debug_assert!(r.is_some()); } } impl EncryptionState { fn new(keypair: &Arc) -> EncryptionState { EncryptionState { id: keypair.send.id, key: keypair.send.key, nonce: 0, death: keypair.birth + REJECT_AFTER_TIME, } } } impl DecryptionState { fn new(peer: &Arc>, keypair: &Arc) -> DecryptionState { DecryptionState { confirmed: AtomicBool::new(keypair.initiator), keypair: keypair.clone(), protector: spin::Mutex::new(AntiReplay::new()), peer: peer.clone(), death: keypair.birth + REJECT_AFTER_TIME, } } } impl Drop for Peer { fn drop(&mut self) { let peer = &self.state; // remove from cryptkey router treebit_remove(self, &peer.device.ipv4); treebit_remove(self, &peer.device.ipv6); // drop channels mem::replace(&mut *peer.inbound.lock(), sync_channel(0).0); mem::replace(&mut *peer.outbound.lock(), sync_channel(0).0); // join with workers mem::replace(&mut self.thread_inbound, None).map(|v| v.join()); mem::replace(&mut self.thread_outbound, None).map(|v| v.join()); // release ids from the receiver map let mut keys = peer.keys.lock(); let mut release = Vec::with_capacity(3); keys.next.as_ref().map(|k| release.push(k.recv.id)); keys.current.as_ref().map(|k| release.push(k.recv.id)); keys.previous.as_ref().map(|k| release.push(k.recv.id)); if release.len() > 0 { let mut recv = peer.device.recv.write(); for id in &release { recv.remove(id); } } // null key-material keys.next = None; keys.current = None; keys.previous = None; *peer.ekey.lock() = None; *peer.endpoint.lock() = None; debug!("peer dropped & removed from device"); } } pub fn new_peer( device: Arc>, opaque: C::Opaque, ) -> Peer { let (out_tx, out_rx) = sync_channel(128); let (in_tx, in_rx) = sync_channel(128); // allocate peer object let peer = { let device = device.clone(); Arc::new(PeerInner { opaque, device, inbound: Mutex::new(in_tx), outbound: Mutex::new(out_tx), ekey: spin::Mutex::new(None), endpoint: spin::Mutex::new(None), keys: spin::Mutex::new(KeyWheel { next: None, current: None, previous: None, retired: None, }), rx_bytes: AtomicU64::new(0), tx_bytes: AtomicU64::new(0), staged_packets: spin::Mutex::new(ArrayDeque::new()), }) }; // spawn outbound thread let thread_inbound = { let peer = peer.clone(); let device = device.clone(); thread::spawn(move || worker_outbound(device, peer, out_rx)) }; // spawn inbound thread let thread_outbound = { let peer = peer.clone(); let device = device.clone(); thread::spawn(move || worker_inbound(device, peer, in_rx)) }; Peer { state: peer, thread_inbound: Some(thread_inbound), thread_outbound: Some(thread_outbound), } } impl PeerInner { pub fn confirm_key(&self, keypair: &Arc) { // take lock and check keypair = keys.next let mut keys = self.keys.lock(); let next = match keys.next.as_ref() { Some(next) => next, None => { return; } }; if !Arc::ptr_eq(&next, keypair) { return; } // allocate new encryption state let ekey = Some(EncryptionState::new(&next)); // rotate key-wheel let mut swap = None; mem::swap(&mut keys.next, &mut swap); mem::swap(&mut keys.current, &mut swap); mem::swap(&mut keys.previous, &mut swap); // set new encryption key *self.ekey.lock() = ekey; } pub fn recv_job( &self, src: B::Endpoint, dec: Arc>, mut msg: Vec, ) -> Option { let (tx, rx) = oneshot(); let key = dec.keypair.send.key; match self.inbound.lock().try_send((dec, src, rx)) { Ok(_) => Some(( tx, JobBuffer { msg, key: key, okay: false, op: Operation::Decryption, }, )), Err(_) => None, } } pub fn send_job(&self, mut msg: Vec) -> Option { debug_assert!(msg.len() >= mem::size_of::()); // parse / cast let (header, _) = LayoutVerified::new_from_prefix(&mut msg[..]).unwrap(); let mut header: LayoutVerified<&mut [u8], TransportHeader> = header; // check if has key let key = match self.ekey.lock().as_mut() { None => { // add to staged packets (create no job) debug!("execute callback: call_need_key"); (self.device.call_need_key)(&self.opaque); self.staged_packets.lock().push_back(msg); return None; } Some(mut state) => { // avoid integer overflow in nonce if state.nonce >= REJECT_AFTER_MESSAGES - 1 { return None; } debug!("encryption state available, nonce = {}", state.nonce); // set transport message fields header.f_counter.set(state.nonce); header.f_receiver.set(state.id); state.nonce += 1; state.key } }; // add job to in-order queue and return sendeer to device for inclusion in worker pool let (tx, rx) = oneshot(); match self.outbound.lock().try_send(rx) { Ok(_) => Some(( tx, JobBuffer { msg, key, okay: false, op: Operation::Encryption, }, )), Err(_) => None, } } } impl Peer { pub fn set_endpoint(&self, endpoint: SocketAddr) { *self.state.endpoint.lock() = Some(endpoint.into()); } /// Add a new keypair /// /// # Arguments /// /// - new: The new confirmed/unconfirmed key pair /// /// # Returns /// /// A vector of ids which has been released. /// These should be released in the handshake module. pub fn add_keypair(&self, new: KeyPair) -> Vec { let mut keys = self.state.keys.lock(); let mut release = Vec::with_capacity(2); let new = Arc::new(new); // collect ids to be released keys.retired.map(|v| release.push(v)); keys.previous.as_ref().map(|k| release.push(k.recv.id)); // update key-wheel if new.initiator { // start using key for encryption *self.state.ekey.lock() = Some(EncryptionState::new(&new)); // move current into previous keys.previous = keys.current.as_ref().map(|v| v.clone()); keys.current = Some(new.clone()); } else { // store the key and await confirmation keys.previous = keys.next.as_ref().map(|v| v.clone()); keys.next = Some(new.clone()); }; // update incoming packet id map { let mut recv = self.state.device.recv.write(); // purge recv map of released ids for id in &release { recv.remove(&id); } // map new id to decryption state debug_assert!(!recv.contains_key(&new.recv.id)); recv.insert( new.recv.id, Arc::new(DecryptionState::new(&self.state, &new)), ); } // return the released id (for handshake state machine) release } pub fn rx_bytes(&self) -> u64 { self.state.rx_bytes.load(Ordering::Relaxed) } pub fn tx_bytes(&self) -> u64 { self.state.tx_bytes.load(Ordering::Relaxed) } pub fn add_subnet(&self, ip: IpAddr, masklen: u32) { match ip { IpAddr::V4(v4) => { self.state .device .ipv4 .write() .insert(v4, masklen, self.state.clone()) } IpAddr::V6(v6) => { self.state .device .ipv6 .write() .insert(v6, masklen, self.state.clone()) } }; } pub fn list_subnets(&self) -> Vec<(IpAddr, u32)> { let mut res = Vec::new(); res.append(&mut treebit_list( &self.state, &self.state.device.ipv4, Box::new(|ip, masklen| (IpAddr::V4(ip), masklen)), )); res.append(&mut treebit_list( &self.state, &self.state.device.ipv6, Box::new(|ip, masklen| (IpAddr::V6(ip), masklen)), )); res } pub fn remove_subnets(&self) { treebit_remove(self, &self.state.device.ipv4); treebit_remove(self, &self.state.device.ipv6); } fn send(&self, msg: Vec) {} }