aboutsummaryrefslogtreecommitdiffstats
path: root/device/allowedips_rand_test.go
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-06-03 15:40:09 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2021-06-03 16:29:43 +0200
commitc382222eab9e3814f4df75fd25f8e9e31484b5e0 (patch)
tree910b69829baae426668c82c83314dcdd9b208437 /device/allowedips_rand_test.go
parentdevice: remove recursion from insertion and connect parent pointers (diff)
downloadwireguard-go-c382222eab9e3814f4df75fd25f8e9e31484b5e0.tar.xz
wireguard-go-c382222eab9e3814f4df75fd25f8e9e31484b5e0.zip
device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to the node and remove it in place, rather than having to recursively walk the entire trie. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to '')
-rw-r--r--device/allowedips_rand_test.go96
1 files changed, 50 insertions, 46 deletions
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 48a5bcd..c5f80fe 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -7,6 +7,7 @@ package device
import (
"math/rand"
+ "net"
"sort"
"testing"
)
@@ -64,68 +65,71 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil
}
-func TestTrieRandomIPv4(t *testing.T) {
- var slow SlowRouter
- var peers []*Peer
- var allowedIPs AllowedIPs
-
- rand.Seed(1)
-
- const AddressLength = 4
-
- for n := 0; n < NumberOfPeers; n++ {
- peers = append(peers, &Peer{})
- }
-
- for n := 0; n < NumberOfAddresses; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint8(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- allowedIPs.Insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
- }
-
- for n := 0; n < NumberOfTests; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := allowedIPs.LookupIPv4(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
+ n := 0
+ for _, x := range r {
+ if x.peer != peer {
+ r[n] = x
+ n++
}
}
+ return r[:n]
}
-func TestTrieRandomIPv6(t *testing.T) {
- var slow SlowRouter
+func TestTrieRandom(t *testing.T) {
+ var slow4, slow6 SlowRouter
var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1)
- const AddressLength = 16
-
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint8(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- allowedIPs.Insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ cidr := uint8(rand.Intn(32) + 1)
+ index := rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(addr4[:], cidr, peers[index])
+ slow4 = slow4.Insert(addr4[:], cidr, peers[index])
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ cidr = uint8(rand.Intn(128) + 1)
+ index = rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(addr6[:], cidr, peers[index])
+ slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
- for n := 0; n < NumberOfTests; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := allowedIPs.LookupIPv6(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+ for p := 0; ; p++ {
+ for n := 0; n < NumberOfTests; n++ {
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ peer1 := slow4.Lookup(addr4[:])
+ peer2 := allowedIPs.LookupIPv4(addr4[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
+ }
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ peer1 = slow6.Lookup(addr6[:])
+ peer2 = allowedIPs.LookupIPv6(addr6[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
+ }
+ }
+ if p >= len(peers) {
+ break
}
+ allowedIPs.RemoveByPeer(peers[p])
+ slow4 = slow4.RemoveByPeer(peers[p])
+ slow6 = slow6.RemoveByPeer(peers[p])
+ }
+
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Failed to remove all nodes from trie by peer")
}
}