summaryrefslogtreecommitdiffstats
path: root/device/allowedips.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/allowedips.go')
-rw-r--r--device/allowedips.go58
1 files changed, 32 insertions, 26 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index d613121..7af9fc7 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
-func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
- if node == nil {
- return node
- }
-
- // walk recursively
-
- node.child[0] = node.child[0].removeByPeer(p)
- node.child[1] = node.child[1].removeByPeer(p)
-
- if node.peer != p {
- return node
- }
-
- // remove peer & merge
-
- node.removeFromPeerEntries()
- node.peer = nil
- if node.child[0] == nil {
- return node.child[1]
- }
- return node.child[0]
-}
-
func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- table.IPv4 = table.IPv4.removeByPeer(peer)
- table.IPv6 = table.IPv6.removeByPeer(peer)
+ var next *list.Element
+ for elem := peer.trieEntries.Front(); elem != nil; elem = next {
+ next = elem.Next()
+ node := elem.Value.(*trieEntry)
+
+ node.removeFromPeerEntries()
+ node.peer = nil
+ if node.child[0] != nil && node.child[1] != nil {
+ continue
+ }
+ bit := 0
+ if node.child[0] == nil {
+ bit = 1
+ }
+ child := node.child[bit]
+ if child != nil {
+ child.parent = node.parent
+ }
+ *node.parent.parentBit = child
+ if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ continue
+ }
+ parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+ if parent.peer != nil {
+ continue
+ }
+ child = parent.child[node.parent.parentBitType^1]
+ if child != nil {
+ child.parent = parent.parent
+ }
+ *parent.parent.parentBit = child
+ }
}
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {