diff options
Diffstat (limited to 'device/allowedips_rand_test.go')
-rw-r--r-- | device/allowedips_rand_test.go | 120 |
1 files changed, 65 insertions, 55 deletions
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 59c10f7..07065c3 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -1,25 +1,28 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" + "net" + "net/netip" "sort" "testing" ) const ( - NumberOfPeers = 100 - NumberOfAddresses = 250 - NumberOfTests = 10000 + NumberOfPeers = 100 + NumberOfPeerRemovals = 4 + NumberOfAddresses = 250 + NumberOfTests = 10000 ) type SlowNode struct { peer *Peer - cidr uint + cidr uint8 bits []byte } @@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) { r[i], r[j] = r[j], r[i] } -func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { +func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { for _, t := range r { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { t.peer = peer @@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer { return nil } -func TestTrieRandomIPv4(t *testing.T) { - var trie *trieEntry - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) +func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { + n := 0 + for _, x := range r { + if x.peer != peer { + r[n] = x + n++ } } + return r[:n] } -func TestTrieRandomIPv6(t *testing.T) { - var trie *trieEntry - var slow SlowRouter +func TestTrieRandom(t *testing.T) { + var slow4, slow6 SlowRouter var peers []*Peer + var allowedIPs AllowedIPs rand.Seed(1) - const AddressLength = 16 - - for n := 0; n < NumberOfPeers; n += 1 { + for n := 0; n < NumberOfPeers; n++ { peers = append(peers, &Peer{}) } - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) + for n := 0; n < NumberOfAddresses; n++ { + var addr4 [4]byte + rand.Read(addr4[:]) + cidr := uint8(rand.Intn(32) + 1) + index := rand.Intn(NumberOfPeers) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) + slow4 = slow4.Insert(addr4[:], cidr, peers[index]) + + var addr6 [16]byte + rand.Read(addr6[:]) + cidr = uint8(rand.Intn(128) + 1) + index = rand.Intn(NumberOfPeers) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) + slow6 = slow6.Insert(addr6[:], cidr, peers[index]) } - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) + var p int + for p = 0; ; p++ { + for n := 0; n < NumberOfTests; n++ { + var addr4 [4]byte + rand.Read(addr4[:]) + peer1 := slow4.Lookup(addr4[:]) + peer2 := allowedIPs.Lookup(addr4[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) + } + + var addr6 [16]byte + rand.Read(addr6[:]) + peer1 = slow6.Lookup(addr6[:]) + peer2 = allowedIPs.Lookup(addr6[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) + } + } + if p >= len(peers) || p >= NumberOfPeerRemovals { + break } + allowedIPs.RemoveByPeer(peers[p]) + slow4 = slow4.RemoveByPeer(peers[p]) + slow6 = slow6.RemoveByPeer(peers[p]) + } + for ; p < len(peers); p++ { + allowedIPs.RemoveByPeer(peers[p]) + } + + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Failed to remove all nodes from trie by peer") } } |