diff options
Diffstat (limited to 'device/allowedips_test.go')
-rw-r--r-- | device/allowedips_test.go | 75 |
1 files changed, 31 insertions, 44 deletions
diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 075ff06..cde068e 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,40 +8,17 @@ package device import ( "math/rand" "net" + "net/netip" "testing" ) -/* Todo: More comprehensive - */ - type testPairCommonBits struct { s1 []byte s2 []byte - match uint -} - -type testPairTrieInsert struct { - key []byte - cidr uint - peer *Peer -} - -type testPairTrieLookup struct { - key []byte - peer *Peer -} - -func printTrie(t *testing.T, p *trieEntry) { - if p == nil { - return - } - t.Log(p) - printTrie(t, p.child[0]) - printTrie(t, p.child[1]) + match uint8 } func TestCommonBits(t *testing.T) { - tests := []testPairCommonBits{ {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, @@ -62,27 +39,28 @@ func TestCommonBits(t *testing.T) { } } -func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { +func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { var trie *trieEntry var peers []*Peer + root := parentIndirection{&trie, 2} rand.Seed(1) const AddressLength = 4 - for n := 0; n < peerNumber; n += 1 { + for n := 0; n < peerNumber; n++ { peers = append(peers, &Peer{}) } - for n := 0; n < addressNumber; n += 1 { + for n := 0; n < addressNumber; n++ { var addr [AddressLength]byte rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) + cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % peerNumber - trie = trie.insert(addr[:], cidr, peers[index]) + root.insert(addr[:], cidr, peers[index]) } - for n := 0; n < b.N; n += 1 { + for n := 0; n < b.N; n++ { var addr [AddressLength]byte rand.Read(addr[:]) trie.lookup(addr[:]) @@ -117,21 +95,21 @@ func TestTrieIPv4(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs - insert := func(peer *Peer, a, b, c, d byte, cidr uint) { - trie = trie.insert([]byte{a, b, c, d}, cidr, peer) + insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { t.Error("Assert EQ failed") } } assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p == peer { t.Error("Assert NEQ failed") } @@ -173,7 +151,7 @@ func TestTrieIPv4(t *testing.T) { assertEQ(a, 192, 0, 0, 0) assertEQ(a, 255, 0, 0, 0) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0) @@ -181,12 +159,21 @@ func TestTrieIPv4(t *testing.T) { assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0) - trie = nil + allowedIPs.RemoveByPeer(a) + allowedIPs.RemoveByPeer(b) + allowedIPs.RemoveByPeer(c) + allowedIPs.RemoveByPeer(d) + allowedIPs.RemoveByPeer(e) + allowedIPs.RemoveByPeer(g) + allowedIPs.RemoveByPeer(h) + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Expected removing all the peers to empty trie, but it did not") + } insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 24) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) } @@ -204,7 +191,7 @@ func TestTrieIPv6(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs expand := func(a uint32) []byte { var out [4]byte @@ -215,13 +202,13 @@ func TestTrieIPv6(t *testing.T) { return out[:] } - insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { var addr []byte addr = append(addr, expand(a)...) addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - trie = trie.insert(addr, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { @@ -230,7 +217,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - p := trie.lookup(addr) + p := allowedIPs.Lookup(addr) if p != peer { t.Error("Assert EQ failed") } |