diff options
author | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-11-18 13:13:55 +0100 |
---|---|---|
committer | Mathias Hall-Andersen <mathias@hall-andersen.dk> | 2019-11-18 13:13:55 +0100 |
commit | 3ba0247634bbaa1da61532ca43e67fb2ad6c1106 (patch) | |
tree | 87ffb281f76b335d2c6441168186c348dbb28d7e /src/wireguard/router/route.rs | |
parent | Bug fixes from compliance tests with WireGuard (diff) | |
download | wireguard-rs-3ba0247634bbaa1da61532ca43e67fb2ad6c1106.tar.xz wireguard-rs-3ba0247634bbaa1da61532ca43e67fb2ad6c1106.zip |
Better compartmentalization of cryptokey router
Diffstat (limited to 'src/wireguard/router/route.rs')
-rw-r--r-- | src/wireguard/router/route.rs | 246 |
1 files changed, 148 insertions, 98 deletions
diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs index 29e7635..e5f5955 100644 --- a/src/wireguard/router/route.rs +++ b/src/wireguard/router/route.rs @@ -1,113 +1,163 @@ -use super::super::{bind, tun, Endpoint}; -use super::device::DeviceInner; use super::ip::*; -use super::peer::PeerInner; -use super::types::Callbacks; -use log::trace; use zerocopy::LayoutVerified; use std::mem; -use std::net::{Ipv4Addr, Ipv6Addr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::sync::Arc; -#[inline(always)] -pub fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( - device: &Arc<DeviceInner<E, C, T, B>>, - packet: &[u8], -) -> Option<Arc<PeerInner<E, C, T, B>>> { - match packet.get(0)? >> 4 { - VERSION_IP4 => { - // check length and cast to IPv4 header - let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - log::trace!( - "Router, get route for IPv4 destination: {:?}", - Ipv4Addr::from(header.f_destination) - ); - - // check IPv4 source address - device - .ipv4 - .read() - .longest_match(Ipv4Addr::from(header.f_destination)) - .and_then(|(_, _, p)| Some(p.clone())) +use spin::RwLock; +use treebitmap::address::Address; +use treebitmap::IpLookupTable; + +/* Functions for obtaining and validating "cryptokey" routes */ + +pub struct RoutingTable<T> { + ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<T>>>, + ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<T>>>, +} + +impl<T> RoutingTable<T> { + pub fn new() -> Self { + RoutingTable { + ipv4: RwLock::new(IpLookupTable::new()), + ipv6: RwLock::new(IpLookupTable::new()), } - VERSION_IP6 => { - // check length and cast to IPv6 header - let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - log::trace!( - "Router, get route for IPv6 destination: {:?}", - Ipv6Addr::from(header.f_destination) - ); - - // check IPv6 source address - device - .ipv6 - .read() - .longest_match(Ipv6Addr::from(header.f_destination)) - .and_then(|(_, _, p)| Some(p.clone())) + } + + fn collect<A>(table: &IpLookupTable<A, Arc<T>>, value: &Arc<T>) -> Vec<(A, u32)> + where + A: Address, + { + let mut res = Vec::new(); + for (ip, cidr, v) in table.iter() { + if Arc::ptr_eq(v, value) { + res.push((ip, cidr)) + } } - _ => None, + res } -} -#[inline(always)] -pub fn check_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>( - device: &Arc<DeviceInner<E, C, T, B>>, - peer: &Arc<PeerInner<E, C, T, B>>, - packet: &[u8], -) -> Option<usize> { - match packet.get(0)? >> 4 { - VERSION_IP4 => { - // check length and cast to IPv4 header - let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - log::trace!( - "Router, check route for IPv4 source: {:?}", - Ipv4Addr::from(header.f_source) - ); - - // check IPv4 source address - device - .ipv4 - .read() - .longest_match(Ipv4Addr::from(header.f_source)) - .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, peer) { - Some(header.f_total_len.get() as usize) - } else { - None - } - }) + pub fn list(&self, value: &Arc<T>) -> Vec<(IpAddr, u32)> { + let mut res = vec![]; + res.extend( + Self::collect(&*self.ipv4.read(), value) + .into_iter() + .map(|(ip, cidr)| (IpAddr::V4(ip), cidr)), + ); + res.extend( + Self::collect(&*self.ipv6.read(), value) + .into_iter() + .map(|(ip, cidr)| (IpAddr::V6(ip), cidr)), + ); + res + } + + pub fn remove(&self, value: &Arc<T>) { + let mut v4 = self.ipv4.write(); + let mut v6 = self.ipv6.write(); + for (ip, cidr) in Self::collect(&*v4, value) { + v4.remove(ip, cidr); } - VERSION_IP6 => { - // check length and cast to IPv6 header - let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = - LayoutVerified::new_from_prefix(packet)?; - - log::trace!( - "Router, check route for IPv6 source: {:?}", - Ipv6Addr::from(header.f_source) - ); - - // check IPv6 source address - device - .ipv6 - .read() - .longest_match(Ipv6Addr::from(header.f_source)) - .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, peer) { - Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>()) - } else { - None - } - }) + for (ip, cidr) in Self::collect(&*v6, value) { + v6.remove(ip, cidr); } - _ => None, + } + + #[inline(always)] + pub fn get_route(&self, packet: &[u8]) -> Option<Arc<T>> { + match packet.get(0)? >> 4 { + VERSION_IP4 => { + // check length and cast to IPv4 header + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + log::trace!( + "Router, get route for IPv4 destination: {:?}", + Ipv4Addr::from(header.f_destination) + ); + + // check IPv4 source address + self.ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_destination)) + .and_then(|(_, _, p)| Some(p.clone())) + } + VERSION_IP6 => { + // check length and cast to IPv6 header + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + log::trace!( + "Router, get route for IPv6 destination: {:?}", + Ipv6Addr::from(header.f_destination) + ); + + // check IPv6 source address + self.ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_destination)) + .and_then(|(_, _, p)| Some(p.clone())) + } + _ => None, + } + } + + #[inline(always)] + pub fn check_route(&self, peer: &Arc<T>, packet: &[u8]) -> Option<usize> { + match packet.get(0)? >> 4 { + VERSION_IP4 => { + // check length and cast to IPv4 header + let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + log::trace!( + "Router, check route for IPv4 source: {:?}", + Ipv4Addr::from(header.f_source) + ); + + // check IPv4 source address + self.ipv4 + .read() + .longest_match(Ipv4Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, peer) { + Some(header.f_total_len.get() as usize) + } else { + None + } + }) + } + VERSION_IP6 => { + // check length and cast to IPv6 header + let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = + LayoutVerified::new_from_prefix(packet)?; + + log::trace!( + "Router, check route for IPv6 source: {:?}", + Ipv6Addr::from(header.f_source) + ); + + // check IPv6 source address + self.ipv6 + .read() + .longest_match(Ipv6Addr::from(header.f_source)) + .and_then(|(_, _, p)| { + if Arc::ptr_eq(p, peer) { + Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>()) + } else { + None + } + }) + } + _ => None, + } + } + + pub fn insert(&self, ip: IpAddr, cidr: u32, value: Arc<T>) { + match ip { + IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value), + IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value), + }; } } |