aboutsummaryrefslogtreecommitdiffstats
path: root/device/allowedips_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/allowedips_test.go')
-rw-r--r--device/allowedips_test.go140
1 files changed, 92 insertions, 48 deletions
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 075ff06..a4b08a3 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,40 +8,17 @@ package device
import (
"math/rand"
"net"
+ "net/netip"
"testing"
)
-/* Todo: More comprehensive
- */
-
type testPairCommonBits struct {
s1 []byte
s2 []byte
- match uint
-}
-
-type testPairTrieInsert struct {
- key []byte
- cidr uint
- peer *Peer
-}
-
-type testPairTrieLookup struct {
- key []byte
- peer *Peer
-}
-
-func printTrie(t *testing.T, p *trieEntry) {
- if p == nil {
- return
- }
- t.Log(p)
- printTrie(t, p.child[0])
- printTrie(t, p.child[1])
+ match uint8
}
func TestCommonBits(t *testing.T) {
-
tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
@@ -62,29 +39,30 @@ func TestCommonBits(t *testing.T) {
}
}
-func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
+func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
+ root := parentIndirection{&trie, 2}
- rand.Seed(1)
+ rng := rand.New(rand.NewSource(1))
const AddressLength = 4
- for n := 0; n < peerNumber; n += 1 {
+ for n := 0; n < peerNumber; n++ {
peers = append(peers, &Peer{})
}
- for n := 0; n < addressNumber; n += 1 {
+ for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % peerNumber
- trie = trie.insert(addr[:], cidr, peers[index])
+ rng.Read(addr[:])
+ cidr := uint8(rng.Uint32() % (AddressLength * 8))
+ index := rng.Int() % peerNumber
+ root.insert(addr[:], cidr, peers[index])
}
- for n := 0; n < b.N; n += 1 {
+ for n := 0; n < b.N; n++ {
var addr [AddressLength]byte
- rand.Read(addr[:])
+ rng.Read(addr[:])
trie.lookup(addr[:])
}
}
@@ -117,21 +95,25 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *trieEntry
+ var allowedIPs AllowedIPs
+
+ insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
+ }
- insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
- trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
+ remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
+ allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.lookup([]byte{a, b, c, d})
+ p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer {
t.Error("Assert EQ failed")
}
}
assertNEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.lookup([]byte{a, b, c, d})
+ p := allowedIPs.Lookup([]byte{a, b, c, d})
if p == peer {
t.Error("Assert NEQ failed")
}
@@ -173,7 +155,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0)
- trie = trie.removeByPeer(a)
+ allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0)
@@ -181,14 +163,38 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0)
- trie = nil
+ allowedIPs.RemoveByPeer(a)
+ allowedIPs.RemoveByPeer(b)
+ allowedIPs.RemoveByPeer(c)
+ allowedIPs.RemoveByPeer(d)
+ allowedIPs.RemoveByPeer(e)
+ allowedIPs.RemoveByPeer(g)
+ allowedIPs.RemoveByPeer(h)
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Expected removing all the peers to empty trie, but it did not")
+ }
insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24)
- trie = trie.removeByPeer(a)
+ allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
+
+ insert(a, 1, 0, 0, 0, 32)
+ insert(a, 192, 0, 0, 0, 24)
+ assertEQ(a, 1, 0, 0, 0)
+ assertEQ(a, 192, 0, 0, 1)
+ remove(a, 192, 0, 0, 0, 32)
+ assertEQ(a, 192, 0, 0, 1)
+ remove(nil, 192, 0, 0, 0, 24)
+ assertEQ(a, 192, 0, 0, 1)
+ remove(b, 192, 0, 0, 0, 24)
+ assertEQ(a, 192, 0, 0, 1)
+ remove(a, 192, 0, 0, 0, 24)
+ assertNEQ(a, 192, 0, 0, 1)
+ remove(a, 1, 0, 0, 0, 32)
+ assertNEQ(a, 1, 0, 0, 0)
}
/* Test ported from kernel implementation:
@@ -204,7 +210,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *trieEntry
+ var allowedIPs AllowedIPs
expand := func(a uint32) []byte {
var out [4]byte
@@ -215,13 +221,22 @@ func TestTrieIPv6(t *testing.T) {
return out[:]
}
- insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
+ insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- trie = trie.insert(addr, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
+ }
+
+ remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
+ var addr []byte
+ addr = append(addr, expand(a)...)
+ addr = append(addr, expand(b)...)
+ addr = append(addr, expand(c)...)
+ addr = append(addr, expand(d)...)
+ allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
@@ -230,12 +245,24 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- p := trie.lookup(addr)
+ p := allowedIPs.Lookup(addr)
if p != peer {
t.Error("Assert EQ failed")
}
}
+ assertNEQ := func(peer *Peer, a, b, c, d uint32) {
+ var addr []byte
+ addr = append(addr, expand(a)...)
+ addr = append(addr, expand(b)...)
+ addr = append(addr, expand(c)...)
+ addr = append(addr, expand(d)...)
+ p := allowedIPs.Lookup(addr)
+ if p == peer {
+ t.Error("Assert NEQ failed")
+ }
+ }
+
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
insert(e, 0, 0, 0, 0, 0)
@@ -257,4 +284,21 @@ func TestTrieIPv6(t *testing.T) {
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
+
+ insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+ insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+ assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+ assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
+ remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
+ assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+ remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+ assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+ remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+ assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+ remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+ assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+ remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+ assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
+ remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+ assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
}