diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-06-03 15:40:09 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-06-03 16:29:43 +0200 |
commit | c382222eab9e3814f4df75fd25f8e9e31484b5e0 (patch) | |
tree | 910b69829baae426668c82c83314dcdd9b208437 /device/allowedips.go | |
parent | device: remove recursion from insertion and connect parent pointers (diff) | |
download | wireguard-go-c382222eab9e3814f4df75fd25f8e9e31484b5e0.tar.xz wireguard-go-c382222eab9e3814f4df75fd25f8e9e31484b5e0.zip |
device: remove nodes by peer in O(1) instead of O(n)
Now that we have parent pointers hooked up, we can simply go right to
the node and remove it in place, rather than having to recursively walk
the entire trie.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'device/allowedips.go')
-rw-r--r-- | device/allowedips.go | 58 |
1 files changed, 32 insertions, 26 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index d613121..7af9fc7 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() { } } -func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].removeByPeer(p) - node.child[1] = node.child[1].removeByPeer(p) - - if node.peer != p { - return node - } - - // remove peer & merge - - node.removeFromPeerEntries() - node.peer = nil - if node.child[0] == nil { - return node.child[1] - } - return node.child[0] -} - func (node *trieEntry) choose(ip net.IP) byte { return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } @@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - table.IPv4 = table.IPv4.removeByPeer(peer) - table.IPv6 = table.IPv6.removeByPeer(peer) + var next *list.Element + for elem := peer.trieEntries.Front(); elem != nil; elem = next { + next = elem.Next() + node := elem.Value.(*trieEntry) + + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + continue + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + continue + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + continue + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + } } func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { |