diff options
Diffstat (limited to 'device/allowedips_test.go')
-rw-r--r-- | device/allowedips_test.go | 71 |
1 files changed, 64 insertions, 7 deletions
diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 225c788..a4b08a3 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -39,12 +39,12 @@ func TestCommonBits(t *testing.T) { } } -func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { +func benchmarkTrie(peerNumber, addressNumber, _ int, b *testing.B) { var trie *trieEntry var peers []*Peer root := parentIndirection{&trie, 2} - rand.Seed(1) + rng := rand.New(rand.NewSource(1)) const AddressLength = 4 @@ -54,15 +54,15 @@ func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { for n := 0; n < addressNumber; n++ { var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint8(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % peerNumber + rng.Read(addr[:]) + cidr := uint8(rng.Uint32() % (AddressLength * 8)) + index := rng.Int() % peerNumber root.insert(addr[:], cidr, peers[index]) } for n := 0; n < b.N; n++ { var addr [AddressLength]byte - rand.Read(addr[:]) + rng.Read(addr[:]) trie.lookup(addr[:]) } } @@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) { allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } + remove := func(peer *Peer, a, b, c, d byte, cidr uint8) { + allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) + } + assertEQ := func(peer *Peer, a, b, c, d byte) { p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { @@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) { allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) + + insert(a, 1, 0, 0, 0, 32) + insert(a, 192, 0, 0, 0, 24) + assertEQ(a, 1, 0, 0, 0) + assertEQ(a, 192, 0, 0, 1) + remove(a, 192, 0, 0, 0, 32) + assertEQ(a, 192, 0, 0, 1) + remove(nil, 192, 0, 0, 0, 24) + assertEQ(a, 192, 0, 0, 1) + remove(b, 192, 0, 0, 0, 24) + assertEQ(a, 192, 0, 0, 1) + remove(a, 192, 0, 0, 0, 24) + assertNEQ(a, 192, 0, 0, 1) + remove(a, 1, 0, 0, 0, 32) + assertNEQ(a, 1, 0, 0, 0) } /* Test ported from kernel implementation: @@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) { allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } + remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) + } + assertEQ := func(peer *Peer, a, b, c, d uint32) { var addr []byte addr = append(addr, expand(a)...) @@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) { } } + assertNEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := allowedIPs.Lookup(addr) + if p == peer { + t.Error("Assert NEQ failed") + } + } + insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) insert(c, 0x26075300, 0x60006b00, 0, 0, 64) insert(e, 0, 0, 0, 0, 0) @@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) { assertEQ(h, 0x24046800, 0x40040800, 0, 0) assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) + + insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) + remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) + remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) } |