aboutsummaryrefslogtreecommitdiffstats
path: root/device/allowedips_rand_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/allowedips_rand_test.go')
-rw-r--r--device/allowedips_rand_test.go120
1 files changed, 65 insertions, 55 deletions
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 59c10f7..07065c3 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -1,25 +1,28 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
+ "net"
+ "net/netip"
"sort"
"testing"
)
const (
- NumberOfPeers = 100
- NumberOfAddresses = 250
- NumberOfTests = 10000
+ NumberOfPeers = 100
+ NumberOfPeerRemovals = 4
+ NumberOfAddresses = 250
+ NumberOfTests = 10000
)
type SlowNode struct {
peer *Peer
- cidr uint
+ cidr uint8
bits []byte
}
@@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}
-func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
+func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer
@@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil
}
-func TestTrieRandomIPv4(t *testing.T) {
- var trie *trieEntry
- var slow SlowRouter
- var peers []*Peer
-
- rand.Seed(1)
-
- const AddressLength = 4
-
- for n := 0; n < NumberOfPeers; n += 1 {
- peers = append(peers, &Peer{})
- }
-
- for n := 0; n < NumberOfAddresses; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- trie = trie.insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
- }
-
- for n := 0; n < NumberOfTests; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := trie.lookup(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
+ n := 0
+ for _, x := range r {
+ if x.peer != peer {
+ r[n] = x
+ n++
}
}
+ return r[:n]
}
-func TestTrieRandomIPv6(t *testing.T) {
- var trie *trieEntry
- var slow SlowRouter
+func TestTrieRandom(t *testing.T) {
+ var slow4, slow6 SlowRouter
var peers []*Peer
+ var allowedIPs AllowedIPs
rand.Seed(1)
- const AddressLength = 16
-
- for n := 0; n < NumberOfPeers; n += 1 {
+ for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
- for n := 0; n < NumberOfAddresses; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- trie = trie.insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
+ for n := 0; n < NumberOfAddresses; n++ {
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ cidr := uint8(rand.Intn(32) + 1)
+ index := rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
+ slow4 = slow4.Insert(addr4[:], cidr, peers[index])
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ cidr = uint8(rand.Intn(128) + 1)
+ index = rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
+ slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
- for n := 0; n < NumberOfTests; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := trie.lookup(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+ var p int
+ for p = 0; ; p++ {
+ for n := 0; n < NumberOfTests; n++ {
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ peer1 := slow4.Lookup(addr4[:])
+ peer2 := allowedIPs.Lookup(addr4[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
+ }
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ peer1 = slow6.Lookup(addr6[:])
+ peer2 = allowedIPs.Lookup(addr6[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
+ }
+ }
+ if p >= len(peers) || p >= NumberOfPeerRemovals {
+ break
}
+ allowedIPs.RemoveByPeer(peers[p])
+ slow4 = slow4.RemoveByPeer(peers[p])
+ slow6 = slow6.RemoveByPeer(peers[p])
+ }
+ for ; p < len(peers); p++ {
+ allowedIPs.RemoveByPeer(peers[p])
+ }
+
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Failed to remove all nodes from trie by peer")
}
}