diff options
Diffstat (limited to 'device/allowedips.go')
-rw-r--r-- | device/allowedips.go | 89 |
1 files changed, 56 insertions, 33 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index fa46f97..d15373c 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device @@ -223,6 +223,60 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) } } +func (node *trieEntry) remove() { + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + return + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + node.zeroizePointers() + return + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + node.zeroizePointers() + return + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + node.zeroizePointers() + parent.zeroizePointers() +} + +func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + var node *trieEntry + var exact bool + + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits())) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits())) + } else { + panic(errors.New("removing unknown address type")) + } + if !exact || node == nil || peer != node.peer { + return + } + node.remove() +} + func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() @@ -230,38 +284,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { var next *list.Element for elem := peer.trieEntries.Front(); elem != nil; elem = next { next = elem.Next() - node := elem.Value.(*trieEntry) - - node.removeFromPeerEntries() - node.peer = nil - if node.child[0] != nil && node.child[1] != nil { - continue - } - bit := 0 - if node.child[0] == nil { - bit = 1 - } - child := node.child[bit] - if child != nil { - child.parent = node.parent - } - *node.parent.parentBit = child - if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { - node.zeroizePointers() - continue - } - parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) - if parent.peer != nil { - node.zeroizePointers() - continue - } - child = parent.child[node.parent.parentBitType^1] - if child != nil { - child.parent = parent.parent - } - *parent.parent.parentBit = child - node.zeroizePointers() - parent.zeroizePointers() + elem.Value.(*trieEntry).remove() } } |