aboutsummaryrefslogtreecommitdiffstats
path: root/device/allowedips.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/allowedips.go')
-rw-r--r--device/allowedips.go89
1 files changed, 56 insertions, 33 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index fa46f97..d15373c 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -223,6 +223,60 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
}
}
+func (node *trieEntry) remove() {
+ node.removeFromPeerEntries()
+ node.peer = nil
+ if node.child[0] != nil && node.child[1] != nil {
+ return
+ }
+ bit := 0
+ if node.child[0] == nil {
+ bit = 1
+ }
+ child := node.child[bit]
+ if child != nil {
+ child.parent = node.parent
+ }
+ *node.parent.parentBit = child
+ if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ node.zeroizePointers()
+ return
+ }
+ parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+ if parent.peer != nil {
+ node.zeroizePointers()
+ return
+ }
+ child = parent.child[node.parent.parentBitType^1]
+ if child != nil {
+ child.parent = parent.parent
+ }
+ *parent.parent.parentBit = child
+ node.zeroizePointers()
+ parent.zeroizePointers()
+}
+
+func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+ var node *trieEntry
+ var exact bool
+
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
+ } else {
+ panic(errors.New("removing unknown address type"))
+ }
+ if !exact || node == nil || peer != node.peer {
+ return
+ }
+ node.remove()
+}
+
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
@@ -230,38 +284,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
- node := elem.Value.(*trieEntry)
-
- node.removeFromPeerEntries()
- node.peer = nil
- if node.child[0] != nil && node.child[1] != nil {
- continue
- }
- bit := 0
- if node.child[0] == nil {
- bit = 1
- }
- child := node.child[bit]
- if child != nil {
- child.parent = node.parent
- }
- *node.parent.parentBit = child
- if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
- node.zeroizePointers()
- continue
- }
- parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
- if parent.peer != nil {
- node.zeroizePointers()
- continue
- }
- child = parent.child[node.parent.parentBitType^1]
- if child != nil {
- child.parent = parent.parent
- }
- *parent.parent.parentBit = child
- node.zeroizePointers()
- parent.zeroizePointers()
+ elem.Value.(*trieEntry).remove()
}
}