aboutsummaryrefslogtreecommitdiffstats
path: root/device/allowedips_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/allowedips_test.go')
-rw-r--r--device/allowedips_test.go75
1 files changed, 31 insertions, 44 deletions
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 075ff06..cde068e 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,40 +8,17 @@ package device
import (
"math/rand"
"net"
+ "net/netip"
"testing"
)
-/* Todo: More comprehensive
- */
-
type testPairCommonBits struct {
s1 []byte
s2 []byte
- match uint
-}
-
-type testPairTrieInsert struct {
- key []byte
- cidr uint
- peer *Peer
-}
-
-type testPairTrieLookup struct {
- key []byte
- peer *Peer
-}
-
-func printTrie(t *testing.T, p *trieEntry) {
- if p == nil {
- return
- }
- t.Log(p)
- printTrie(t, p.child[0])
- printTrie(t, p.child[1])
+ match uint8
}
func TestCommonBits(t *testing.T) {
-
tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
@@ -62,27 +39,28 @@ func TestCommonBits(t *testing.T) {
}
}
-func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
+func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
+ root := parentIndirection{&trie, 2}
rand.Seed(1)
const AddressLength = 4
- for n := 0; n < peerNumber; n += 1 {
+ for n := 0; n < peerNumber; n++ {
peers = append(peers, &Peer{})
}
- for n := 0; n < addressNumber; n += 1 {
+ for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
+ cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
- trie = trie.insert(addr[:], cidr, peers[index])
+ root.insert(addr[:], cidr, peers[index])
}
- for n := 0; n < b.N; n += 1 {
+ for n := 0; n < b.N; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
trie.lookup(addr[:])
@@ -117,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *trieEntry
+ var allowedIPs AllowedIPs
- insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
- trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
+ insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.lookup([]byte{a, b, c, d})
+ p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer {
t.Error("Assert EQ failed")
}
}
assertNEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.lookup([]byte{a, b, c, d})
+ p := allowedIPs.Lookup([]byte{a, b, c, d})
if p == peer {
t.Error("Assert NEQ failed")
}
@@ -173,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0)
- trie = trie.removeByPeer(a)
+ allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0)
@@ -181,12 +159,21 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0)
- trie = nil
+ allowedIPs.RemoveByPeer(a)
+ allowedIPs.RemoveByPeer(b)
+ allowedIPs.RemoveByPeer(c)
+ allowedIPs.RemoveByPeer(d)
+ allowedIPs.RemoveByPeer(e)
+ allowedIPs.RemoveByPeer(g)
+ allowedIPs.RemoveByPeer(h)
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Expected removing all the peers to empty trie, but it did not")
+ }
insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24)
- trie = trie.removeByPeer(a)
+ allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
}
@@ -204,7 +191,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *trieEntry
+ var allowedIPs AllowedIPs
expand := func(a uint32) []byte {
var out [4]byte
@@ -215,13 +202,13 @@ func TestTrieIPv6(t *testing.T) {
return out[:]
}
- insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
+ insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- trie = trie.insert(addr, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
@@ -230,7 +217,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- p := trie.lookup(addr)
+ p := allowedIPs.Lookup(addr)
if p != peer {
t.Error("Assert EQ failed")
}