aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJake McGinty <me@jake.su>2018-05-03 23:42:29 -0700
committerJake McGinty <me@jake.su>2018-05-03 23:42:38 -0700
commit9f5b12d3b8967bee22491515731d950d8d5220e4 (patch)
treec0ce4edf81218863f7b4566d5a61e1035ef749ee
parenttimers: rewrite persistent keepalive code (diff)
downloadwireguard-rs-9f5b12d3b8967bee22491515731d950d8d5220e4.tar.xz
wireguard-rs-9f5b12d3b8967bee22491515731d950d8d5220e4.zip
timers: more corrections to persistent keepalive
-rw-r--r--src/interface/config.rs23
-rw-r--r--src/interface/mod.rs3
-rw-r--r--src/interface/peer_server.rs75
-rw-r--r--src/peer.rs2
-rw-r--r--src/timer.rs34
5 files changed, 74 insertions, 63 deletions
diff --git a/src/interface/config.rs b/src/interface/config.rs
index 882b6cb..9ccebde 100644
--- a/src/interface/config.rs
+++ b/src/interface/config.rs
@@ -142,14 +142,12 @@ pub struct ConfigurationService {
interface_name: String,
config_server: Box<Future<Item = (), Error = ()>>,
reaper: Box<Future<Item = (), Error = ()>>,
- rx: mpsc::Receiver<ChannelMessage>
}
impl ConfigurationService {
- pub fn new(interface_name: &str, state: &SharedState, handle: &Handle) -> Result<Self, Error> {
+ pub fn new(interface_name: &str, state: &SharedState, peer_server_tx: mpsc::Sender<ChannelMessage>, handle: &Handle) -> Result<Self, Error> {
let config_path = Self::get_path(interface_name).unwrap();
let listener = UnixListener::bind(config_path.clone(), handle).unwrap();
- let (tx, rx) = mpsc::channel::<ChannelMessage>(1024);
// TODO only listen for own socket, verify behavior from `notify` crate
let reaper = GrimReaper::spawn(handle, &config_path).unwrap();
@@ -163,7 +161,7 @@ impl ConfigurationService {
let handle = handle.clone();
let responses = stream.and_then({
- let tx = tx.clone();
+ let tx = peer_server_tx.clone();
let state = state.clone();
move |command| {
let mut state = state.borrow_mut();
@@ -211,7 +209,6 @@ impl ConfigurationService {
interface_name: interface_name.to_owned(),
config_server: Box::new(config_server),
reaper: Box::new(reaper),
- rx
})
}
@@ -265,7 +262,7 @@ impl ConfigurationService {
info.allowed_ips.extend_from_slice(&peer.info.allowed_ips);
}
let ret = if info.keepalive.is_some() && peer.info.keepalive != info.keepalive {
- Some(ChannelMessage::NewPersistentKeepalive(info.keepalive.unwrap()))
+ Some(ChannelMessage::NewPersistentKeepalive(peer_ref.clone()))
} else {
None
};
@@ -292,7 +289,7 @@ impl ConfigurationService {
let peer_ref = Rc::new(RefCell::new(peer));
let _ = state.pubkey_map.insert(info.pub_key, peer_ref.clone());
state.router.add_allowed_ips(&info.allowed_ips, &peer_ref);
- Ok(None) // TODO: notify specifically on details of these new peers
+ Ok(Some(ChannelMessage::NewPeer(peer_ref)))
}
},
UpdateEvent::RemoveAllPeers => {
@@ -355,11 +352,11 @@ impl ConfigurationService {
}
}
-impl Stream for ConfigurationService {
- type Item = ChannelMessage;
+impl Future for ConfigurationService {
+ type Item = ();
type Error = Error;
- fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
+ fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.config_server.poll() {
Ok(Async::NotReady) => {},
_ => return Err(err_msg("config_server broken")),
@@ -374,11 +371,7 @@ impl Stream for ConfigurationService {
},
}
- match self.rx.poll() {
- Ok(Async::Ready(None)) | Err(_) => Err(err_msg("err in config rx channel")),
- Ok(Async::Ready(msg)) => Ok(Async::Ready(msg)),
- Ok(Async::NotReady) => Ok(Async::NotReady)
- }
+ Ok(Async::NotReady)
}
}
diff --git a/src/interface/mod.rs b/src/interface/mod.rs
index 0d12048..8d939f6 100644
--- a/src/interface/mod.rs
+++ b/src/interface/mod.rs
@@ -98,8 +98,7 @@ impl Interface {
let (utun_tx, utun_rx) = unsync::mpsc::channel::<Vec<u8>>(1024);
let peer_server = PeerServer::new(core.handle(), self.state.clone(), utun_tx.clone())?;
- let config_server = ConfigurationService::new(&self.name, &self.state, &core.handle())?;
- let config_server = config_server.forward(peer_server.tx()).map_err(|_|()); // TODO: don't just forward, this is so hacky.
+ let config_server = ConfigurationService::new(&self.name, &self.state, peer_server.tx(), &core.handle())?.map_err(|_|());
let utun_stream = UtunStream::connect(&self.name, &core.handle())?.framed(VecUtunCodec{});
let (utun_writer, utun_reader) = utun_stream.split();
diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs
index 7ad6d15..31b004b 100644
--- a/src/interface/peer_server.rs
+++ b/src/interface/peer_server.rs
@@ -14,14 +14,15 @@ use rand::{self, Rng};
use udp::{Endpoint, UdpSocket, PeerServerMessage, UdpChannel};
use tokio_core::reactor::Handle;
-use std::{convert::TryInto, time::Duration};
+use std::convert::TryInto;
pub enum ChannelMessage {
ClearPrivateKey,
NewPrivateKey,
NewListenPort(u16),
NewFwmark(u32),
- NewPersistentKeepalive(u16),
+ NewPersistentKeepalive(SharedPeer),
+ NewPeer(SharedPeer),
}
struct Channel<T> {
@@ -165,7 +166,6 @@ impl PeerServer {
Ok(())
}
- // TODO use the address to update endpoint if it changes i suppose
fn handle_ingress_handshake_resp(&mut self, addr: Endpoint, packet: &Response) -> Result<(), Error> {
ensure!(packet.len() == 92, "handshake resp packet length is incorrect");
let mut state = self.shared_state.borrow_mut();
@@ -200,17 +200,6 @@ impl PeerServer {
info!("handshake response received, current session now {}", our_index);
self.timer.send_after(*WIPE_AFTER_TIME, TimerMessage::Wipe(peer_ref.clone()));
-
- match peer.info.keepalive {
- Some(keepalive) if keepalive > 0 => {
- self.timer.send_after(Duration::from_secs(u64::from(keepalive)),
- TimerMessage::PersistentKeepAlive(peer_ref.clone(), our_index));
- },
- _ => {
- self.timer.send_after(*KEEPALIVE_TIMEOUT,
- TimerMessage::PassiveKeepAlive(peer_ref.clone(), our_index));
- }
- }
Ok(())
}
@@ -245,8 +234,6 @@ impl PeerServer {
}
}
- let our_new_index = peer.sessions.current.as_ref().unwrap().our_index;
- self.timer.send_after(*KEEPALIVE_TIMEOUT, TimerMessage::PassiveKeepAlive(peer_ref.clone(), our_new_index));
self.timer.send_after(*WIPE_AFTER_TIME, TimerMessage::Wipe(peer_ref.clone()));
}
(raw_packet, peer.needs_new_handshake(false))
@@ -355,29 +342,29 @@ impl PeerServer {
let new_index = self.send_handshake_init(&peer_ref)?;
debug!("sent handshake init (Rekey timer) ({} -> {})", our_index, new_index);
},
- PassiveKeepAlive(peer_ref, our_index) => {
+ PassiveKeepAlive(peer_ref) => {
let mut peer = peer_ref.borrow_mut();
{
if peer.sessions.current.is_none() {
- self.timer.send_after(*KEEPALIVE_TIMEOUT, PassiveKeepAlive(peer_ref.clone(), our_index));
- bail!("no active session. waiting until there is one.");
+ self.timer.send_after(*KEEPALIVE_TIMEOUT, PassiveKeepAlive(peer_ref.clone()));
+ bail!("passive keepalive skip: no active session. waiting until there is one.");
} else if peer.info.keepalive.is_some() {
- self.timer.send_after(*KEEPALIVE_TIMEOUT, PassiveKeepAlive(peer_ref.clone(), our_index));
- bail!("persistent keepalive set, no passive keepalive needed.");
+ self.timer.send_after(*KEEPALIVE_TIMEOUT, PassiveKeepAlive(peer_ref.clone()));
+ bail!("passive keepalive skip: persistent keepalive set.");
}
let since_last_recv = peer.timers.data_received.elapsed();
let since_last_send = peer.timers.data_sent.elapsed();
if since_last_recv < *KEEPALIVE_TIMEOUT {
let wait = *KEEPALIVE_TIMEOUT - since_last_recv;
- self.timer.send_after(wait, PassiveKeepAlive(peer_ref.clone(), our_index));
+ self.timer.send_after(wait, PassiveKeepAlive(peer_ref.clone()));
bail!("passive keepalive tick (waiting ~{}s due to last recv time)", wait.as_secs());
} else if since_last_send < *KEEPALIVE_TIMEOUT {
let wait = *KEEPALIVE_TIMEOUT - since_last_send;
- self.timer.send_after(wait, PassiveKeepAlive(peer_ref.clone(), our_index));
+ self.timer.send_after(wait, PassiveKeepAlive(peer_ref.clone()));
bail!("passive keepalive tick (waiting ~{}s due to last send time)", wait.as_secs());
} else if peer.timers.keepalive_sent {
- self.timer.send_after(*KEEPALIVE_TIMEOUT, PassiveKeepAlive(peer_ref.clone(), our_index));
+ self.timer.send_after(*KEEPALIVE_TIMEOUT, PassiveKeepAlive(peer_ref.clone()));
bail!("passive keepalive already sent (waiting ~{}s to see if session survives)", KEEPALIVE_TIMEOUT.as_secs());
} else {
peer.timers.keepalive_sent = true;
@@ -385,24 +372,26 @@ impl PeerServer {
}
self.send_to_peer(peer.handle_outgoing_transport(&[])?)?;
- debug!("sent passive keepalive packet ({})", our_index);
+ debug!("sent passive keepalive packet");
- self.timer.send_after(*KEEPALIVE_TIMEOUT, PassiveKeepAlive(peer_ref.clone(), our_index));
+ self.timer.send_after(*KEEPALIVE_TIMEOUT, PassiveKeepAlive(peer_ref.clone()));
},
- PersistentKeepAlive(peer_ref, our_index) => {
+ PersistentKeepAlive(peer_ref) => {
let mut peer = peer_ref.borrow_mut();
if let Some(persistent_keepalive) = peer.info.persistent_keepalive() {
let since_last_auth_any = peer.timers.authenticated_traversed.elapsed();
if since_last_auth_any < persistent_keepalive {
let wait = persistent_keepalive - since_last_auth_any;
- self.timer.send_after(wait, PersistentKeepAlive(peer_ref.clone(), our_index));
+ let handle = self.timer.send_after(wait, PersistentKeepAlive(peer_ref.clone()));
+ peer.timers.persistent_timer = Some(handle);
bail!("persistent keepalive tick (waiting ~{}s due to last authenticated packet time)", wait.as_secs());
}
self.send_to_peer(peer.handle_outgoing_transport(&[])?)?;
- self.timer.send_after(persistent_keepalive, PersistentKeepAlive(peer_ref.clone(), our_index));
- debug!("sent persistent keepalive packet ({})", our_index);
+ let handle = self.timer.send_after(persistent_keepalive, PersistentKeepAlive(peer_ref.clone()));
+ peer.timers.persistent_timer = Some(handle);
+ debug!("sent persistent keepalive packet");
} else {
bail!("no persistent keepalive set for peer (likely unset between the time the timer was started and now).");
}
@@ -438,10 +427,32 @@ impl PeerServer {
self.port = None;
}
},
- NewListenPort(_) => self.rebind().unwrap(),
+ NewPeer(peer_ref) => {
+ let mut peer = peer_ref.borrow_mut();
+ self.timer.send_after(*KEEPALIVE_TIMEOUT, TimerMessage::PassiveKeepAlive(peer_ref.clone()));
+ if let Some(keepalive) = peer.info.persistent_keepalive() {
+ let handle = self.timer.send_after(keepalive, TimerMessage::PersistentKeepAlive(peer_ref.clone()));
+ peer.timers.persistent_timer = Some(handle);
+ }
+ },
+ NewPersistentKeepalive(peer_ref) => {
+ let mut peer = peer_ref.borrow_mut();
+ if let Some(ref mut handle) = peer.timers.persistent_timer {
+ handle.cancel();
+ debug!("sent cancel signal to old persistent_timer.");
+ }
+
+ if let Some(keepalive) = peer.info.persistent_keepalive() {
+ let handle = self.timer.send_after(keepalive, TimerMessage::PersistentKeepAlive(peer_ref.clone()));
+ peer.timers.persistent_timer = Some(handle);
+ self.send_to_peer(peer.handle_outgoing_transport(&[])?)?;
+ debug!("set new keepalive timer and immediately sent new keepalive packet.");
+ }
+ }
+ NewListenPort(_) => self.rebind()?,
NewFwmark(mark) => {
if let Some(ref udp) = self.udp {
- udp.set_mark(mark).unwrap();
+ udp.set_mark(mark)?;
}
}
_ => {}
diff --git a/src/peer.rs b/src/peer.rs
index 7b5107b..977c192 100644
--- a/src/peer.rs
+++ b/src/peer.rs
@@ -14,6 +14,7 @@ use std::collections::VecDeque;
use std::fmt::{self, Debug, Display, Formatter};
use std::time::{SystemTime, UNIX_EPOCH};
use hex;
+use timer::TimerHandle;
use timestamp::{Tai64n, Timestamp};
use snow;
use types::PeerInfo;
@@ -55,6 +56,7 @@ pub struct Timers {
pub egress_queued : Timestamp,
pub handshake_completed : Timestamp,
pub handshake_initialized : Timestamp,
+ pub persistent_timer : Option<TimerHandle>,
pub keepalive_sent : bool
}
diff --git a/src/timer.rs b/src/timer.rs
index ca00e29..5afbdc0 100644
--- a/src/timer.rs
+++ b/src/timer.rs
@@ -1,5 +1,6 @@
use consts::TIMER_RESOLUTION;
-use futures::{Async, Future, Stream, Sink, Poll, unsync};
+use futures::{Future, Stream, Sink, Poll, unsync};
+use std::{cell::RefCell, rc::Rc};
use std::time::{Instant, Duration};
use tokio::timer::Delay;
use tokio_core::reactor::Handle;
@@ -7,19 +8,19 @@ use interface::SharedPeer;
#[derive(Debug)]
pub enum TimerMessage {
- PersistentKeepAlive(SharedPeer, u32),
- PassiveKeepAlive(SharedPeer, u32),
+ PersistentKeepAlive(SharedPeer),
+ PassiveKeepAlive(SharedPeer),
Rekey(SharedPeer, u32),
Wipe(SharedPeer),
}
pub struct TimerHandle {
- tx: unsync::oneshot::Sender<()>
+ canceled: Rc<RefCell<bool>>
}
impl TimerHandle {
- pub fn cancel(self) -> Result<(), ()> {
- self.tx.send(())
+ pub fn cancel(&mut self) {
+ *self.canceled.borrow_mut() = true;
}
}
@@ -37,18 +38,23 @@ impl Timer {
pub fn send_after(&mut self, delay: Duration, message: TimerMessage) -> TimerHandle {
trace!("queuing timer message {:?}", &message);
- let (cancel_tx, mut cancel_rx) = unsync::oneshot::channel();
+ let canceled = Rc::new(RefCell::new(false));
+ let handle = self.handle.clone();
let tx = self.tx.clone();
let future = Delay::new(Instant::now() + delay + (*TIMER_RESOLUTION * 2))
.map_err(|e| panic!("timer failed; err={:?}", e))
- .and_then(move |_| {
- if let Ok(Async::Ready(())) = cancel_rx.poll() {
- trace!("timer cancel signal sent, won't send message.");
- }
- tx.send(message).then(|_| Ok(()))
- });
+ .and_then({
+ let canceled = canceled.clone();
+ move |_| {
+ if !*canceled.borrow() {
+ handle.spawn(tx.send(message).then(|_| Ok(())))
+ } else {
+ debug!("timer cancel signal sent, won't send message.");
+ }
+ Ok(())
+ }});
self.handle.spawn(future);
- TimerHandle { tx: cancel_tx }
+ TimerHandle { canceled }
}
}