aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2025-05-20 23:03:06 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2025-05-20 23:03:06 +0200
commit256bcbd70d5b4eaae2a9f21a9889498c0f89041c (patch)
tree99a8989021d6ccaa7700095f3def511554ba4ec3
parentversion: bump snapshot (diff)
downloadwireguard-go-256bcbd70d5b4eaae2a9f21a9889498c0f89041c.tar.xz
wireguard-go-256bcbd70d5b4eaae2a9f21a9889498c0f89041c.zip
device: add support for removing allowedips individually
This pairs with the recent change in wireguard-tools. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--device/allowedips.go87
-rw-r--r--device/allowedips_test.go57
-rw-r--r--device/uapi.go15
3 files changed, 125 insertions, 34 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index b40c817..d15373c 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -223,6 +223,60 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
}
}
+func (node *trieEntry) remove() {
+ node.removeFromPeerEntries()
+ node.peer = nil
+ if node.child[0] != nil && node.child[1] != nil {
+ return
+ }
+ bit := 0
+ if node.child[0] == nil {
+ bit = 1
+ }
+ child := node.child[bit]
+ if child != nil {
+ child.parent = node.parent
+ }
+ *node.parent.parentBit = child
+ if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ node.zeroizePointers()
+ return
+ }
+ parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+ if parent.peer != nil {
+ node.zeroizePointers()
+ return
+ }
+ child = parent.child[node.parent.parentBitType^1]
+ if child != nil {
+ child.parent = parent.parent
+ }
+ *parent.parent.parentBit = child
+ node.zeroizePointers()
+ parent.zeroizePointers()
+}
+
+func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+ var node *trieEntry
+ var exact bool
+
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
+ } else {
+ panic(errors.New("removing unknown address type"))
+ }
+ if !exact || node == nil || peer != node.peer {
+ return
+ }
+ node.remove()
+}
+
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
@@ -230,38 +284,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
- node := elem.Value.(*trieEntry)
-
- node.removeFromPeerEntries()
- node.peer = nil
- if node.child[0] != nil && node.child[1] != nil {
- continue
- }
- bit := 0
- if node.child[0] == nil {
- bit = 1
- }
- child := node.child[bit]
- if child != nil {
- child.parent = node.parent
- }
- *node.parent.parentBit = child
- if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
- node.zeroizePointers()
- continue
- }
- parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
- if parent.peer != nil {
- node.zeroizePointers()
- continue
- }
- child = parent.child[node.parent.parentBitType^1]
- if child != nil {
- child.parent = parent.parent
- }
- *parent.parent.parentBit = child
- node.zeroizePointers()
- parent.zeroizePointers()
+ elem.Value.(*trieEntry).remove()
}
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 7df7da5..a4b08a3 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) {
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
+ remove := func(peer *Peer, a, b, c, d byte, cidr uint8) {
+ allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
+ }
+
assertEQ := func(peer *Peer, a, b, c, d byte) {
p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer {
@@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) {
allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
+
+ insert(a, 1, 0, 0, 0, 32)
+ insert(a, 192, 0, 0, 0, 24)
+ assertEQ(a, 1, 0, 0, 0)
+ assertEQ(a, 192, 0, 0, 1)
+ remove(a, 192, 0, 0, 0, 32)
+ assertEQ(a, 192, 0, 0, 1)
+ remove(nil, 192, 0, 0, 0, 24)
+ assertEQ(a, 192, 0, 0, 1)
+ remove(b, 192, 0, 0, 0, 24)
+ assertEQ(a, 192, 0, 0, 1)
+ remove(a, 192, 0, 0, 0, 24)
+ assertNEQ(a, 192, 0, 0, 1)
+ remove(a, 1, 0, 0, 0, 32)
+ assertNEQ(a, 1, 0, 0, 0)
}
/* Test ported from kernel implementation:
@@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) {
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
+ remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
+ var addr []byte
+ addr = append(addr, expand(a)...)
+ addr = append(addr, expand(b)...)
+ addr = append(addr, expand(c)...)
+ addr = append(addr, expand(d)...)
+ allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
+ }
+
assertEQ := func(peer *Peer, a, b, c, d uint32) {
var addr []byte
addr = append(addr, expand(a)...)
@@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) {
}
}
+ assertNEQ := func(peer *Peer, a, b, c, d uint32) {
+ var addr []byte
+ addr = append(addr, expand(a)...)
+ addr = append(addr, expand(b)...)
+ addr = append(addr, expand(c)...)
+ addr = append(addr, expand(d)...)
+ p := allowedIPs.Lookup(addr)
+ if p == peer {
+ t.Error("Assert NEQ failed")
+ }
+ }
+
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
insert(e, 0, 0, 0, 0, 0)
@@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) {
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
+
+ insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+ insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+ assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+ assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
+ remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96)
+ assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+ remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+ assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+ remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+ assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+ remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+ assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef)
+ remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+ assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
+ remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+ assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010)
}
diff --git a/device/uapi.go b/device/uapi.go
index 521a741..cc69488 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -371,7 +371,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
device.allowedips.RemoveByPeer(peer.Peer)
case "allowed_ip":
- device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
+ add := true
+ verb := "Adding"
+ if len(value) > 0 && value[0] == '-' {
+ add = false
+ verb = "Removing"
+ value = value[1:]
+ }
+ device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb)
prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
@@ -379,7 +386,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if peer.dummy {
return nil
}
- device.allowedips.Insert(prefix, peer.Peer)
+ if add {
+ device.allowedips.Insert(prefix, peer.Peer)
+ } else {
+ device.allowedips.Remove(prefix, peer.Peer)
+ }
case "protocol_version":
if value != "1" {