From c382222eab9e3814f4df75fd25f8e9e31484b5e0 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 3 Jun 2021 15:40:09 +0200 Subject: device: remove nodes by peer in O(1) instead of O(n) Now that we have parent pointers hooked up, we can simply go right to the node and remove it in place, rather than having to recursively walk the entire trie. Signed-off-by: Jason A. Donenfeld --- device/allowedips.go | 58 +++++++++++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 26 deletions(-) (limited to 'device/allowedips.go') diff --git a/device/allowedips.go b/device/allowedips.go index d613121..7af9fc7 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() { } } -func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].removeByPeer(p) - node.child[1] = node.child[1].removeByPeer(p) - - if node.peer != p { - return node - } - - // remove peer & merge - - node.removeFromPeerEntries() - node.peer = nil - if node.child[0] == nil { - return node.child[1] - } - return node.child[0] -} - func (node *trieEntry) choose(ip net.IP) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } @@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - table.IPv4 = table.IPv4.removeByPeer(peer) - table.IPv6 = table.IPv6.removeByPeer(peer) + var next *list.Element + for elem := peer.trieEntries.Front(); elem != nil; elem = next { + next = elem.Next() + node := elem.Value.(*trieEntry) + + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + continue + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + continue + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + continue + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + } } func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { -- cgit v1.2.3-59-g8ed1b