diff options
Diffstat (limited to 'device/allowedips_test.go')
-rw-r--r-- | device/allowedips_test.go | 140 |
1 files changed, 92 insertions, 48 deletions
diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 075ff06..a4b08a3 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -8,40 +8,17 @@ package device import ( "math/rand" "net" + "net/netip" "testing" ) -/* Todo: More comprehensive - */ - type testPairCommonBits struct { s1 []byte s2 []byte - match uint -} - -type testPairTrieInsert struct { - key []byte - cidr uint - peer *Peer -} - -type testPairTrieLookup struct { - key []byte - peer *Peer -} - -func printTrie(t *testing.T, p *trieEntry) { - if p == nil { - return - } - t.Log(p) - printTrie(t, p.child[0]) - printTrie(t, p.child[1]) + match uint8 } func TestCommonBits(t *testing.T) { - tests := []testPairCommonBits{ {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, @@ -62,29 +39,30 @@ func TestCommonBits(t *testing.T) { } } -func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { +func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) { var trie *trieEntry var peers []*Peer + root := parentIndirection{&trie, 2} - rand.Seed(1) + rng := rand.New(rand.NewSource(1)) const AddressLength = 4 - for n := 0; n < peerNumber; n += 1 { + for n := 0; n < peerNumber; n++ { peers = append(peers, &Peer{}) } - for n := 0; n < addressNumber; n += 1 { + for n := 0; n < addressNumber; n++ { var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % peerNumber - trie = trie.insert(addr[:], cidr, peers[index]) + rng.Read(addr[:]) + cidr := uint8(rng.Uint32() % (AddressLength * 8)) + index := rng.Int() % peerNumber + root.insert(addr[:], cidr, peers[index]) } - for n := 0; n < b.N; n += 1 { + for n := 0; n < b.N; n++ { var addr [AddressLength]byte - rand.Read(addr[:]) + rng.Read(addr[:]) trie.lookup(addr[:]) } } @@ -117,21 +95,25 @@ func TestTrieIPv4(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs + + insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) + } - insert := func(peer *Peer, a, b, c, d byte, cidr uint) { - trie = trie.insert([]byte{a, b, c, d}, cidr, peer) + remove := func(peer *Peer, a, b, c, d byte, cidr uint8) { + allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { t.Error("Assert EQ failed") } } assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p == peer { t.Error("Assert NEQ failed") } @@ -173,7 +155,7 @@ func TestTrieIPv4(t *testing.T) { assertEQ(a, 192, 0, 0, 0) assertEQ(a, 255, 0, 0, 0) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0) @@ -181,14 +163,38 @@ func TestTrieIPv4(t *testing.T) { assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0) - trie = nil + allowedIPs.RemoveByPeer(a) + allowedIPs.RemoveByPeer(b) + allowedIPs.RemoveByPeer(c) + allowedIPs.RemoveByPeer(d) + allowedIPs.RemoveByPeer(e) + allowedIPs.RemoveByPeer(g) + allowedIPs.RemoveByPeer(h) + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Expected removing all the peers to empty trie, but it did not") + } insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 24) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) + + insert(a, 1, 0, 0, 0, 32) + insert(a, 192, 0, 0, 0, 24) + assertEQ(a, 1, 0, 0, 0) + assertEQ(a, 192, 0, 0, 1) + remove(a, 192, 0, 0, 0, 32) + assertEQ(a, 192, 0, 0, 1) + remove(nil, 192, 0, 0, 0, 24) + assertEQ(a, 192, 0, 0, 1) + remove(b, 192, 0, 0, 0, 24) + assertEQ(a, 192, 0, 0, 1) + remove(a, 192, 0, 0, 0, 24) + assertNEQ(a, 192, 0, 0, 1) + remove(a, 1, 0, 0, 0, 32) + assertNEQ(a, 1, 0, 0, 0) } /* Test ported from kernel implementation: @@ -204,7 +210,7 @@ func TestTrieIPv6(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs expand := func(a uint32) []byte { var out [4]byte @@ -215,13 +221,22 @@ func TestTrieIPv6(t *testing.T) { return out[:] } - insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { var addr []byte addr = append(addr, expand(a)...) addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - trie = trie.insert(addr, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) + } + + remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { @@ -230,12 +245,24 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - p := trie.lookup(addr) + p := allowedIPs.Lookup(addr) if p != peer { t.Error("Assert EQ failed") } } + assertNEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := allowedIPs.Lookup(addr) + if p == peer { + t.Error("Assert NEQ failed") + } + } + insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) insert(c, 0x26075300, 0x60006b00, 0, 0, 64) insert(e, 0, 0, 0, 0, 0) @@ -257,4 +284,21 @@ func TestTrieIPv6(t *testing.T) { assertEQ(h, 0x24046800, 0x40040800, 0, 0) assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) + + insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) + remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) + remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) } |