summaryrefslogtreecommitdiffstats
path: root/src/wireguard/router/route.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard/router/route.rs')
-rw-r--r--src/wireguard/router/route.rs246
1 files changed, 148 insertions, 98 deletions
diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs
index 29e7635..e5f5955 100644
--- a/src/wireguard/router/route.rs
+++ b/src/wireguard/router/route.rs
@@ -1,113 +1,163 @@
-use super::super::{bind, tun, Endpoint};
-use super::device::DeviceInner;
use super::ip::*;
-use super::peer::PeerInner;
-use super::types::Callbacks;
-use log::trace;
use zerocopy::LayoutVerified;
use std::mem;
-use std::net::{Ipv4Addr, Ipv6Addr};
+use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
-#[inline(always)]
-pub fn get_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
- device: &Arc<DeviceInner<E, C, T, B>>,
- packet: &[u8],
-) -> Option<Arc<PeerInner<E, C, T, B>>> {
- match packet.get(0)? >> 4 {
- VERSION_IP4 => {
- // check length and cast to IPv4 header
- let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
- LayoutVerified::new_from_prefix(packet)?;
-
- log::trace!(
- "Router, get route for IPv4 destination: {:?}",
- Ipv4Addr::from(header.f_destination)
- );
-
- // check IPv4 source address
- device
- .ipv4
- .read()
- .longest_match(Ipv4Addr::from(header.f_destination))
- .and_then(|(_, _, p)| Some(p.clone()))
+use spin::RwLock;
+use treebitmap::address::Address;
+use treebitmap::IpLookupTable;
+
+/* Functions for obtaining and validating "cryptokey" routes */
+
+pub struct RoutingTable<T> {
+ ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<T>>>,
+ ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<T>>>,
+}
+
+impl<T> RoutingTable<T> {
+ pub fn new() -> Self {
+ RoutingTable {
+ ipv4: RwLock::new(IpLookupTable::new()),
+ ipv6: RwLock::new(IpLookupTable::new()),
}
- VERSION_IP6 => {
- // check length and cast to IPv6 header
- let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
- LayoutVerified::new_from_prefix(packet)?;
-
- log::trace!(
- "Router, get route for IPv6 destination: {:?}",
- Ipv6Addr::from(header.f_destination)
- );
-
- // check IPv6 source address
- device
- .ipv6
- .read()
- .longest_match(Ipv6Addr::from(header.f_destination))
- .and_then(|(_, _, p)| Some(p.clone()))
+ }
+
+ fn collect<A>(table: &IpLookupTable<A, Arc<T>>, value: &Arc<T>) -> Vec<(A, u32)>
+ where
+ A: Address,
+ {
+ let mut res = Vec::new();
+ for (ip, cidr, v) in table.iter() {
+ if Arc::ptr_eq(v, value) {
+ res.push((ip, cidr))
+ }
}
- _ => None,
+ res
}
-}
-#[inline(always)]
-pub fn check_route<E: Endpoint, C: Callbacks, T: tun::Writer, B: bind::Writer<E>>(
- device: &Arc<DeviceInner<E, C, T, B>>,
- peer: &Arc<PeerInner<E, C, T, B>>,
- packet: &[u8],
-) -> Option<usize> {
- match packet.get(0)? >> 4 {
- VERSION_IP4 => {
- // check length and cast to IPv4 header
- let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
- LayoutVerified::new_from_prefix(packet)?;
-
- log::trace!(
- "Router, check route for IPv4 source: {:?}",
- Ipv4Addr::from(header.f_source)
- );
-
- // check IPv4 source address
- device
- .ipv4
- .read()
- .longest_match(Ipv4Addr::from(header.f_source))
- .and_then(|(_, _, p)| {
- if Arc::ptr_eq(p, peer) {
- Some(header.f_total_len.get() as usize)
- } else {
- None
- }
- })
+ pub fn list(&self, value: &Arc<T>) -> Vec<(IpAddr, u32)> {
+ let mut res = vec![];
+ res.extend(
+ Self::collect(&*self.ipv4.read(), value)
+ .into_iter()
+ .map(|(ip, cidr)| (IpAddr::V4(ip), cidr)),
+ );
+ res.extend(
+ Self::collect(&*self.ipv6.read(), value)
+ .into_iter()
+ .map(|(ip, cidr)| (IpAddr::V6(ip), cidr)),
+ );
+ res
+ }
+
+ pub fn remove(&self, value: &Arc<T>) {
+ let mut v4 = self.ipv4.write();
+ let mut v6 = self.ipv6.write();
+ for (ip, cidr) in Self::collect(&*v4, value) {
+ v4.remove(ip, cidr);
}
- VERSION_IP6 => {
- // check length and cast to IPv6 header
- let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
- LayoutVerified::new_from_prefix(packet)?;
-
- log::trace!(
- "Router, check route for IPv6 source: {:?}",
- Ipv6Addr::from(header.f_source)
- );
-
- // check IPv6 source address
- device
- .ipv6
- .read()
- .longest_match(Ipv6Addr::from(header.f_source))
- .and_then(|(_, _, p)| {
- if Arc::ptr_eq(p, peer) {
- Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
- } else {
- None
- }
- })
+ for (ip, cidr) in Self::collect(&*v6, value) {
+ v6.remove(ip, cidr);
}
- _ => None,
+ }
+
+ #[inline(always)]
+ pub fn get_route(&self, packet: &[u8]) -> Option<Arc<T>> {
+ match packet.get(0)? >> 4 {
+ VERSION_IP4 => {
+ // check length and cast to IPv4 header
+ let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ log::trace!(
+ "Router, get route for IPv4 destination: {:?}",
+ Ipv4Addr::from(header.f_destination)
+ );
+
+ // check IPv4 source address
+ self.ipv4
+ .read()
+ .longest_match(Ipv4Addr::from(header.f_destination))
+ .and_then(|(_, _, p)| Some(p.clone()))
+ }
+ VERSION_IP6 => {
+ // check length and cast to IPv6 header
+ let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ log::trace!(
+ "Router, get route for IPv6 destination: {:?}",
+ Ipv6Addr::from(header.f_destination)
+ );
+
+ // check IPv6 source address
+ self.ipv6
+ .read()
+ .longest_match(Ipv6Addr::from(header.f_destination))
+ .and_then(|(_, _, p)| Some(p.clone()))
+ }
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub fn check_route(&self, peer: &Arc<T>, packet: &[u8]) -> Option<usize> {
+ match packet.get(0)? >> 4 {
+ VERSION_IP4 => {
+ // check length and cast to IPv4 header
+ let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ log::trace!(
+ "Router, check route for IPv4 source: {:?}",
+ Ipv4Addr::from(header.f_source)
+ );
+
+ // check IPv4 source address
+ self.ipv4
+ .read()
+ .longest_match(Ipv4Addr::from(header.f_source))
+ .and_then(|(_, _, p)| {
+ if Arc::ptr_eq(p, peer) {
+ Some(header.f_total_len.get() as usize)
+ } else {
+ None
+ }
+ })
+ }
+ VERSION_IP6 => {
+ // check length and cast to IPv6 header
+ let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) =
+ LayoutVerified::new_from_prefix(packet)?;
+
+ log::trace!(
+ "Router, check route for IPv6 source: {:?}",
+ Ipv6Addr::from(header.f_source)
+ );
+
+ // check IPv6 source address
+ self.ipv6
+ .read()
+ .longest_match(Ipv6Addr::from(header.f_source))
+ .and_then(|(_, _, p)| {
+ if Arc::ptr_eq(p, peer) {
+ Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
+ } else {
+ None
+ }
+ })
+ }
+ _ => None,
+ }
+ }
+
+ pub fn insert(&self, ip: IpAddr, cidr: u32, value: Arc<T>) {
+ match ip {
+ IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value),
+ IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value),
+ };
}
}