aboutsummaryrefslogtreecommitdiffstats
path: root/device
diff options
context:
space:
mode:
Diffstat (limited to 'device')
-rw-r--r--device/allowedips.go17
-rw-r--r--device/allowedips_rand_test.go4
-rw-r--r--device/allowedips_test.go6
-rw-r--r--device/receive.go4
-rw-r--r--device/send.go4
5 files changed, 18 insertions, 17 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index 7af9fc7..95615ab 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -285,14 +285,15 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
}
}
-func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
- return table.IPv4.lookup(address)
-}
-
-func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
- table.mutex.RLock()
- defer table.mutex.RUnlock()
- return table.IPv6.lookup(address)
+ switch len(address) {
+ case net.IPv6len:
+ return table.IPv6.lookup(address)
+ case net.IPv4len:
+ return table.IPv4.lookup(address)
+ default:
+ panic(errors.New("looking up unknown address type"))
+ }
}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index c5f80fe..8d1e633 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -108,7 +108,7 @@ func TestTrieRandom(t *testing.T) {
var addr4 [4]byte
rand.Read(addr4[:])
peer1 := slow4.Lookup(addr4[:])
- peer2 := allowedIPs.LookupIPv4(addr4[:])
+ peer2 := allowedIPs.Lookup(addr4[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
}
@@ -116,7 +116,7 @@ func TestTrieRandom(t *testing.T) {
var addr6 [16]byte
rand.Read(addr6[:])
peer1 = slow6.Lookup(addr6[:])
- peer2 = allowedIPs.LookupIPv6(addr6[:])
+ peer2 = allowedIPs.Lookup(addr6[:])
if peer1 != peer2 {
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index cbd32cc..7701cde 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -102,14 +102,14 @@ func TestTrieIPv4(t *testing.T) {
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
- p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
+ p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer {
t.Error("Assert EQ failed")
}
}
assertNEQ := func(peer *Peer, a, b, c, d byte) {
- p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
+ p := allowedIPs.Lookup([]byte{a, b, c, d})
if p == peer {
t.Error("Assert NEQ failed")
}
@@ -208,7 +208,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- p := allowedIPs.LookupIPv6(addr)
+ p := allowedIPs.Lookup(addr)
if p != peer {
t.Error("Assert EQ failed")
}
diff --git a/device/receive.go b/device/receive.go
index 1182246..5857481 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -447,7 +447,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.allowedips.LookupIPv4(src) != peer {
+ if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
goto skip
}
@@ -464,7 +464,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.allowedips.LookupIPv6(src) != peer {
+ if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
goto skip
}
diff --git a/device/send.go b/device/send.go
index a4f07e4..b05c69e 100644
--- a/device/send.go
+++ b/device/send.go
@@ -254,14 +254,14 @@ func (device *Device) RoutineReadFromTUN() {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.allowedips.LookupIPv4(dst)
+ peer = device.allowedips.Lookup(dst)
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.allowedips.LookupIPv6(dst)
+ peer = device.allowedips.Lookup(dst)
default:
device.log.Verbosef("Received packet with unknown IP version")