diff options
author | Jake McGinty <me@jake.su> | 2018-02-12 19:20:03 +0000 |
---|---|---|
committer | Jake McGinty <me@jake.su> | 2018-02-12 19:20:03 +0000 |
commit | 07a4a45035137876f38ae45ba8a900404a07f02d (patch) | |
tree | 4c76594397696d35e36f51a287dc9a91bc9565e5 /src | |
parent | only count authenticated packets in tx/rx numbers (diff) | |
download | wireguard-rs-07a4a45035137876f38ae45ba8a900404a07f02d.tar.xz wireguard-rs-07a4a45035137876f38ae45ba8a900404a07f02d.zip |
wait for first packet before next -> current transition
Diffstat (limited to 'src')
-rw-r--r-- | src/interface/peer_server.rs | 53 | ||||
-rw-r--r-- | src/protocol/peer.rs | 139 |
2 files changed, 104 insertions, 88 deletions
diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs index a142d31..e48db10 100644 --- a/src/interface/peer_server.rs +++ b/src/interface/peer_server.rs @@ -141,28 +141,22 @@ impl PeerServer { }; let mut peer = peer_ref.borrow_mut(); - let (response, next_index, dead_index) = peer.process_incoming_handshake(addr, their_index, timestamp.into(), noise)?; + let (response, next_index) = peer.process_incoming_handshake(addr, their_index, timestamp.into(), noise)?; let _ = state.index_map.insert(next_index, peer_ref.clone()); - if let Some(index) = dead_index { - let _ = state.index_map.remove(&index); - } self.send_to_peer((addr, response)); info!("sent handshake response, ratcheted session."); }, 2 => { - let their_index = LittleEndian::read_u32(&packet[4..]); - let our_index = LittleEndian::read_u32(&packet[8..]); - let peer_ref = state.index_map.get(&our_index).unwrap().clone(); + let our_index = LittleEndian::read_u32(&packet[8..]); + let peer_ref = state.index_map.get(&our_index) + .ok_or_else(|| format_err!("unknown our_index"))? + .clone(); let mut peer = peer_ref.borrow_mut(); - peer.sessions.next.as_mut().unwrap().their_index = their_index; - let payload_len = peer.next_noise().expect("pending noise session") - .read_message(&packet[12..60], &mut []) - .map_err(SyncFailure::new)?; - - ensure!(payload_len == 0, "non-zero payload length in handshake response"); - - peer.ratchet_session()?; + let dead_index = peer.process_incoming_handshake_response(&packet)?; + if let Some(index) = dead_index { + let _ = state.index_map.remove(&index); + } info!("got handshake response, ratcheted session."); // TODO neither of these timers are to spec, but are simple functional placeholders @@ -200,22 +194,27 @@ impl PeerServer { let our_index_received = LittleEndian::read_u32(&packet[4..]); let nonce = LittleEndian::read_u64(&packet[8..]); - let lookup = state.index_map.get(&our_index_received); - if let Some(ref peer) = lookup { - let raw_packet = { - let mut peer = peer.borrow_mut(); - peer.handle_incoming_transport(our_index_received, nonce, addr, &packet[16..])? - }; + let peer_ref = state.index_map.get(&our_index_received) + .ok_or_else(|| format_err!("unknown our_index"))? + .clone(); - if raw_packet.len() == 0 { - return Ok(()) // short-circuit on keep-alives - } + let (raw_packet, dead_index) = { + let mut peer = peer_ref.borrow_mut(); + peer.handle_incoming_transport(our_index_received, nonce, addr, &packet[16..])? + }; - state.router.validate_source(&raw_packet, peer)?; + if let Some(index) = dead_index { + let _ = state.index_map.remove(&index); + } - trace_packet("received TRANSPORT: ", &raw_packet); - self.send_to_tunnel(UtunPacket::from(raw_packet)?); + if raw_packet.len() == 0 { + return Ok(()) // short-circuit on keep-alives } + + state.router.validate_source(&raw_packet, &peer_ref)?; + + trace_packet("received TRANSPORT: ", &raw_packet); + self.send_to_tunnel(UtunPacket::from(raw_packet)?); }, _ => bail!("unknown wireguard message type") } diff --git a/src/protocol/peer.rs b/src/protocol/peer.rs index ad6d361..3e1bac5 100644 --- a/src/protocol/peer.rs +++ b/src/protocol/peer.rs @@ -1,13 +1,13 @@ use anti_replay::AntiReplay; use byteorder::{ByteOrder, BigEndian, LittleEndian}; use blake2_rfc::blake2s::{Blake2s, blake2s}; -use consts::{TRANSPORT_OVERHEAD, TRANSPORT_HEADER_SIZE}; +use consts::{TRANSPORT_OVERHEAD, TRANSPORT_HEADER_SIZE, MAX_SEGMENT_SIZE}; use failure::{Error, SyncFailure}; use pnet::packet::Packet; use pnet::packet::ip::IpNextHeaderProtocols; use pnet::packet::ipv4::{self, MutableIpv4Packet}; use pnet::packet::icmp::{self, MutableIcmpPacket, IcmpTypes, echo_reply, echo_request}; -use std::{self, io}; +use std::{self, io, mem}; use std::fmt::{self, Debug, Display, Formatter}; use std::net::{Ipv4Addr, IpAddr, SocketAddr, ToSocketAddrs}; use std::str::FromStr; @@ -45,6 +45,11 @@ impl PartialEq for Peer { } } +#[derive(PartialEq)] +enum SessionType { + Past, Current, Next +} + pub struct Session { pub noise: snow::Session, pub our_index: u32, @@ -115,55 +120,14 @@ impl Peer { } pub fn set_next_session(&mut self, session: Session) { - let _ = std::mem::replace(&mut self.sessions.next, Some(session)); + let _ = mem::replace(&mut self.sessions.next, Some(session)); } - pub fn ratchet_session(&mut self) -> Result<Option<Session>, Error> { - let next = std::mem::replace(&mut self.sessions.next, None) - .ok_or_else(|| format_err!("next session is missing"))?; - let next = next.into_transport_mode(); - - let current = std::mem::replace(&mut self.sessions.current, Some(next)); - let dead = std::mem::replace(&mut self.sessions.past, current); - - self.last_handshake = Some(SystemTime::now()); - Ok(dead) - } - - pub fn handle_incoming_transport(&mut self, our_index: u32, nonce: u64, addr: SocketAddr, packet: &[u8]) -> Result<Vec<u8>, Error> { - - let session = self.sessions.current.as_mut().filter(|session| session.our_index == our_index) - .or(self.sessions.past.as_mut().filter(|session| session.our_index == our_index)) - .ok_or_else(|| format_err!("couldn't find available session"))?; - - session.anti_replay.update(nonce)?; - - let mut raw_packet = vec![0u8; 1500]; - session.noise.set_receiving_nonce(nonce) - .map_err(SyncFailure::new)?; - let len = session.noise.read_message(packet, &mut raw_packet) - .map_err(SyncFailure::new)?; - - self.rx_bytes += packet.len() as u64; - self.info.endpoint = Some(addr); // update peer endpoint after successful authentication - - raw_packet.truncate(len); - Ok(raw_packet) - } - - pub fn handle_outgoing_transport(&mut self, packet: &[u8]) -> Result<(SocketAddr, Vec<u8>), Error> { - let session = self.sessions.current.as_mut().ok_or_else(|| format_err!("no current noise session"))?; - let endpoint = self.info.endpoint.ok_or_else(|| format_err!("no known peer endpoint"))?; - - let mut out_packet = vec![0u8; packet.len() + TRANSPORT_OVERHEAD]; - out_packet[0] = 4; - LittleEndian::write_u32(&mut out_packet[4..], session.their_index); - LittleEndian::write_u64(&mut out_packet[8..], session.noise.sending_nonce().map_err(SyncFailure::new)?); - let len = session.noise.write_message(packet, &mut out_packet[16..]) - .map_err(SyncFailure::new)?; - self.tx_bytes += len as u64; - out_packet.truncate(TRANSPORT_HEADER_SIZE + len); - Ok((endpoint, out_packet)) + fn find_session(&mut self, our_index: u32) -> Result<(&mut Session, SessionType), Error> { + self.sessions.next.as_mut().filter(|session| session.our_index == our_index).map(|s| (s, SessionType::Next)) + .or(self.sessions.current.as_mut().filter(|session| session.our_index == our_index).map(|s| (s, SessionType::Current))) + .or(self.sessions.past.as_mut().filter(|session| session.our_index == our_index).map(|s| (s, SessionType::Past))) + .ok_or_else(|| format_err!("couldn't find available session")) } pub fn current_noise(&mut self) -> Option<&mut snow::Session> { @@ -174,14 +138,6 @@ impl Peer { } } - pub fn next_noise(&mut self) -> Option<&mut snow::Session> { - if let Some(ref mut session) = self.sessions.next { - Some(&mut session.noise) - } else { - None - } - } - pub fn our_next_index(&self) -> Option<u32> { if let Some(ref session) = self.sessions.next { Some(session.our_index) @@ -230,7 +186,7 @@ impl Peer { /// /// Returns: the response packet (type 0x02), and an optional dead session index that was removed. pub fn process_incoming_handshake(&mut self, addr: SocketAddr, their_index: u32, timestamp: TAI64N, mut noise: snow::Session) - -> Result<(Vec<u8>, u32, Option<u32>), Error> { + -> Result<(Vec<u8>, u32), Error> { if let Some(ref last_tai64n) = self.last_handshake_tai64n { ensure!(×tamp > last_tai64n, "handshake timestamp earlier than last handshake's timestamp"); @@ -249,12 +205,10 @@ impl Peer { let response_packet = self.get_response_packet(&mut next_session)?; self.set_next_session(next_session); - let dead_index = self.ratchet_session()?.map(|session| session.our_index); - self.info.endpoint = Some(addr); // update peer endpoint after successful authentication self.last_handshake_tai64n = Some(timestamp); - Ok((response_packet, next_index, dead_index)) + Ok((response_packet, next_index)) } fn get_response_packet(&mut self, next_session: &mut Session) -> Result<Vec<u8>, Error> { @@ -273,6 +227,69 @@ impl Peer { Ok(packet) } + pub fn process_incoming_handshake_response(&mut self, packet: &[u8]) -> Result<Option<u32>, Error> { + let their_index = LittleEndian::read_u32(&packet[4..]); + let mut session = mem::replace(&mut self.sessions.next, None).ok_or_else(|| format_err!("no next session"))?; + let len = session.noise.read_message(&packet[12..60], &mut []).map_err(SyncFailure::new)?; + + ensure!(len == 0, "non-zero payload length in handshake response"); + session.their_index = their_index; + + let session = session.into_transport_mode(); + + let current = mem::replace(&mut self.sessions.current, Some(session)); + let dead = mem::replace(&mut self.sessions.past, current); + + self.last_handshake = Some(SystemTime::now()); + Ok(dead.map(|session| session.our_index)) + } + + pub fn handle_incoming_transport(&mut self, our_index: u32, nonce: u64, addr: SocketAddr, packet: &[u8]) + -> Result<(Vec<u8>, Option<u32>), Error> { + + let mut raw_packet = vec![0u8; MAX_SEGMENT_SIZE]; + let session_type = { + let (session, session_type) = self.find_session(our_index)?; + ensure!(session.noise.is_handshake_finished(), "session is not ready for transport packets"); + + session.anti_replay.update(nonce)?; + session.noise.set_receiving_nonce(nonce).map_err(SyncFailure::new)?; + let len = session.noise.read_message(packet, &mut raw_packet).map_err(SyncFailure::new)?; + raw_packet.truncate(len); + + session_type + }; + + let dead_index = if session_type == SessionType::Next { + let next = std::mem::replace(&mut self.sessions.next, None); + let current = std::mem::replace(&mut self.sessions.current, next); + let dead = std::mem::replace(&mut self.sessions.past, current); + dead.map(|session| session.our_index) + } else { + None + }; + + self.rx_bytes += packet.len() as u64; + self.info.endpoint = Some(addr); // update peer endpoint after successful authentication + + Ok((raw_packet, dead_index)) + } + + pub fn handle_outgoing_transport(&mut self, packet: &[u8]) -> Result<(SocketAddr, Vec<u8>), Error> { + let session = self.sessions.current.as_mut().ok_or_else(|| format_err!("no current noise session"))?; + let endpoint = self.info.endpoint.ok_or_else(|| format_err!("no known peer endpoint"))?; + let mut out_packet = vec![0u8; packet.len() + TRANSPORT_OVERHEAD]; + + out_packet[0] = 4; + LittleEndian::write_u32(&mut out_packet[4..], session.their_index); + LittleEndian::write_u64(&mut out_packet[8..], session.noise.sending_nonce().map_err(SyncFailure::new)?); + let len = session.noise.write_message(packet, &mut out_packet[16..]) + .map_err(SyncFailure::new)?; + self.tx_bytes += len as u64; + out_packet.truncate(TRANSPORT_HEADER_SIZE + len); + Ok((endpoint, out_packet)) + } + pub fn to_config_string(&self) -> String { let mut s = format!("public_key={}\n", hex::encode(&self.info.pub_key)); if let Some(ref psk) = self.info.psk { |