aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorJake McGinty <me@jake.su>2018-05-17 19:43:29 -0700
committerJake McGinty <me@jake.su>2018-05-17 19:44:58 -0700
commitc8a8b6a568e2bb78bdd1470e0fba1fb360b769ee (patch)
treeecf545d13f9147e671e22890f5d4ce0ede7e2204 /src
parentglobal: don't directly rely on tokio 'meta' crate (diff)
downloadwireguard-rs-c8a8b6a568e2bb78bdd1470e0fba1fb360b769ee.tar.xz
wireguard-rs-c8a8b6a568e2bb78bdd1470e0fba1fb360b769ee.zip
peer_server: use unbounded channels, ratelimiter wip
unbounded channels are easier to deal with now, and bounded channels weren't actually doing anything useful.
Diffstat (limited to 'src')
-rw-r--r--src/interface/config.rs2
-rw-r--r--src/interface/mod.rs2
-rw-r--r--src/interface/peer_server.rs132
-rw-r--r--src/ratelimiter.rs2
-rw-r--r--src/udp/frame.rs6
5 files changed, 87 insertions, 57 deletions
diff --git a/src/interface/config.rs b/src/interface/config.rs
index 9ccebde..41762ea 100644
--- a/src/interface/config.rs
+++ b/src/interface/config.rs
@@ -145,7 +145,7 @@ pub struct ConfigurationService {
}
impl ConfigurationService {
- pub fn new(interface_name: &str, state: &SharedState, peer_server_tx: mpsc::Sender<ChannelMessage>, handle: &Handle) -> Result<Self, Error> {
+ pub fn new(interface_name: &str, state: &SharedState, peer_server_tx: mpsc::UnboundedSender<ChannelMessage>, handle: &Handle) -> Result<Self, Error> {
let config_path = Self::get_path(interface_name).unwrap();
let listener = UnixListener::bind(config_path.clone(), handle).unwrap();
diff --git a/src/interface/mod.rs b/src/interface/mod.rs
index 21984ce..d40061a 100644
--- a/src/interface/mod.rs
+++ b/src/interface/mod.rs
@@ -96,7 +96,7 @@ impl Interface {
pub fn start(&mut self) -> Result<(), Error> {
let mut core = Core::new()?;
- let (utun_tx, utun_rx) = unsync::mpsc::channel::<Vec<u8>>(1024);
+ let (utun_tx, utun_rx) = unsync::mpsc::unbounded::<Vec<u8>>();
let peer_server = PeerServer::new(core.handle(), self.state.clone(), utun_tx.clone())?;
let config_server = ConfigurationService::new(&self.name, &self.state, peer_server.tx(), &core.handle())?.map_err(|_|());
diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs
index ebcb01b..fe4797c 100644
--- a/src/interface/peer_server.rs
+++ b/src/interface/peer_server.rs
@@ -4,16 +4,18 @@ use cookie;
use interface::{SharedPeer, SharedState, State, UtunPacket};
use message::{Message, Initiation, Response, CookieReply, Transport};
use peer::{Peer, SessionType, SessionTransition};
+use ratelimiter::RateLimiter;
use timestamp::Timestamp;
use timer::{Timer, TimerMessage};
use byteorder::{ByteOrder, LittleEndian};
use failure::{Error, err_msg};
-use futures::{Async, Future, Stream, Sink, Poll, unsync::mpsc};
+use futures::{Async, Future, Stream, Poll, unsync::mpsc, task};
use rand::{self, Rng, ThreadRng};
use udp::{Endpoint, UdpSocket, PeerServerMessage, UdpChannel};
use tokio_core::reactor::Handle;
+use std::collections::VecDeque;
use std::convert::TryInto;
use std::rc::Rc;
@@ -27,12 +29,12 @@ pub enum ChannelMessage {
}
struct Channel<T> {
- tx: mpsc::Sender<T>,
- rx: mpsc::Receiver<T>,
+ tx: mpsc::UnboundedSender<T>,
+ rx: mpsc::UnboundedReceiver<T>,
}
-impl<T> From<(mpsc::Sender<T>, mpsc::Receiver<T>)> for Channel<T> {
- fn from(pair: (mpsc::Sender<T>, mpsc::Receiver<T>)) -> Self {
+impl<T> From<(mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>)> for Channel<T> {
+ fn from(pair: (mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>)) -> Self {
Self {
tx: pair.0,
rx: pair.1,
@@ -41,30 +43,36 @@ impl<T> From<(mpsc::Sender<T>, mpsc::Receiver<T>)> for Channel<T> {
}
pub struct PeerServer {
- handle : Handle,
- shared_state : SharedState,
- udp : Option<UdpChannel>,
- port : Option<u16>,
- outgoing : Channel<UtunPacket>,
- channel : Channel<ChannelMessage>,
- timer : Timer,
- tunnel_tx : mpsc::Sender<Vec<u8>>,
- cookie : cookie::Validator,
- rng : ThreadRng,
+ handle : Handle,
+ shared_state : SharedState,
+ udp : Option<UdpChannel>,
+ port : Option<u16>,
+ outgoing : Channel<UtunPacket>,
+ channel : Channel<ChannelMessage>,
+ handshakes : VecDeque<(Endpoint, Message)>,
+ timer : Timer,
+ tunnel_tx : mpsc ::UnboundedSender<Vec<u8>>,
+ cookie : cookie::Validator,
+ rate_limiter : RateLimiter,
+ under_load_until : Timestamp,
+ rng : ThreadRng,
}
impl PeerServer {
- pub fn new(handle: Handle, shared_state: SharedState, tunnel_tx: mpsc::Sender<Vec<u8>>) -> Result<Self, Error> {
+ pub fn new(handle: Handle, shared_state: SharedState, tunnel_tx: mpsc::UnboundedSender<Vec<u8>>) -> Result<Self, Error> {
Ok(PeerServer {
shared_state, tunnel_tx,
- handle : handle.clone(),
- timer : Timer::new(handle),
- udp : None,
- port : None,
- outgoing : mpsc::channel(1024).into(),
- channel : mpsc::channel(1024).into(),
- cookie : cookie::Validator::new(&[0u8; 32]),
- rng : rand::thread_rng()
+ handle : handle.clone(),
+ timer : Timer::new(handle.clone()),
+ udp : None,
+ port : None,
+ outgoing : mpsc::unbounded().into(),
+ channel : mpsc::unbounded().into(),
+ handshakes : VecDeque::new(),
+ cookie : cookie::Validator::new(&[0u8; 32]),
+ rate_limiter : RateLimiter::new(&handle)?,
+ under_load_until : Timestamp::default(),
+ rng : rand::thread_rng()
})
}
@@ -77,8 +85,8 @@ impl PeerServer {
return Ok(());
}
- let port = interface.listen_port.unwrap_or(0);
- let fwmark = interface.fwmark.unwrap_or(0);
+ let port = interface.listen_port.unwrap_or(0);
+ let fwmark = interface.fwmark.unwrap_or(0);
if self.port.is_some() && self.port.unwrap() == port {
debug!("skipping rebind, since we're already listening on the correct port.");
@@ -100,11 +108,11 @@ impl PeerServer {
Ok(())
}
- pub fn tunnel_tx(&self) -> mpsc::Sender<UtunPacket> {
+ pub fn tunnel_tx(&self) -> mpsc::UnboundedSender<UtunPacket> {
self.outgoing.tx.clone()
}
- pub fn tx(&self) -> mpsc::Sender<ChannelMessage> {
+ pub fn tx(&self) -> mpsc::UnboundedSender<ChannelMessage> {
self.channel.tx.clone()
}
@@ -114,8 +122,8 @@ impl PeerServer {
Ok(())
}
- fn send_to_tunnel(&self, packet: Vec<u8>) {
- self.handle.spawn(self.tunnel_tx.clone().send(packet).then(|_| Ok(())));
+ fn send_to_tunnel(&self, packet: Vec<u8>) -> Result<(), Error> {
+ self.tunnel_tx.unbounded_send(packet).map_err(|e| e.into())
}
fn unused_index(&mut self, state: &mut State) -> u32 {
@@ -130,12 +138,29 @@ impl PeerServer {
fn handle_ingress_packet(&mut self, addr: Endpoint, packet: Vec<u8>) -> Result<(), Error> {
trace!("got a UDP packet from {:?} of length {}, packet type {}", &addr, packet.len(), packet[0]);
- match packet.try_into()? {
- Message::Initiation(packet) => self.handle_ingress_handshake_init(addr, &packet),
- Message::Response(packet) => self.handle_ingress_handshake_resp(addr, &packet),
- Message::CookieReply(packet) => self.handle_ingress_cookie_reply(addr, &packet),
- Message::Transport(packet) => self.handle_ingress_transport(addr, &packet),
+ let message = packet.try_into()?;
+ if let Message::Transport(packet) = message {
+ self.handle_ingress_transport(addr, &packet)?;
+ } else {
+ self.queue_ingress_handshake(addr, message);
}
+ Ok(())
+ }
+
+ fn queue_ingress_handshake(&mut self, addr: Endpoint, message: Message) {
+ // TODO: max queue size management
+ self.handshakes.push_back((addr, message));
+ task::current().notify();
+ }
+
+ fn handle_ingress_handshake(&mut self, addr: Endpoint, message: &Message) -> Result<(), Error> {
+ match message {
+ Message::Initiation(ref packet) => self.handle_ingress_handshake_init(addr, packet)?,
+ Message::Response(ref packet) => self.handle_ingress_handshake_resp(addr, packet)?,
+ Message::CookieReply(ref packet) => self.handle_ingress_cookie_reply(addr, packet)?,
+ Message::Transport(_) => unreachable!("no transport packets allowed"),
+ }
+ Ok(())
}
fn handle_ingress_handshake_init(&mut self, addr: Endpoint, packet: &Initiation) -> Result<(), Error> {
@@ -253,7 +278,7 @@ impl PeerServer {
self.shared_state.borrow_mut().router.validate_source(&raw_packet, &peer_ref)?;
trace!("received transport packet");
- self.send_to_tunnel(raw_packet);
+ self.send_to_tunnel(raw_packet)?;
Ok(())
}
@@ -481,54 +506,59 @@ impl Future for PeerServer {
type Error = Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
- // Handle config events
+ // Poll inner Futures until at least one of them has returned a NotReady. It's not
+ // safe to return NotReady yourself unless at least one future has returned a NotReady.
loop {
+ let mut not_ready = false;
+ // Handle config events
match self.channel.rx.poll() {
Ok(Async::Ready(Some(event))) => {
let _ = self.handle_incoming_event(event);
},
- Ok(Async::NotReady) => break,
+ Ok(Async::NotReady) => { not_ready = true; },
Ok(Async::Ready(None)) => bail!("config stream ended unexpectedly"),
Err(e) => bail!("config stream error: {:?}", e),
}
- }
- // Handle pending state-changing timers
- loop {
+ // Handle pending state-changing timers
match self.timer.poll() {
Ok(Async::Ready(Some(message))) => {
let _ = self.handle_timer(message).map_err(|e| debug!("TIMER: {}", e));
},
- Ok(Async::NotReady) => break,
+ Ok(Async::NotReady) => { not_ready = true; },
Ok(Async::Ready(None)) => bail!("timer stream ended unexpectedly"),
Err(e) => bail!("timer stream error: {:?}", e),
}
- }
- // Handle UDP packets from the outside world
- if self.udp.is_some() {
- loop {
+ // Handle UDP packets from the outside world
+ if self.udp.is_some() {
match self.udp.as_mut().unwrap().ingress.poll() {
Ok(Async::Ready(Some((addr, packet)))) => {
let _ = self.handle_ingress_packet(addr, packet).map_err(|e| warn!("UDP ERR: {:?}", e));
},
- Ok(Async::NotReady) => break,
+ Ok(Async::NotReady) => { not_ready = true; },
Ok(Async::Ready(None)) => bail!("incoming udp stream ended unexpectedly"),
Err(e) => bail!("incoming udp stream error: {:?}", e)
}
}
- }
- // Handle packets coming from the local tunnel
- loop {
+ // Handle packets coming from the local tunnel
match self.outgoing.rx.poll() {
Ok(Async::Ready(Some(packet))) => {
let _ = self.handle_egress_packet(packet).map_err(|e| warn!("UDP ERR: {:?}", e));
},
- Ok(Async::NotReady) => break,
+ Ok(Async::NotReady) => { not_ready = true; },
Ok(Async::Ready(None)) => bail!("outgoing udp stream ended unexpectedly"),
Err(e) => bail!("outgoing udp stream error: {:?}", e),
}
+
+ if not_ready {
+ break;
+ }
+ }
+
+ if let Some((addr, message)) = self.handshakes.pop_front() {
+ let _ = self.handle_ingress_handshake(addr, &message).map_err(|e| warn!("handshake err: {:?}", e));
}
Ok(Async::NotReady)
diff --git a/src/ratelimiter.rs b/src/ratelimiter.rs
index ee916fc..bdc5c7c 100644
--- a/src/ratelimiter.rs
+++ b/src/ratelimiter.rs
@@ -24,7 +24,7 @@ struct Entry {
pub tokens : u64,
}
-struct RateLimiter {
+pub struct RateLimiter {
table : HashMap<IpAddr, Entry>,
rx : mpsc::Receiver<()>,
}
diff --git a/src/udp/frame.rs b/src/udp/frame.rs
index b0f1dd4..f2f5f00 100644
--- a/src/udp/frame.rs
+++ b/src/udp/frame.rs
@@ -153,7 +153,7 @@ impl VecUdpCodec {
pub struct UdpChannel {
pub ingress : stream::SplitStream<UdpFramed>,
- pub egress : mpsc::Sender<PeerServerMessage>,
+ pub egress : mpsc::UnboundedSender<PeerServerMessage>,
pub fd4 : RawFd,
pub fd6 : RawFd,
handle : Handle,
@@ -165,7 +165,7 @@ impl From<UdpFramed> for UdpChannel {
let fd6 = framed.socket.as_raw_fd_v6();
let handle = framed.socket.handle.clone();
let (udp_sink, ingress) = framed.split();
- let (egress, egress_rx) = mpsc::channel(1024);
+ let (egress, egress_rx) = mpsc::unbounded();
let udp_writethrough = udp_sink
.sink_map_err(|_| ())
.send_all(egress_rx.and_then(|(addr, packet)| {
@@ -183,7 +183,7 @@ impl From<UdpFramed> for UdpChannel {
impl UdpChannel {
pub fn send(&self, message: PeerServerMessage) {
- self.handle.spawn(self.egress.clone().send(message).then(|_| Ok(())));
+ self.egress.clone().unbounded_send(message);
}
#[cfg(target_os = "linux")]