aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGuanhao Yin <sopium@mysterious.site>2017-03-25 20:06:03 +0800
committerGuanhao Yin <sopium@mysterious.site>2017-03-25 21:51:29 +0800
commita7b5b48c96d6b2b6e5653ba4579ef824cf8c4e3d (patch)
tree44f71d7b278744f9cdfb2a21a991819f70d29af1
parentAdd TimerHandle::dummy(), don't put dummy timers into the wheel (diff)
downloadwireguard-rs-a7b5b48c96d6b2b6e5653ba4579ef824cf8c4e3d.tar.xz
wireguard-rs-a7b5b48c96d6b2b6e5653ba4579ef824cf8c4e3d.zip
Manage REKEY_AFTER_TIME and REJECT_AFTER_TIME with timers
-rw-r--r--src/protocol/controller.rs92
-rw-r--r--src/protocol/timer.rs11
2 files changed, 65 insertions, 38 deletions
diff --git a/src/protocol/controller.rs b/src/protocol/controller.rs
index 8f71519..2243736 100644
--- a/src/protocol/controller.rs
+++ b/src/protocol/controller.rs
@@ -114,9 +114,9 @@ struct PeerState {
tx_bytes: AtomicU64,
// XXX: use a Vec? or ArrayVec?
- transport0: Option<Transport>,
- transport1: Option<Transport>,
- transport2: Option<Transport>,
+ transport0: Option<Arc<Transport>>,
+ transport1: Option<Arc<Transport>>,
+ transport2: Option<Arc<Transport>>,
// Rekey because of send but not recv in...
rekey_no_recv: TimerHandle,
@@ -143,6 +143,8 @@ struct Transport {
self_id: IdMapGuard,
peer_id: Id,
is_initiator: bool,
+ // Is set to true after REKEY_AFTER_TIME if `is_initiator`.
+ should_handshake: AtomicBool,
// If we are responder, should not send until received one packet.
is_initiator_or_has_received: AtomicBool,
// Also should not send after REJECT_AFTER_TIME,
@@ -155,6 +157,10 @@ struct Transport {
recv_key: SecretKey,
recv_ar: Mutex<AntiReplay>,
+
+ // Use mutex to make the compiler happy.
+ rekey_after_time: Mutex<TimerHandle>,
+ reject_after_time: Mutex<TimerHandle>,
}
// TODO determine / detect load.
@@ -744,7 +750,6 @@ pub fn wg_change_peer<F>(wg: Arc<WgState>, peer_pubkey: &X25519Pubkey, f: F) ->
/// The peer should not already exist.
pub fn wg_add_peer(wg: Arc<WgState>, peer: &PeerInfo, sock: Arc<UdpSocket>) {
let register = |a| CONTROLLER.register_delay(Duration::from_secs(0), a);
- let dummy_action = || Box::new(|| {});
// Lock pubkey_map.
let mut pubkey_map = wg.pubkey_map.write().unwrap();
@@ -760,10 +765,10 @@ pub fn wg_add_peer(wg: Arc<WgState>, peer: &PeerInfo, sock: Arc<UdpSocket>) {
transport0: None,
transport1: None,
transport2: None,
- rekey_no_recv: register(dummy_action()),
- keep_alive: register(dummy_action()),
- persistent_keep_alive: register(dummy_action()),
- clear: register(dummy_action()),
+ rekey_no_recv: TimerHandle::dummy(),
+ keep_alive: TimerHandle::dummy(),
+ persistent_keep_alive: TimerHandle::dummy(),
+ clear: TimerHandle::dummy(),
};
let ps = Arc::new(RwLock::new(ps));
@@ -854,7 +859,7 @@ impl WgState {
wg
}
- // These methods helps a lot in avoiding deadlocks.
+ // These methods help a lot in avoiding deadlocks.
fn find_peer_by_id(&self, id: Id) -> Option<SharedPeerState> {
self.id_map.read().unwrap().get(&id).cloned()
@@ -962,7 +967,7 @@ impl PeerState {
});
}
- fn push_transport(&mut self, t: Transport) {
+ fn push_transport(&mut self, t: Arc<Transport>) {
self.on_new_transport();
self.transport2 = self.transport1.take();
@@ -1026,7 +1031,7 @@ impl PeerState {
}
impl Transport {
- fn new_from_hs(self_id: IdMapGuard, peer_id: Id, hs: HS) -> Self {
+ fn new_from_hs(self_id: IdMapGuard, peer_id: Id, hs: HS) -> Arc<Transport> {
let (x, y) = hs.get_ciphers();
let (s, r) = if hs.get_is_initiator() {
(x, y)
@@ -1036,9 +1041,10 @@ impl Transport {
let sk = s.extract().0;
let rk = r.extract().0;
- Transport {
+ let t = Arc::new(Transport {
self_id: self_id,
peer_id: peer_id,
+ should_handshake: AtomicBool::new(false),
is_initiator: hs.get_is_initiator(),
is_initiator_or_has_received: AtomicBool::new(hs.get_is_initiator()),
not_too_old: AtomicBool::new(true),
@@ -1047,7 +1053,38 @@ impl Transport {
created: Instant::now(),
recv_ar: Mutex::new(AntiReplay::new()),
send_counter: AtomicU64::new(0),
+ rekey_after_time: Mutex::new(TimerHandle::dummy()),
+ reject_after_time: Mutex::new(TimerHandle::dummy()),
+ });
+
+ let w = Arc::downgrade(&t);
+
+ if t.is_initiator {
+ let w = w.clone();
+ let delay = Duration::from_secs(REKEY_AFTER_TIME);
+ let r = CONTROLLER.register_delay(delay, Box::new(move || {
+ debug!("Timer: mark should handshake.");
+ w.upgrade().map(|t| {
+ t.should_handshake.store(true, Relaxed);
+ });
+ }));
+ r.activate();
+ *t.rekey_after_time.lock().unwrap() = r;
}
+
+ {
+ let delay = Duration::from_secs(REJECT_AFTER_TIME);
+ let r = CONTROLLER.register_delay(delay, Box::new(move || {
+ debug!("Timer: mark too old.");
+ w.upgrade().map(|t| {
+ t.not_too_old.store(false, Relaxed);
+ });
+ }));
+ r.activate();
+ *t.reject_after_time.lock().unwrap() = r;
+ }
+
+ t
}
fn get_should_send(&self) -> bool {
@@ -1065,23 +1102,16 @@ impl Transport {
/// Length: out.len() = msg.len() + 32.
fn encrypt(&self, msg: &[u8], out: &mut [u8]) -> (Result<(), ()>, bool) {
let c = self.send_counter.fetch_add(1, Relaxed);
- let mut should_rekey = false;
- if self.is_initiator && c >= REKEY_AFTER_MESSAGES {
- should_rekey = true;
- }
- if c >= REJECT_AFTER_MESSAGES {
- self.not_too_old.store(false, Relaxed);
- return (Err(()), should_rekey);
- }
-
- let age = self.created.elapsed();
+ let mut should_rekey = self.should_handshake.load(Relaxed);
- if age >= Duration::from_secs(REKEY_AFTER_TIME) {
+ // This is REALLY REALLY unlikely...
+ if c >= REKEY_AFTER_MESSAGES {
should_rekey = true;
- }
- if age >= Duration::from_secs(REJECT_AFTER_TIME) {
- self.not_too_old.store(false, Relaxed);
- return (Err(()), should_rekey);
+ // Even more unlikely...
+ if c >= REJECT_AFTER_MESSAGES {
+ self.not_too_old.store(false, Relaxed);
+ return (Err(()), should_rekey);
+ }
}
out[0..4].copy_from_slice(&[4, 0, 0, 0]);
@@ -1097,19 +1127,15 @@ impl Transport {
///
/// Length: out.len() + 32 = msg.len().
fn decrypt(&self, msg: &[u8], out: &mut [u8]) -> Result<(), ()> {
- if self.created.elapsed() >= Duration::from_secs(REJECT_AFTER_TIME) {
- return Err(());
- }
-
if msg.len() < 32 {
return Err(());
}
- if msg[0..4] != [4, 0, 0, 0] {
+ if !self.not_too_old.load(Relaxed) {
return Err(());
}
- if self.created.elapsed() >= Duration::from_secs(REJECT_AFTER_TIME) {
+ if msg[0..4] != [4, 0, 0, 0] {
return Err(());
}
diff --git a/src/protocol/timer.rs b/src/protocol/timer.rs
index 5b8cf13..3991503 100644
--- a/src/protocol/timer.rs
+++ b/src/protocol/timer.rs
@@ -28,6 +28,11 @@ type Action = Box<Fn() + Send + Sync>;
lazy_static! {
/// Global timer controller.
pub static ref CONTROLLER: TimerController = TimerController::new();
+ static ref DUMMY: ArcTimer = ArcTimer(Arc::new(Timer {
+ activated: AtomicBool::new(false),
+ rounds: AtomicUsize::new(0),
+ action: Box::new(|| {}),
+ }));
}
struct Timer {
@@ -139,11 +144,7 @@ impl TimerHandle {
pub fn dummy() -> Self {
TimerHandle {
pos: AtomicUsize::new(0),
- timer: ArcTimer(Arc::new(Timer {
- activated: AtomicBool::new(false),
- rounds: AtomicUsize::new(0),
- action: Box::new(|| {}),
- })),
+ timer: DUMMY.clone(),
}
}