diff options
Diffstat (limited to 'src/wireguard/router/route.rs')
-rw-r--r-- | src/wireguard/router/route.rs | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs index 1c93009..40dc36b 100644 --- a/src/wireguard/router/route.rs +++ b/src/wireguard/router/route.rs @@ -4,7 +4,6 @@ use zerocopy::LayoutVerified; use std::mem; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; -use std::sync::Arc; use spin::RwLock; use treebitmap::address::Address; @@ -12,12 +11,12 @@ use treebitmap::IpLookupTable; /* Functions for obtaining and validating "cryptokey" routes */ -pub struct RoutingTable<T> { - ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<T>>>, - ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<T>>>, +pub struct RoutingTable<T: Eq + Clone> { + ipv4: RwLock<IpLookupTable<Ipv4Addr, T>>, + ipv6: RwLock<IpLookupTable<Ipv6Addr, T>>, } -impl<T> RoutingTable<T> { +impl<T: Eq + Clone> RoutingTable<T> { pub fn new() -> Self { RoutingTable { ipv4: RwLock::new(IpLookupTable::new()), @@ -26,27 +25,27 @@ impl<T> RoutingTable<T> { } // collect keys mapping to the given value - fn collect<A>(table: &IpLookupTable<A, Arc<T>>, value: &Arc<T>) -> Vec<(A, u32)> + fn collect<A>(table: &IpLookupTable<A, T>, value: &T) -> Vec<(A, u32)> where A: Address, { let mut res = Vec::new(); for (ip, cidr, v) in table.iter() { - if Arc::ptr_eq(v, value) { + if v == value { res.push((ip, cidr)) } } res } - pub fn insert(&self, ip: IpAddr, cidr: u32, value: Arc<T>) { + pub fn insert(&self, ip: IpAddr, cidr: u32, value: T) { match ip { IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value), IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value), }; } - pub fn list(&self, value: &Arc<T>) -> Vec<(IpAddr, u32)> { + pub fn list(&self, value: &T) -> Vec<(IpAddr, u32)> { let mut res = vec![]; res.extend( Self::collect(&*self.ipv4.read(), value) @@ -61,7 +60,7 @@ impl<T> RoutingTable<T> { res } - pub fn remove(&self, value: &Arc<T>) { + pub fn remove(&self, value: &T) { let mut v4 = self.ipv4.write(); for (ip, cidr) in Self::collect(&*v4, value) { v4.remove(ip, cidr); @@ -74,7 +73,7 @@ impl<T> RoutingTable<T> { } #[inline(always)] - pub fn get_route(&self, packet: &[u8]) -> Option<Arc<T>> { + pub fn get_route(&self, packet: &[u8]) -> Option<T> { match packet.get(0)? >> 4 { VERSION_IP4 => { // check length and cast to IPv4 header @@ -113,7 +112,7 @@ impl<T> RoutingTable<T> { } #[inline(always)] - pub fn check_route(&self, peer: &Arc<T>, packet: &[u8]) -> Option<usize> { + pub fn check_route(&self, peer: &T, packet: &[u8]) -> Option<usize> { match packet.get(0)? >> 4 { VERSION_IP4 => { // check length and cast to IPv4 header @@ -130,7 +129,7 @@ impl<T> RoutingTable<T> { .read() .longest_match(Ipv4Addr::from(header.f_source)) .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, peer) { + if p == peer { Some(header.f_total_len.get() as usize) } else { None @@ -152,7 +151,7 @@ impl<T> RoutingTable<T> { .read() .longest_match(Ipv6Addr::from(header.f_source)) .and_then(|(_, _, p)| { - if Arc::ptr_eq(p, peer) { + if p == peer { Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>()) } else { None |