aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorJake McGinty <me@jake.su>2018-02-12 19:20:03 +0000
committerJake McGinty <me@jake.su>2018-02-12 19:20:03 +0000
commit07a4a45035137876f38ae45ba8a900404a07f02d (patch)
tree4c76594397696d35e36f51a287dc9a91bc9565e5 /src
parentonly count authenticated packets in tx/rx numbers (diff)
downloadwireguard-rs-07a4a45035137876f38ae45ba8a900404a07f02d.tar.xz
wireguard-rs-07a4a45035137876f38ae45ba8a900404a07f02d.zip
wait for first packet before next -> current transition
Diffstat (limited to 'src')
-rw-r--r--src/interface/peer_server.rs53
-rw-r--r--src/protocol/peer.rs139
2 files changed, 104 insertions, 88 deletions
diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs
index a142d31..e48db10 100644
--- a/src/interface/peer_server.rs
+++ b/src/interface/peer_server.rs
@@ -141,28 +141,22 @@ impl PeerServer {
};
let mut peer = peer_ref.borrow_mut();
- let (response, next_index, dead_index) = peer.process_incoming_handshake(addr, their_index, timestamp.into(), noise)?;
+ let (response, next_index) = peer.process_incoming_handshake(addr, their_index, timestamp.into(), noise)?;
let _ = state.index_map.insert(next_index, peer_ref.clone());
- if let Some(index) = dead_index {
- let _ = state.index_map.remove(&index);
- }
self.send_to_peer((addr, response));
info!("sent handshake response, ratcheted session.");
},
2 => {
- let their_index = LittleEndian::read_u32(&packet[4..]);
- let our_index = LittleEndian::read_u32(&packet[8..]);
- let peer_ref = state.index_map.get(&our_index).unwrap().clone();
+ let our_index = LittleEndian::read_u32(&packet[8..]);
+ let peer_ref = state.index_map.get(&our_index)
+ .ok_or_else(|| format_err!("unknown our_index"))?
+ .clone();
let mut peer = peer_ref.borrow_mut();
- peer.sessions.next.as_mut().unwrap().their_index = their_index;
- let payload_len = peer.next_noise().expect("pending noise session")
- .read_message(&packet[12..60], &mut [])
- .map_err(SyncFailure::new)?;
-
- ensure!(payload_len == 0, "non-zero payload length in handshake response");
-
- peer.ratchet_session()?;
+ let dead_index = peer.process_incoming_handshake_response(&packet)?;
+ if let Some(index) = dead_index {
+ let _ = state.index_map.remove(&index);
+ }
info!("got handshake response, ratcheted session.");
// TODO neither of these timers are to spec, but are simple functional placeholders
@@ -200,22 +194,27 @@ impl PeerServer {
let our_index_received = LittleEndian::read_u32(&packet[4..]);
let nonce = LittleEndian::read_u64(&packet[8..]);
- let lookup = state.index_map.get(&our_index_received);
- if let Some(ref peer) = lookup {
- let raw_packet = {
- let mut peer = peer.borrow_mut();
- peer.handle_incoming_transport(our_index_received, nonce, addr, &packet[16..])?
- };
+ let peer_ref = state.index_map.get(&our_index_received)
+ .ok_or_else(|| format_err!("unknown our_index"))?
+ .clone();
- if raw_packet.len() == 0 {
- return Ok(()) // short-circuit on keep-alives
- }
+ let (raw_packet, dead_index) = {
+ let mut peer = peer_ref.borrow_mut();
+ peer.handle_incoming_transport(our_index_received, nonce, addr, &packet[16..])?
+ };
- state.router.validate_source(&raw_packet, peer)?;
+ if let Some(index) = dead_index {
+ let _ = state.index_map.remove(&index);
+ }
- trace_packet("received TRANSPORT: ", &raw_packet);
- self.send_to_tunnel(UtunPacket::from(raw_packet)?);
+ if raw_packet.len() == 0 {
+ return Ok(()) // short-circuit on keep-alives
}
+
+ state.router.validate_source(&raw_packet, &peer_ref)?;
+
+ trace_packet("received TRANSPORT: ", &raw_packet);
+ self.send_to_tunnel(UtunPacket::from(raw_packet)?);
},
_ => bail!("unknown wireguard message type")
}
diff --git a/src/protocol/peer.rs b/src/protocol/peer.rs
index ad6d361..3e1bac5 100644
--- a/src/protocol/peer.rs
+++ b/src/protocol/peer.rs
@@ -1,13 +1,13 @@
use anti_replay::AntiReplay;
use byteorder::{ByteOrder, BigEndian, LittleEndian};
use blake2_rfc::blake2s::{Blake2s, blake2s};
-use consts::{TRANSPORT_OVERHEAD, TRANSPORT_HEADER_SIZE};
+use consts::{TRANSPORT_OVERHEAD, TRANSPORT_HEADER_SIZE, MAX_SEGMENT_SIZE};
use failure::{Error, SyncFailure};
use pnet::packet::Packet;
use pnet::packet::ip::IpNextHeaderProtocols;
use pnet::packet::ipv4::{self, MutableIpv4Packet};
use pnet::packet::icmp::{self, MutableIcmpPacket, IcmpTypes, echo_reply, echo_request};
-use std::{self, io};
+use std::{self, io, mem};
use std::fmt::{self, Debug, Display, Formatter};
use std::net::{Ipv4Addr, IpAddr, SocketAddr, ToSocketAddrs};
use std::str::FromStr;
@@ -45,6 +45,11 @@ impl PartialEq for Peer {
}
}
+#[derive(PartialEq)]
+enum SessionType {
+ Past, Current, Next
+}
+
pub struct Session {
pub noise: snow::Session,
pub our_index: u32,
@@ -115,55 +120,14 @@ impl Peer {
}
pub fn set_next_session(&mut self, session: Session) {
- let _ = std::mem::replace(&mut self.sessions.next, Some(session));
+ let _ = mem::replace(&mut self.sessions.next, Some(session));
}
- pub fn ratchet_session(&mut self) -> Result<Option<Session>, Error> {
- let next = std::mem::replace(&mut self.sessions.next, None)
- .ok_or_else(|| format_err!("next session is missing"))?;
- let next = next.into_transport_mode();
-
- let current = std::mem::replace(&mut self.sessions.current, Some(next));
- let dead = std::mem::replace(&mut self.sessions.past, current);
-
- self.last_handshake = Some(SystemTime::now());
- Ok(dead)
- }
-
- pub fn handle_incoming_transport(&mut self, our_index: u32, nonce: u64, addr: SocketAddr, packet: &[u8]) -> Result<Vec<u8>, Error> {
-
- let session = self.sessions.current.as_mut().filter(|session| session.our_index == our_index)
- .or(self.sessions.past.as_mut().filter(|session| session.our_index == our_index))
- .ok_or_else(|| format_err!("couldn't find available session"))?;
-
- session.anti_replay.update(nonce)?;
-
- let mut raw_packet = vec![0u8; 1500];
- session.noise.set_receiving_nonce(nonce)
- .map_err(SyncFailure::new)?;
- let len = session.noise.read_message(packet, &mut raw_packet)
- .map_err(SyncFailure::new)?;
-
- self.rx_bytes += packet.len() as u64;
- self.info.endpoint = Some(addr); // update peer endpoint after successful authentication
-
- raw_packet.truncate(len);
- Ok(raw_packet)
- }
-
- pub fn handle_outgoing_transport(&mut self, packet: &[u8]) -> Result<(SocketAddr, Vec<u8>), Error> {
- let session = self.sessions.current.as_mut().ok_or_else(|| format_err!("no current noise session"))?;
- let endpoint = self.info.endpoint.ok_or_else(|| format_err!("no known peer endpoint"))?;
-
- let mut out_packet = vec![0u8; packet.len() + TRANSPORT_OVERHEAD];
- out_packet[0] = 4;
- LittleEndian::write_u32(&mut out_packet[4..], session.their_index);
- LittleEndian::write_u64(&mut out_packet[8..], session.noise.sending_nonce().map_err(SyncFailure::new)?);
- let len = session.noise.write_message(packet, &mut out_packet[16..])
- .map_err(SyncFailure::new)?;
- self.tx_bytes += len as u64;
- out_packet.truncate(TRANSPORT_HEADER_SIZE + len);
- Ok((endpoint, out_packet))
+ fn find_session(&mut self, our_index: u32) -> Result<(&mut Session, SessionType), Error> {
+ self.sessions.next.as_mut().filter(|session| session.our_index == our_index).map(|s| (s, SessionType::Next))
+ .or(self.sessions.current.as_mut().filter(|session| session.our_index == our_index).map(|s| (s, SessionType::Current)))
+ .or(self.sessions.past.as_mut().filter(|session| session.our_index == our_index).map(|s| (s, SessionType::Past)))
+ .ok_or_else(|| format_err!("couldn't find available session"))
}
pub fn current_noise(&mut self) -> Option<&mut snow::Session> {
@@ -174,14 +138,6 @@ impl Peer {
}
}
- pub fn next_noise(&mut self) -> Option<&mut snow::Session> {
- if let Some(ref mut session) = self.sessions.next {
- Some(&mut session.noise)
- } else {
- None
- }
- }
-
pub fn our_next_index(&self) -> Option<u32> {
if let Some(ref session) = self.sessions.next {
Some(session.our_index)
@@ -230,7 +186,7 @@ impl Peer {
///
/// Returns: the response packet (type 0x02), and an optional dead session index that was removed.
pub fn process_incoming_handshake(&mut self, addr: SocketAddr, their_index: u32, timestamp: TAI64N, mut noise: snow::Session)
- -> Result<(Vec<u8>, u32, Option<u32>), Error> {
+ -> Result<(Vec<u8>, u32), Error> {
if let Some(ref last_tai64n) = self.last_handshake_tai64n {
ensure!(&timestamp > last_tai64n, "handshake timestamp earlier than last handshake's timestamp");
@@ -249,12 +205,10 @@ impl Peer {
let response_packet = self.get_response_packet(&mut next_session)?;
self.set_next_session(next_session);
- let dead_index = self.ratchet_session()?.map(|session| session.our_index);
-
self.info.endpoint = Some(addr); // update peer endpoint after successful authentication
self.last_handshake_tai64n = Some(timestamp);
- Ok((response_packet, next_index, dead_index))
+ Ok((response_packet, next_index))
}
fn get_response_packet(&mut self, next_session: &mut Session) -> Result<Vec<u8>, Error> {
@@ -273,6 +227,69 @@ impl Peer {
Ok(packet)
}
+ pub fn process_incoming_handshake_response(&mut self, packet: &[u8]) -> Result<Option<u32>, Error> {
+ let their_index = LittleEndian::read_u32(&packet[4..]);
+ let mut session = mem::replace(&mut self.sessions.next, None).ok_or_else(|| format_err!("no next session"))?;
+ let len = session.noise.read_message(&packet[12..60], &mut []).map_err(SyncFailure::new)?;
+
+ ensure!(len == 0, "non-zero payload length in handshake response");
+ session.their_index = their_index;
+
+ let session = session.into_transport_mode();
+
+ let current = mem::replace(&mut self.sessions.current, Some(session));
+ let dead = mem::replace(&mut self.sessions.past, current);
+
+ self.last_handshake = Some(SystemTime::now());
+ Ok(dead.map(|session| session.our_index))
+ }
+
+ pub fn handle_incoming_transport(&mut self, our_index: u32, nonce: u64, addr: SocketAddr, packet: &[u8])
+ -> Result<(Vec<u8>, Option<u32>), Error> {
+
+ let mut raw_packet = vec![0u8; MAX_SEGMENT_SIZE];
+ let session_type = {
+ let (session, session_type) = self.find_session(our_index)?;
+ ensure!(session.noise.is_handshake_finished(), "session is not ready for transport packets");
+
+ session.anti_replay.update(nonce)?;
+ session.noise.set_receiving_nonce(nonce).map_err(SyncFailure::new)?;
+ let len = session.noise.read_message(packet, &mut raw_packet).map_err(SyncFailure::new)?;
+ raw_packet.truncate(len);
+
+ session_type
+ };
+
+ let dead_index = if session_type == SessionType::Next {
+ let next = std::mem::replace(&mut self.sessions.next, None);
+ let current = std::mem::replace(&mut self.sessions.current, next);
+ let dead = std::mem::replace(&mut self.sessions.past, current);
+ dead.map(|session| session.our_index)
+ } else {
+ None
+ };
+
+ self.rx_bytes += packet.len() as u64;
+ self.info.endpoint = Some(addr); // update peer endpoint after successful authentication
+
+ Ok((raw_packet, dead_index))
+ }
+
+ pub fn handle_outgoing_transport(&mut self, packet: &[u8]) -> Result<(SocketAddr, Vec<u8>), Error> {
+ let session = self.sessions.current.as_mut().ok_or_else(|| format_err!("no current noise session"))?;
+ let endpoint = self.info.endpoint.ok_or_else(|| format_err!("no known peer endpoint"))?;
+ let mut out_packet = vec![0u8; packet.len() + TRANSPORT_OVERHEAD];
+
+ out_packet[0] = 4;
+ LittleEndian::write_u32(&mut out_packet[4..], session.their_index);
+ LittleEndian::write_u64(&mut out_packet[8..], session.noise.sending_nonce().map_err(SyncFailure::new)?);
+ let len = session.noise.write_message(packet, &mut out_packet[16..])
+ .map_err(SyncFailure::new)?;
+ self.tx_bytes += len as u64;
+ out_packet.truncate(TRANSPORT_HEADER_SIZE + len);
+ Ok((endpoint, out_packet))
+ }
+
pub fn to_config_string(&self) -> String {
let mut s = format!("public_key={}\n", hex::encode(&self.info.pub_key));
if let Some(ref psk) = self.info.psk {