summaryrefslogtreecommitdiffstats
path: root/device
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-06-03 15:40:09 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2021-06-03 16:29:43 +0200
commitc382222eab9e3814f4df75fd25f8e9e31484b5e0 (patch)
tree910b69829baae426668c82c83314dcdd9b208437 /device
parentdevice: remove recursion from insertion and connect parent pointers (diff)
downloadwireguard-go-c382222eab9e3814f4df75fd25f8e9e31484b5e0.tar.xz
wireguard-go-c382222eab9e3814f4df75fd25f8e9e31484b5e0.zip
device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to the node and remove it in place, rather than having to recursively walk the entire trie. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'device')
-rw-r--r--device/allowedips.go58
-rw-r--r--device/allowedips_rand_test.go96
2 files changed, 82 insertions, 72 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index d613121..7af9fc7 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
-func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
- if node == nil {
- return node
- }
-
- // walk recursively
-
- node.child[0] = node.child[0].removeByPeer(p)
- node.child[1] = node.child[1].removeByPeer(p)
-
- if node.peer != p {
- return node
- }
-
- // remove peer & merge
-
- node.removeFromPeerEntries()
- node.peer = nil
- if node.child[0] == nil {
- return node.child[1]
- }
- return node.child[0]
-}
-
func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- table.IPv4 = table.IPv4.removeByPeer(peer)
- table.IPv6 = table.IPv6.removeByPeer(peer)
+ var next *list.Element
+ for elem := peer.trieEntries.Front(); elem != nil; elem = next {
+ next = elem.Next()
+ node := elem.Value.(*trieEntry)
+
+ node.removeFromPeerEntries()
+ node.peer = nil
+ if node.child[0] != nil && node.child[1] != nil {
+ continue
+ }
+ bit := 0
+ if node.child[0] == nil {
+ bit = 1
+ }
+ child := node.child[bit]
+ if child != nil {
+ child.parent = node.parent
+ }
+ *node.parent.parentBit = child
+ if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ continue
+ }
+ parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+ if parent.peer != nil {
+ continue
+ }
+ child = parent.child[node.parent.parentBitType^1]
+ if child != nil {
+ child.parent = parent.parent
+ }
+ *parent.parent.parentBit = child
+ }
}
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 48a5bcd..c5f80fe 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -7,6 +7,7 @@ package device
import (
"math/rand"
+ "net"
"sort"
"testing"
)
@@ -64,68 +65,71 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil
}
-func TestTrieRandomIPv4(t *testing.T) {
- var slow SlowRouter
- var peers []*Peer
- var allowedIPs AllowedIPs
-
- rand.Seed(1)
-
- const AddressLength = 4
-
- for n := 0; n < NumberOfPeers; n++ {
- peers = append(peers, &Peer{})
- }
-
- for n := 0; n < NumberOfAddresses; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint8(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- allowedIPs.Insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
- }
-
- for n := 0; n < NumberOfTests; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := allowedIPs.LookupIPv4(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
+ n := 0
+ for _, x := range r {
+ if x.peer != peer {
+ r[n] = x
+ n++
}
}
+ return r[:n]
}
-func TestTrieRandomIPv6(t *testing.T) {
- var slow SlowRouter
+func TestTrieRandom(t *testing.T) {
+ var slow4, slow6 SlowRouter
var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1)
- const AddressLength = 16
-
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint8(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- allowedIPs.Insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ cidr := uint8(rand.Intn(32) + 1)
+ index := rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(addr4[:], cidr, peers[index])
+ slow4 = slow4.Insert(addr4[:], cidr, peers[index])
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ cidr = uint8(rand.Intn(128) + 1)
+ index = rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(addr6[:], cidr, peers[index])
+ slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
- for n := 0; n < NumberOfTests; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := allowedIPs.LookupIPv6(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+ for p := 0; ; p++ {
+ for n := 0; n < NumberOfTests; n++ {
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ peer1 := slow4.Lookup(addr4[:])
+ peer2 := allowedIPs.LookupIPv4(addr4[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
+ }
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ peer1 = slow6.Lookup(addr6[:])
+ peer2 = allowedIPs.LookupIPv6(addr6[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
+ }
+ }
+ if p >= len(peers) {
+ break
}
+ allowedIPs.RemoveByPeer(peers[p])
+ slow4 = slow4.RemoveByPeer(peers[p])
+ slow6 = slow6.RemoveByPeer(peers[p])
+ }
+
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Failed to remove all nodes from trie by peer")
}
}