diff options
author | Jake McGinty <me@jake.su> | 2018-02-09 14:21:10 +0000 |
---|---|---|
committer | Jake McGinty <me@jake.su> | 2018-02-09 14:21:10 +0000 |
commit | 1ef81bf07f9bfbfd57c8c4d408f037eb20ce7367 (patch) | |
tree | eef7a5fbebeb4353f5e19b50a95ebf1dced1c24d | |
parent | router.rs (diff) | |
download | wireguard-rs-1ef81bf07f9bfbfd57c8c4d408f037eb20ce7367.tar.xz wireguard-rs-1ef81bf07f9bfbfd57c8c4d408f037eb20ce7367.zip |
router refactoring
-rw-r--r-- | src/interface/mod.rs | 9 | ||||
-rw-r--r-- | src/interface/peer_server.rs | 39 | ||||
-rw-r--r-- | src/ip_packet.rs | 33 | ||||
-rw-r--r-- | src/main.rs | 1 | ||||
-rw-r--r-- | src/protocol/peer.rs | 10 | ||||
-rw-r--r-- | src/router.rs | 35 |
6 files changed, 97 insertions, 30 deletions
diff --git a/src/interface/mod.rs b/src/interface/mod.rs index 73530fe..6d51902 100644 --- a/src/interface/mod.rs +++ b/src/interface/mod.rs @@ -8,6 +8,7 @@ use router::Router; use base64; use hex; use byteorder::{ByteOrder, BigEndian, LittleEndian}; +use failure::Error; use snow::NoiseBuilder; use protocol::Peer; use std::io; @@ -63,6 +64,14 @@ impl UtunPacket { &UtunPacket::Inet6(ref payload) => &payload, } } + + pub fn from(raw_packet: Vec<u8>) -> Result<UtunPacket, Error> { + match raw_packet[0] >> 4 { + 4 => Ok(UtunPacket::Inet4(raw_packet)), + 6 => Ok(UtunPacket::Inet6(raw_packet)), + _ => bail!("unrecognized IP version") + } + } } impl UtunCodec for VecUtunCodec { diff --git a/src/interface/peer_server.rs b/src/interface/peer_server.rs index 8b2d81e..dfe37e1 100644 --- a/src/interface/peer_server.rs +++ b/src/interface/peer_server.rs @@ -108,6 +108,14 @@ impl PeerServer { self.udp_tx.clone() } + fn send_to_peer(&self, payload: PeerServerMessage) { + self.handle.spawn(self.udp_tx.clone().send(payload).then(|_| Ok(()))); + } + + fn send_to_tunnel(&self, packet: UtunPacket) { + self.handle.spawn(self.tunnel_tx.clone().send(packet).then(|_| Ok(()))); + } + // TODO: create a transport packet (type 0x4) queue until a handshake has been completed fn handle_incoming_packet(&mut self, addr: SocketAddr, packet: Vec<u8>) -> Result<(), Error> { debug!("got a UDP packet from {:?} of length {}, packet type {}", &addr, packet.len(), packet[0]); @@ -150,7 +158,7 @@ impl PeerServer { let response_packet = peer.get_response_packet()?; - self.handle.spawn(self.udp_tx.clone().send((addr.clone(), response_packet)).then(|_| Ok(()))); + self.send_to_peer((addr.clone(), response_packet)); let dead_session = peer.ratchet_session()?; if let Some(session) = dead_session { let _ = state.index_map.remove(&session.our_index); @@ -211,25 +219,18 @@ impl PeerServer { let lookup = state.index_map.get(&our_index_received); if let Some(ref peer) = lookup { - let mut peer = peer.borrow_mut(); - - let res = peer.decrypt_transport_packet(our_index_received, nonce, &packet[16..]); - - if let Ok(raw_packet) = res { - trace_packet("received TRANSPORT: ", &raw_packet); - let utun_packet = match raw_packet[0] >> 4 { - 4 => UtunPacket::Inet4(raw_packet), - 6 => UtunPacket::Inet6(raw_packet), - _ => unimplemented!() - }; - self.handle.spawn(self.tunnel_tx.clone().send(utun_packet) - .then(|_| Ok(()))); - } else { - warn!("dropped incoming tranport packet that neither the current nor past session could decrypt"); - } + let raw_packet = { + let mut peer = peer.borrow_mut(); + peer.decrypt_transport_packet(our_index_received, nonce, &packet[16..])? + }; + + state.router.validate_source(&raw_packet, peer)?; + + trace_packet("received TRANSPORT: ", &raw_packet); + self.send_to_tunnel(UtunPacket::from(raw_packet)?); } }, - _ => unimplemented!() + _ => bail!("unknown wireguard message type") } Ok(()) } @@ -285,7 +286,7 @@ impl PeerServer { trace_packet("received UTUN packet: ", packet.payload()); let state = self.shared_state.borrow(); let mut out_packet = vec![0u8; 1500]; - let peer = state.router.route_to_peer(&packet); + let peer = state.router.route_to_peer(packet.payload()); if let Some(peer) = peer { let mut peer = peer.borrow_mut(); diff --git a/src/ip_packet.rs b/src/ip_packet.rs new file mode 100644 index 0000000..4aa7ecb --- /dev/null +++ b/src/ip_packet.rs @@ -0,0 +1,33 @@ +use failure::Error; +use pnet::packet::ipv4::Ipv4Packet; +use pnet::packet::ipv6::Ipv6Packet; +use std::net::{Ipv4Addr, Ipv6Addr, IpAddr, SocketAddr}; + +pub enum IpPacket<'a> { + V4(Ipv4Packet<'a>), + V6(Ipv6Packet<'a>), +} + +impl<'a> IpPacket<'a> { + pub fn new(packet: &'a [u8]) -> Option<Self> { + match packet[0] >> 4 { + 4 => Ipv4Packet::new(&packet).map(|packet| IpPacket::V4(packet)), + 6 => Ipv6Packet::new(&packet).map(|packet| IpPacket::V6(packet)), + _ => None + } + } + + pub fn get_source(&self) -> IpAddr { + match *self { + IpPacket::V4(ref packet) => packet.get_source().into(), + IpPacket::V6(ref packet) => packet.get_source().into(), + } + } + + pub fn get_destination(&self) -> IpAddr { + match *self { + IpPacket::V4(ref packet) => packet.get_destination().into(), + IpPacket::V6(ref packet) => packet.get_destination().into(), + } + } +} diff --git a/src/main.rs b/src/main.rs index 572f4ec..7d4bd28 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,6 +38,7 @@ mod protocol; mod types; mod anti_replay; mod router; +mod ip_packet; use std::path::PathBuf; diff --git a/src/protocol/peer.rs b/src/protocol/peer.rs index 14d8691..8f461fb 100644 --- a/src/protocol/peer.rs +++ b/src/protocol/peer.rs @@ -32,6 +32,16 @@ pub struct Peer { pub last_handshake: Option<SystemTime>, } +impl PartialEq for Peer { + fn eq(&self, other: &Peer) -> bool { + self.info.pub_key == other.info.pub_key + } + + fn ne(&self, other: &Peer) -> bool { + self.info.pub_key != other.info.pub_key + } +} + pub struct Session { pub noise: snow::Session, pub our_index: u32, diff --git a/src/router.rs b/src/router.rs index 697907e..a542a4d 100644 --- a/src/router.rs +++ b/src/router.rs @@ -1,9 +1,11 @@ +use failure::Error; use interface::{SharedPeer, UtunPacket}; +use protocol::Peer; use treebitmap::{IpLookupTable, IpLookupTableOps}; use std::net::{Ipv4Addr, Ipv6Addr, IpAddr, SocketAddr}; +use ip_packet::IpPacket; use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv6::Ipv6Packet; -use pnet::packet::ethernet::{EtherTypes, EthernetPacket}; /// The `Router` struct is, as one might expect, the authority for the IP routing table. pub struct Router { @@ -32,16 +34,27 @@ impl Router { } } - pub fn route_to_peer(&self, packet: &UtunPacket) -> Option<SharedPeer> { - match packet { - &UtunPacket::Inet4(ref packet) => { - let destination = Ipv4Packet::new(&packet).unwrap().get_destination(); - self.ip4_map.longest_match(destination).map(|(_, _, peer)| peer.clone()) - }, - &UtunPacket::Inet6(ref packet) => { - let destination = Ipv6Packet::new(&packet).unwrap().get_destination(); - self.ip6_map.longest_match(destination).map(|(_, _, peer)| peer.clone()) - } + fn get_peer_from_ip(&self, ip: IpAddr) -> Option<SharedPeer> { + match ip { + IpAddr::V4(ip) => self.ip4_map.longest_match(ip).map(|(_, _, peer)| peer.clone()), + IpAddr::V6(ip) => self.ip6_map.longest_match(ip).map(|(_, _, peer)| peer.clone()) } } + + pub fn route_to_peer(&self, packet: &[u8]) -> Option<SharedPeer> { + match IpPacket::new(&packet) { + Some(packet) => self.get_peer_from_ip(packet.get_destination()), + _ => None + } + } + + pub fn validate_source(&self, packet: &[u8], peer: &SharedPeer) -> Result<(), Error> { + let routed_peer = match IpPacket::new(&packet) { + Some(packet) => self.get_peer_from_ip(packet.get_source()), + _ => None + }.ok_or_else(|| format_err!("no peer found on route"))?; + + ensure!(&routed_peer == peer, "peer mismatch"); + Ok(()) + } }
\ No newline at end of file |