diff options
author | Jake McGinty <me@jake.su> | 2018-02-19 22:05:51 +0000 |
---|---|---|
committer | Jake McGinty <me@jake.su> | 2018-02-19 22:07:10 +0000 |
commit | afb8840cb87186b856fcbbae6baf46cca9d349aa (patch) | |
tree | a016fb6f28594f7bf49e4ddc57f8d52944d77aef /src/interface/peer_server.rs | |
parent | give each peer their own packet queue (diff) | |
download | wireguard-rs-afb8840cb87186b856fcbbae6baf46cca9d349aa.tar.xz wireguard-rs-afb8840cb87186b856fcbbae6baf46cca9d349aa.zip |
finish per-peer egress queues + method refactor
Diffstat (limited to 'src/interface/peer_server.rs')
-rw-r--r-- | src/interface/peer_server.rs | 284 |
1 files changed, 148 insertions, 136 deletions
diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs index 2e2b08f..0f5edbd 100644 --- a/src/interface/peer_server.rs +++ b/src/interface/peer_server.rs @@ -1,4 +1,4 @@ -use super::{SharedState, UtunPacket, trace_packet}; +use super::{SharedState, UtunPacket}; use consts::{REKEY_TIMEOUT, REKEY_AFTER_TIME, REJECT_AFTER_TIME, REKEY_ATTEMPT_TIME, KEEPALIVE_TIMEOUT, MAX_CONTENT_SIZE, TIMER_TICK_DURATION}; use cookie; use interface::SharedPeer; @@ -11,7 +11,7 @@ use std::time::{Duration, Instant}; use byteorder::{ByteOrder, LittleEndian}; use failure::{Error, err_msg}; -use futures::{self, Async, Future, Stream, Sink, Poll, unsync, stream}; +use futures::{Async, Future, Stream, Sink, Poll, unsync::mpsc, stream}; use socket2::{Socket, Domain, Type, Protocol}; use tokio_core::net::{UdpSocket, UdpCodec, UdpFramed}; use tokio_core::reactor::Handle; @@ -50,29 +50,28 @@ impl UdpCodec for VecUdpCodec { } pub struct PeerServer { - handle: Handle, - shared_state: SharedState, - udp_stream: stream::SplitStream<UdpFramed<VecUdpCodec>>, - timer: Timer, - outgoing_tx: unsync::mpsc::Sender<UtunPacket>, - outgoing_rx: futures::stream::Peekable<unsync::mpsc::Receiver<UtunPacket>>, - udp_tx: unsync::mpsc::Sender<(SocketAddr, Vec<u8>)>, - tunnel_tx: unsync::mpsc::Sender<Vec<u8>>, + handle : Handle, + shared_state : SharedState, + udp_stream : stream::SplitStream<UdpFramed<VecUdpCodec>>, + timer : Timer, + outgoing_tx : mpsc::Sender<UtunPacket>, + outgoing_rx : mpsc::Receiver<UtunPacket>, + udp_tx : mpsc::Sender<(SocketAddr, Vec<u8>)>, + tunnel_tx : mpsc::Sender<Vec<u8>>, } impl PeerServer { - pub fn bind(handle: Handle, shared_state: SharedState, tunnel_tx: unsync::mpsc::Sender<Vec<u8>>) -> Result<Self, Error> { - let port = shared_state.borrow().interface_info.listen_port.unwrap_or(0); + pub fn bind(handle: Handle, shared_state: SharedState, tunnel_tx: mpsc::Sender<Vec<u8>>) -> Result<Self, Error> { + let timer = Timer::default(); + let port = shared_state.borrow().interface_info.listen_port.unwrap_or(0); let socket = Socket::new(Domain::ipv6(), Type::dgram(), Some(Protocol::udp()))?; socket.set_only_v6(false)?; socket.set_nonblocking(true)?; socket.bind(&SocketAddr::from((Ipv6Addr::unspecified(), port)).into())?; - let timer = Timer::default(); let socket = UdpSocket::from_socket(socket.into_udp_socket(), &handle.clone())?; let (udp_sink, udp_stream) = socket.framed(VecUdpCodec{}).split(); - let (udp_tx, udp_rx) = unsync::mpsc::channel::<(SocketAddr, Vec<u8>)>(1024); - let (outgoing_tx, outgoing_rx) = unsync::mpsc::channel::<UtunPacket>(1024); - let outgoing_rx = outgoing_rx.peekable(); + let (udp_tx, udp_rx) = mpsc::channel::<(SocketAddr, Vec<u8>)>(1024); + let (outgoing_tx, outgoing_rx) = mpsc::channel::<UtunPacket>(1024); let udp_write_passthrough = udp_sink.sink_map_err(|_| ()).send_all( udp_rx.map(|(addr, packet)| { @@ -87,7 +86,7 @@ impl PeerServer { }) } - pub fn tx(&self) -> unsync::mpsc::Sender<UtunPacket> { + pub fn tx(&self) -> mpsc::Sender<UtunPacket> { self.outgoing_tx.clone() } @@ -99,101 +98,130 @@ impl PeerServer { self.handle.spawn(self.tunnel_tx.clone().send(packet).then(|_| Ok(()))); } - fn handle_incoming_packet(&mut self, addr: SocketAddr, packet: &[u8]) -> Result<(), Error> { + fn handle_ingress_packet(&mut self, addr: SocketAddr, packet: &[u8]) -> Result<(), Error> { trace!("got a UDP packet from {:?} of length {}, packet type {}", &addr, packet.len(), packet[0]); - let mut state = self.shared_state.borrow_mut(); match packet[0] { - 1 => { - ensure!(packet.len() == 148, "handshake init packet length is incorrect"); - { - let pubkey = state.interface_info.pub_key.as_ref() - .ok_or_else(|| err_msg("must have local interface key"))?; - let (mac_in, mac_out) = packet.split_at(116); - cookie::verify_mac1(pubkey, mac_in, &mac_out[..16])?; - } + 1 => self.handle_ingress_handshake_init(addr, packet), + 2 => self.handle_ingress_handshake_resp(addr, packet), + 3 => bail!("cookie messages not yet supported."), + 4 => self.handle_ingress_transport(addr, packet), + _ => bail!("unknown wireguard message type") + } + } + + fn handle_ingress_handshake_init(&mut self, addr: SocketAddr, packet: &[u8]) -> Result<(), Error> { + ensure!(packet.len() == 148, "handshake init packet length is incorrect"); + let mut state = self.shared_state.borrow_mut(); + { + let pubkey = state.interface_info.pub_key.as_ref() + .ok_or_else(|| err_msg("must have local interface key"))?; + let (mac_in, mac_out) = packet.split_at(116); + cookie::verify_mac1(pubkey, mac_in, &mac_out[..16])?; + } - debug!("got handshake initiation request (0x01)"); + debug!("got handshake initiation request (0x01)"); - let handshake = Peer::process_incoming_handshake( - &state.interface_info.private_key.ok_or_else(|| err_msg("no private key!"))?, - packet)?; + let handshake = Peer::process_incoming_handshake( + &state.interface_info.private_key.ok_or_else(|| err_msg("no private key!"))?, + packet)?; - let peer_ref = state.pubkey_map.get(handshake.their_pubkey()) - .ok_or_else(|| err_msg("unknown peer pubkey"))?.clone(); + let peer_ref = state.pubkey_map.get(handshake.their_pubkey()) + .ok_or_else(|| err_msg("unknown peer pubkey"))?.clone(); - let mut peer = peer_ref.borrow_mut(); - let (response, next_index) = peer.complete_incoming_handshake(addr, handshake)?; - let _ = state.index_map.insert(next_index, peer_ref.clone()); + let mut peer = peer_ref.borrow_mut(); + let (response, next_index) = peer.complete_incoming_handshake(addr, handshake)?; + let _ = state.index_map.insert(next_index, peer_ref.clone()); - self.send_to_peer((addr, response)); - info!("sent handshake response, ratcheted session (index {}).", next_index); - }, - 2 => { - ensure!(packet.len() == 92, "handshake resp packet length is incorrect"); - { - let pubkey = state.interface_info.pub_key.as_ref() - .ok_or_else(|| err_msg("must have local interface key"))?; - let (mac_in, mac_out) = packet.split_at(60); - cookie::verify_mac1(pubkey, mac_in, &mac_out[..16])?; - } - debug!("got handshake response (0x02)"); + self.send_to_peer((addr, response)); + info!("sent handshake response, ratcheted session (index {}).", next_index); - let our_index = LittleEndian::read_u32(&packet[8..]); - let peer_ref = state.index_map.get(&our_index) - .ok_or_else(|| format_err!("unknown our_index ({})", our_index))? - .clone(); - let mut peer = peer_ref.borrow_mut(); - let dead_index = peer.process_incoming_handshake_response(packet)?; - if let Some(index) = dead_index { - let _ = state.index_map.remove(&index); - } - info!("handshake response received, current session now {}", our_index); + Ok(()) + } - // send empty packet to unblock peer from establishing secure session - // TODO: only do this if the tun queue is empty + // TODO use the address to update endpoint if it changes i suppose + fn handle_ingress_handshake_resp(&mut self, _addr: SocketAddr, packet: &[u8]) -> Result<(), Error> { + ensure!(packet.len() == 92, "handshake resp packet length is incorrect"); + let mut state = self.shared_state.borrow_mut(); + { + let pubkey = state.interface_info.pub_key.as_ref() + .ok_or_else(|| err_msg("must have local interface key"))?; + let (mac_in, mac_out) = packet.split_at(60); + cookie::verify_mac1(pubkey, mac_in, &mac_out[..16])?; + } + debug!("got handshake response (0x02)"); + + let our_index = LittleEndian::read_u32(&packet[8..]); + let peer_ref = state.index_map.get(&our_index) + .ok_or_else(|| format_err!("unknown our_index ({})", our_index))? + .clone(); + let mut peer = peer_ref.borrow_mut(); + let dead_index = peer.process_incoming_handshake_response(packet)?; + if let Some(index) = dead_index { + let _ = state.index_map.remove(&index); + } + if peer.ready_for_transport() { + if !peer.outgoing_queue.is_empty() { + debug!("sending {} queued egress packets", peer.outgoing_queue.len()); + while let Some(packet) = peer.outgoing_queue.pop_front() { + self.send_to_peer(peer.handle_outgoing_transport(packet.payload())?); + } + } else { self.send_to_peer(peer.handle_outgoing_transport(&[])?); + } + } else { + error!("peer not ready for transport after processing handshake response. this shouldn't happen."); + } + info!("handshake response received, current session now {}", our_index); - self.timer.spawn_delayed(&self.handle, - *KEEPALIVE_TIMEOUT, - TimerMessage::PassiveKeepAlive(peer_ref.clone(), our_index)); + self.timer.spawn_delayed(&self.handle, + *KEEPALIVE_TIMEOUT, + TimerMessage::PassiveKeepAlive(peer_ref.clone(), our_index)); - self.timer.spawn_delayed(&self.handle, - *REJECT_AFTER_TIME, - TimerMessage::Reject(peer_ref.clone(), our_index)); + self.timer.spawn_delayed(&self.handle, + *REJECT_AFTER_TIME, + TimerMessage::Reject(peer_ref.clone(), our_index)); - if let Some(persistent_keep_alive) = peer.info.keep_alive_interval { - self.timer.spawn_delayed(&self.handle, - Duration::from_secs(u64::from(persistent_keep_alive)), - TimerMessage::PersistentKeepAlive(peer_ref.clone(), our_index)); - } - }, - 3 => { - warn!("cookie messages not yet implemented."); - }, - 4 => { - let our_index_received = LittleEndian::read_u32(&packet[4..]); - let peer_ref = state.index_map.get(&our_index_received) - .ok_or_else(|| err_msg("unknown our_index"))? - .clone(); + if let Some(persistent_keep_alive) = peer.info.keep_alive_interval { + self.timer.spawn_delayed(&self.handle, + Duration::from_secs(u64::from(persistent_keep_alive)), + TimerMessage::PersistentKeepAlive(peer_ref.clone(), our_index)); + } + Ok(()) + } - let (raw_packet, dead_index) = peer_ref.borrow_mut().handle_incoming_transport(addr, packet)?; + fn handle_ingress_transport(&mut self, addr: SocketAddr, packet: &[u8]) -> Result<(), Error> { + let mut state = self.shared_state.borrow_mut(); + let our_index = LittleEndian::read_u32(&packet[4..]); + let peer_ref = state.index_map.get(&our_index).ok_or_else(|| err_msg("unknown our_index"))?.clone(); + let raw_packet = { + let mut peer = peer_ref.borrow_mut(); + let (raw_packet, transition) = peer.handle_incoming_transport(addr, packet)?; - if let Some(index) = dead_index { + if let Some(possible_dead_index) = transition { + if let Some(index) = possible_dead_index { let _ = state.index_map.remove(&index); } - if raw_packet.is_empty() { - debug!("received keepalive."); - return Ok(()) // short-circuit on keep-alives - } + let outgoing: Vec<UtunPacket> = peer.outgoing_queue.drain(..).collect(); - state.router.validate_source(&raw_packet, &peer_ref)?; + for packet in outgoing { + match peer.handle_outgoing_transport(packet.payload()) { + Ok(message) => self.send_to_peer(message), + Err(e) => warn!("failed to encrypt packet: {}", e) + } + } + } + raw_packet + }; - trace_packet("received TRANSPORT: ", &raw_packet); - self.send_to_tunnel(raw_packet); - }, - _ => bail!("unknown wireguard message type") + if raw_packet.is_empty() { + debug!("received keepalive."); + return Ok(()) // short-circuit on keep-alives } + + state.router.validate_source(&raw_packet, &peer_ref)?; + trace!("received transport packet"); + self.send_to_tunnel(raw_packet); Ok(()) } @@ -327,52 +355,33 @@ impl PeerServer { } // Just this way to avoid a double-mutable-borrow while peeking. - fn peek_from_tun_and_handle(&mut self) -> Result<bool, Error> { - enum Decision { Drop, Wait, Handshake(SharedPeer), Transport((SocketAddr, Vec<u8>))} - let decision = { - let packet = match self.outgoing_rx.peek() { - Ok(Async::Ready(Some(packet))) => packet, - Ok(Async::NotReady) => return Ok(false), - Ok(Async::Ready(None)) | Err(_) => bail!("channel failure"), - }; - trace_packet("received UTUN packet: ", packet.payload()); - - let mut state = self.shared_state.borrow_mut(); - let peer_ref = state.router.route_to_peer(packet.payload()).ok_or_else(|| err_msg("no route to peer"))?; - let mut peer = peer_ref.borrow_mut(); - - if packet.payload().is_empty() || packet.payload().len() > MAX_CONTENT_SIZE { - Decision::Drop - } else if peer.sessions.current.is_none() { - if peer.sessions.next.is_some() { - Decision::Wait - } else { - Decision::Handshake(peer_ref.clone()) + fn handle_egress_packet(&mut self, packet: UtunPacket) -> Result<(), Error> { + ensure!(!packet.payload().is_empty() && packet.payload().len() <= MAX_CONTENT_SIZE, "egress packet outside of size bounds"); + + let peer_ref = self.shared_state.borrow_mut().router.route_to_peer(packet.payload()) + .ok_or_else(|| err_msg("no route to peer"))?; + + let needs_handshake = { + let mut peer = peer_ref.borrow_mut(); + peer.outgoing_queue.push_back(packet); + + if peer.ready_for_transport() { + if peer.outgoing_queue.len() > 1 { + debug!("sending {} queued egress packets", peer.outgoing_queue.len()); + } + + while let Some(packet) = peer.outgoing_queue.pop_front() { + self.send_to_peer(peer.handle_outgoing_transport(packet.payload())?); } - } else { - Decision::Transport(peer.handle_outgoing_transport(packet.payload())?) } + peer.needs_new_handshake() }; - match decision { - Decision::Transport(outgoing) => { - self.send_to_peer(outgoing); - let _ = self.outgoing_rx.poll(); - Ok(true) - }, - Decision::Handshake(peer_ref) => { - debug!("kicking off handshake because there are pending outgoing packets"); - self.send_handshake_init(&peer_ref)?; - Ok(false) - }, - Decision::Drop => { - let _ = self.outgoing_rx.poll(); - Ok(true) - }, - Decision::Wait => { - Ok(false) - } + if needs_handshake { + debug!("sending handshake init because peer needs it"); + self.send_handshake_init(&peer_ref)?; } + Ok(()) } } @@ -396,7 +405,7 @@ impl Future for PeerServer { loop { match self.udp_stream.poll() { Ok(Async::Ready(Some((addr, packet)))) => { - let _ = self.handle_incoming_packet(addr, &packet).map_err(|e| warn!("UDP ERR: {:?}", e)); + let _ = self.handle_ingress_packet(addr, &packet).map_err(|e| warn!("UDP ERR: {:?}", e)); }, Ok(Async::NotReady) => break, Ok(Async::Ready(None)) | Err(_) => return Err(()), @@ -405,9 +414,12 @@ impl Future for PeerServer { // Handle packets coming from the local tunnel loop { - match self.peek_from_tun_and_handle().map_err(|e| { warn!("TUN ERR: {:?}", e); e }) { - Ok(false) | Err(_) => break, - _ => {} + match self.outgoing_rx.poll() { + Ok(Async::Ready(Some(packet))) => { + let _ = self.handle_egress_packet(packet).map_err(|e| warn!("UDP ERR: {:?}", e)); + }, + Ok(Async::NotReady) => break, + Ok(Async::Ready(None)) | Err(_) => return Err(()), } } |