diff options
author | Guanhao Yin <sopium@mysterious.site> | 2017-03-25 20:06:03 +0800 |
---|---|---|
committer | Guanhao Yin <sopium@mysterious.site> | 2017-03-25 21:51:29 +0800 |
commit | a7b5b48c96d6b2b6e5653ba4579ef824cf8c4e3d (patch) | |
tree | 44f71d7b278744f9cdfb2a21a991819f70d29af1 | |
parent | Add TimerHandle::dummy(), don't put dummy timers into the wheel (diff) | |
download | wireguard-rs-a7b5b48c96d6b2b6e5653ba4579ef824cf8c4e3d.tar.xz wireguard-rs-a7b5b48c96d6b2b6e5653ba4579ef824cf8c4e3d.zip |
Manage REKEY_AFTER_TIME and REJECT_AFTER_TIME with timers
-rw-r--r-- | src/protocol/controller.rs | 92 | ||||
-rw-r--r-- | src/protocol/timer.rs | 11 |
2 files changed, 65 insertions, 38 deletions
diff --git a/src/protocol/controller.rs b/src/protocol/controller.rs index 8f71519..2243736 100644 --- a/src/protocol/controller.rs +++ b/src/protocol/controller.rs @@ -114,9 +114,9 @@ struct PeerState { tx_bytes: AtomicU64, // XXX: use a Vec? or ArrayVec? - transport0: Option<Transport>, - transport1: Option<Transport>, - transport2: Option<Transport>, + transport0: Option<Arc<Transport>>, + transport1: Option<Arc<Transport>>, + transport2: Option<Arc<Transport>>, // Rekey because of send but not recv in... rekey_no_recv: TimerHandle, @@ -143,6 +143,8 @@ struct Transport { self_id: IdMapGuard, peer_id: Id, is_initiator: bool, + // Is set to true after REKEY_AFTER_TIME if `is_initiator`. + should_handshake: AtomicBool, // If we are responder, should not send until received one packet. is_initiator_or_has_received: AtomicBool, // Also should not send after REJECT_AFTER_TIME, @@ -155,6 +157,10 @@ struct Transport { recv_key: SecretKey, recv_ar: Mutex<AntiReplay>, + + // Use mutex to make the compiler happy. + rekey_after_time: Mutex<TimerHandle>, + reject_after_time: Mutex<TimerHandle>, } // TODO determine / detect load. @@ -744,7 +750,6 @@ pub fn wg_change_peer<F>(wg: Arc<WgState>, peer_pubkey: &X25519Pubkey, f: F) -> /// The peer should not already exist. pub fn wg_add_peer(wg: Arc<WgState>, peer: &PeerInfo, sock: Arc<UdpSocket>) { let register = |a| CONTROLLER.register_delay(Duration::from_secs(0), a); - let dummy_action = || Box::new(|| {}); // Lock pubkey_map. let mut pubkey_map = wg.pubkey_map.write().unwrap(); @@ -760,10 +765,10 @@ pub fn wg_add_peer(wg: Arc<WgState>, peer: &PeerInfo, sock: Arc<UdpSocket>) { transport0: None, transport1: None, transport2: None, - rekey_no_recv: register(dummy_action()), - keep_alive: register(dummy_action()), - persistent_keep_alive: register(dummy_action()), - clear: register(dummy_action()), + rekey_no_recv: TimerHandle::dummy(), + keep_alive: TimerHandle::dummy(), + persistent_keep_alive: TimerHandle::dummy(), + clear: TimerHandle::dummy(), }; let ps = Arc::new(RwLock::new(ps)); @@ -854,7 +859,7 @@ impl WgState { wg } - // These methods helps a lot in avoiding deadlocks. + // These methods help a lot in avoiding deadlocks. fn find_peer_by_id(&self, id: Id) -> Option<SharedPeerState> { self.id_map.read().unwrap().get(&id).cloned() @@ -962,7 +967,7 @@ impl PeerState { }); } - fn push_transport(&mut self, t: Transport) { + fn push_transport(&mut self, t: Arc<Transport>) { self.on_new_transport(); self.transport2 = self.transport1.take(); @@ -1026,7 +1031,7 @@ impl PeerState { } impl Transport { - fn new_from_hs(self_id: IdMapGuard, peer_id: Id, hs: HS) -> Self { + fn new_from_hs(self_id: IdMapGuard, peer_id: Id, hs: HS) -> Arc<Transport> { let (x, y) = hs.get_ciphers(); let (s, r) = if hs.get_is_initiator() { (x, y) @@ -1036,9 +1041,10 @@ impl Transport { let sk = s.extract().0; let rk = r.extract().0; - Transport { + let t = Arc::new(Transport { self_id: self_id, peer_id: peer_id, + should_handshake: AtomicBool::new(false), is_initiator: hs.get_is_initiator(), is_initiator_or_has_received: AtomicBool::new(hs.get_is_initiator()), not_too_old: AtomicBool::new(true), @@ -1047,7 +1053,38 @@ impl Transport { created: Instant::now(), recv_ar: Mutex::new(AntiReplay::new()), send_counter: AtomicU64::new(0), + rekey_after_time: Mutex::new(TimerHandle::dummy()), + reject_after_time: Mutex::new(TimerHandle::dummy()), + }); + + let w = Arc::downgrade(&t); + + if t.is_initiator { + let w = w.clone(); + let delay = Duration::from_secs(REKEY_AFTER_TIME); + let r = CONTROLLER.register_delay(delay, Box::new(move || { + debug!("Timer: mark should handshake."); + w.upgrade().map(|t| { + t.should_handshake.store(true, Relaxed); + }); + })); + r.activate(); + *t.rekey_after_time.lock().unwrap() = r; } + + { + let delay = Duration::from_secs(REJECT_AFTER_TIME); + let r = CONTROLLER.register_delay(delay, Box::new(move || { + debug!("Timer: mark too old."); + w.upgrade().map(|t| { + t.not_too_old.store(false, Relaxed); + }); + })); + r.activate(); + *t.reject_after_time.lock().unwrap() = r; + } + + t } fn get_should_send(&self) -> bool { @@ -1065,23 +1102,16 @@ impl Transport { /// Length: out.len() = msg.len() + 32. fn encrypt(&self, msg: &[u8], out: &mut [u8]) -> (Result<(), ()>, bool) { let c = self.send_counter.fetch_add(1, Relaxed); - let mut should_rekey = false; - if self.is_initiator && c >= REKEY_AFTER_MESSAGES { - should_rekey = true; - } - if c >= REJECT_AFTER_MESSAGES { - self.not_too_old.store(false, Relaxed); - return (Err(()), should_rekey); - } - - let age = self.created.elapsed(); + let mut should_rekey = self.should_handshake.load(Relaxed); - if age >= Duration::from_secs(REKEY_AFTER_TIME) { + // This is REALLY REALLY unlikely... + if c >= REKEY_AFTER_MESSAGES { should_rekey = true; - } - if age >= Duration::from_secs(REJECT_AFTER_TIME) { - self.not_too_old.store(false, Relaxed); - return (Err(()), should_rekey); + // Even more unlikely... + if c >= REJECT_AFTER_MESSAGES { + self.not_too_old.store(false, Relaxed); + return (Err(()), should_rekey); + } } out[0..4].copy_from_slice(&[4, 0, 0, 0]); @@ -1097,19 +1127,15 @@ impl Transport { /// /// Length: out.len() + 32 = msg.len(). fn decrypt(&self, msg: &[u8], out: &mut [u8]) -> Result<(), ()> { - if self.created.elapsed() >= Duration::from_secs(REJECT_AFTER_TIME) { - return Err(()); - } - if msg.len() < 32 { return Err(()); } - if msg[0..4] != [4, 0, 0, 0] { + if !self.not_too_old.load(Relaxed) { return Err(()); } - if self.created.elapsed() >= Duration::from_secs(REJECT_AFTER_TIME) { + if msg[0..4] != [4, 0, 0, 0] { return Err(()); } diff --git a/src/protocol/timer.rs b/src/protocol/timer.rs index 5b8cf13..3991503 100644 --- a/src/protocol/timer.rs +++ b/src/protocol/timer.rs @@ -28,6 +28,11 @@ type Action = Box<Fn() + Send + Sync>; lazy_static! { /// Global timer controller. pub static ref CONTROLLER: TimerController = TimerController::new(); + static ref DUMMY: ArcTimer = ArcTimer(Arc::new(Timer { + activated: AtomicBool::new(false), + rounds: AtomicUsize::new(0), + action: Box::new(|| {}), + })); } struct Timer { @@ -139,11 +144,7 @@ impl TimerHandle { pub fn dummy() -> Self { TimerHandle { pos: AtomicUsize::new(0), - timer: ArcTimer(Arc::new(Timer { - activated: AtomicBool::new(false), - rounds: AtomicUsize::new(0), - action: Box::new(|| {}), - })), + timer: DUMMY.clone(), } } |