aboutsummaryrefslogtreecommitdiffstats
path: root/src/wireguard/router/route.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/wireguard/router/route.rs')
-rw-r--r--src/wireguard/router/route.rs27
1 files changed, 13 insertions, 14 deletions
diff --git a/src/wireguard/router/route.rs b/src/wireguard/router/route.rs
index 1c93009..40dc36b 100644
--- a/src/wireguard/router/route.rs
+++ b/src/wireguard/router/route.rs
@@ -4,7 +4,6 @@ use zerocopy::LayoutVerified;
use std::mem;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
-use std::sync::Arc;
use spin::RwLock;
use treebitmap::address::Address;
@@ -12,12 +11,12 @@ use treebitmap::IpLookupTable;
/* Functions for obtaining and validating "cryptokey" routes */
-pub struct RoutingTable<T> {
- ipv4: RwLock<IpLookupTable<Ipv4Addr, Arc<T>>>,
- ipv6: RwLock<IpLookupTable<Ipv6Addr, Arc<T>>>,
+pub struct RoutingTable<T: Eq + Clone> {
+ ipv4: RwLock<IpLookupTable<Ipv4Addr, T>>,
+ ipv6: RwLock<IpLookupTable<Ipv6Addr, T>>,
}
-impl<T> RoutingTable<T> {
+impl<T: Eq + Clone> RoutingTable<T> {
pub fn new() -> Self {
RoutingTable {
ipv4: RwLock::new(IpLookupTable::new()),
@@ -26,27 +25,27 @@ impl<T> RoutingTable<T> {
}
// collect keys mapping to the given value
- fn collect<A>(table: &IpLookupTable<A, Arc<T>>, value: &Arc<T>) -> Vec<(A, u32)>
+ fn collect<A>(table: &IpLookupTable<A, T>, value: &T) -> Vec<(A, u32)>
where
A: Address,
{
let mut res = Vec::new();
for (ip, cidr, v) in table.iter() {
- if Arc::ptr_eq(v, value) {
+ if v == value {
res.push((ip, cidr))
}
}
res
}
- pub fn insert(&self, ip: IpAddr, cidr: u32, value: Arc<T>) {
+ pub fn insert(&self, ip: IpAddr, cidr: u32, value: T) {
match ip {
IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value),
IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value),
};
}
- pub fn list(&self, value: &Arc<T>) -> Vec<(IpAddr, u32)> {
+ pub fn list(&self, value: &T) -> Vec<(IpAddr, u32)> {
let mut res = vec![];
res.extend(
Self::collect(&*self.ipv4.read(), value)
@@ -61,7 +60,7 @@ impl<T> RoutingTable<T> {
res
}
- pub fn remove(&self, value: &Arc<T>) {
+ pub fn remove(&self, value: &T) {
let mut v4 = self.ipv4.write();
for (ip, cidr) in Self::collect(&*v4, value) {
v4.remove(ip, cidr);
@@ -74,7 +73,7 @@ impl<T> RoutingTable<T> {
}
#[inline(always)]
- pub fn get_route(&self, packet: &[u8]) -> Option<Arc<T>> {
+ pub fn get_route(&self, packet: &[u8]) -> Option<T> {
match packet.get(0)? >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
@@ -113,7 +112,7 @@ impl<T> RoutingTable<T> {
}
#[inline(always)]
- pub fn check_route(&self, peer: &Arc<T>, packet: &[u8]) -> Option<usize> {
+ pub fn check_route(&self, peer: &T, packet: &[u8]) -> Option<usize> {
match packet.get(0)? >> 4 {
VERSION_IP4 => {
// check length and cast to IPv4 header
@@ -130,7 +129,7 @@ impl<T> RoutingTable<T> {
.read()
.longest_match(Ipv4Addr::from(header.f_source))
.and_then(|(_, _, p)| {
- if Arc::ptr_eq(p, peer) {
+ if p == peer {
Some(header.f_total_len.get() as usize)
} else {
None
@@ -152,7 +151,7 @@ impl<T> RoutingTable<T> {
.read()
.longest_match(Ipv6Addr::from(header.f_source))
.and_then(|(_, _, p)| {
- if Arc::ptr_eq(p, peer) {
+ if p == peer {
Some(header.f_len.get() as usize + mem::size_of::<IPv6Header>())
} else {
None