aboutsummaryrefslogtreecommitdiffstats
path: root/device/allowedips.go
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-06-03 15:40:09 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2021-06-03 16:29:43 +0200
commitc382222eab9e3814f4df75fd25f8e9e31484b5e0 (patch)
tree910b69829baae426668c82c83314dcdd9b208437 /device/allowedips.go
parentdevice: remove recursion from insertion and connect parent pointers (diff)
downloadwireguard-go-c382222eab9e3814f4df75fd25f8e9e31484b5e0.tar.xz
wireguard-go-c382222eab9e3814f4df75fd25f8e9e31484b5e0.zip
device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to the node and remove it in place, rather than having to recursively walk the entire trie. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'device/allowedips.go')
-rw-r--r--device/allowedips.go58
1 files changed, 32 insertions, 26 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index d613121..7af9fc7 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
-func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
- if node == nil {
- return node
- }
-
- // walk recursively
-
- node.child[0] = node.child[0].removeByPeer(p)
- node.child[1] = node.child[1].removeByPeer(p)
-
- if node.peer != p {
- return node
- }
-
- // remove peer & merge
-
- node.removeFromPeerEntries()
- node.peer = nil
- if node.child[0] == nil {
- return node.child[1]
- }
- return node.child[0]
-}
-
func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- table.IPv4 = table.IPv4.removeByPeer(peer)
- table.IPv6 = table.IPv6.removeByPeer(peer)
+ var next *list.Element
+ for elem := peer.trieEntries.Front(); elem != nil; elem = next {
+ next = elem.Next()
+ node := elem.Value.(*trieEntry)
+
+ node.removeFromPeerEntries()
+ node.peer = nil
+ if node.child[0] != nil && node.child[1] != nil {
+ continue
+ }
+ bit := 0
+ if node.child[0] == nil {
+ bit = 1
+ }
+ child := node.child[bit]
+ if child != nil {
+ child.parent = node.parent
+ }
+ *node.parent.parentBit = child
+ if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ continue
+ }
+ parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+ if parent.peer != nil {
+ continue
+ }
+ child = parent.child[node.parent.parentBitType^1]
+ if child != nil {
+ child.parent = parent.parent
+ }
+ *parent.parent.parentBit = child
+ }
}
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {