diff options
author | 2025-05-20 23:03:06 +0200 | |
---|---|---|
committer | 2025-05-20 23:03:06 +0200 | |
commit | 256bcbd70d5b4eaae2a9f21a9889498c0f89041c (patch) | |
tree | 99a8989021d6ccaa7700095f3def511554ba4ec3 | |
parent | version: bump snapshot (diff) | |
download | wireguard-go-256bcbd70d5b4eaae2a9f21a9889498c0f89041c.tar.xz wireguard-go-256bcbd70d5b4eaae2a9f21a9889498c0f89041c.zip |
device: add support for removing allowedips individually
This pairs with the recent change in wireguard-tools.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | device/allowedips.go | 87 | ||||
-rw-r--r-- | device/allowedips_test.go | 57 | ||||
-rw-r--r-- | device/uapi.go | 15 |
3 files changed, 125 insertions, 34 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index b40c817..d15373c 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -223,6 +223,60 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) } } +func (node *trieEntry) remove() { + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + return + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + node.zeroizePointers() + return + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + node.zeroizePointers() + return + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + node.zeroizePointers() + parent.zeroizePointers() +} + +func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + var node *trieEntry + var exact bool + + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits())) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits())) + } else { + panic(errors.New("removing unknown address type")) + } + if !exact || node == nil || peer != node.peer { + return + } + node.remove() +} + func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() @@ -230,38 +284,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { var next *list.Element for elem := peer.trieEntries.Front(); elem != nil; elem = next { next = elem.Next() - node := elem.Value.(*trieEntry) - - node.removeFromPeerEntries() - node.peer = nil - if node.child[0] != nil && node.child[1] != nil { - continue - } - bit := 0 - if node.child[0] == nil { - bit = 1 - } - child := node.child[bit] - if child != nil { - child.parent = node.parent - } - *node.parent.parentBit = child - if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { - node.zeroizePointers() - continue - } - parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) - if parent.peer != nil { - node.zeroizePointers() - continue - } - child = parent.child[node.parent.parentBitType^1] - if child != nil { - child.parent = parent.parent - } - *parent.parent.parentBit = child - node.zeroizePointers() - parent.zeroizePointers() + elem.Value.(*trieEntry).remove() } } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 7df7da5..a4b08a3 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -101,6 +101,10 @@ func TestTrieIPv4(t *testing.T) { allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } + remove := func(peer *Peer, a, b, c, d byte, cidr uint8) { + allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) + } + assertEQ := func(peer *Peer, a, b, c, d byte) { p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { @@ -176,6 +180,21 @@ func TestTrieIPv4(t *testing.T) { allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) + + insert(a, 1, 0, 0, 0, 32) + insert(a, 192, 0, 0, 0, 24) + assertEQ(a, 1, 0, 0, 0) + assertEQ(a, 192, 0, 0, 1) + remove(a, 192, 0, 0, 0, 32) + assertEQ(a, 192, 0, 0, 1) + remove(nil, 192, 0, 0, 0, 24) + assertEQ(a, 192, 0, 0, 1) + remove(b, 192, 0, 0, 0, 24) + assertEQ(a, 192, 0, 0, 1) + remove(a, 192, 0, 0, 0, 24) + assertNEQ(a, 192, 0, 0, 1) + remove(a, 1, 0, 0, 0, 32) + assertNEQ(a, 1, 0, 0, 0) } /* Test ported from kernel implementation: @@ -211,6 +230,15 @@ func TestTrieIPv6(t *testing.T) { allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } + remove := func(peer *Peer, a, b, c, d uint32, cidr uint8) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + allowedIPs.Remove(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) + } + assertEQ := func(peer *Peer, a, b, c, d uint32) { var addr []byte addr = append(addr, expand(a)...) @@ -223,6 +251,18 @@ func TestTrieIPv6(t *testing.T) { } } + assertNEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := allowedIPs.Lookup(addr) + if p == peer { + t.Error("Assert NEQ failed") + } + } + insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) insert(c, 0x26075300, 0x60006b00, 0, 0, 64) insert(e, 0, 0, 0, 0, 0) @@ -244,4 +284,21 @@ func TestTrieIPv6(t *testing.T) { assertEQ(h, 0x24046800, 0x40040800, 0, 0) assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) + + insert(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + insert(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) + remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 96) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(nil, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(b, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + assertNEQ(a, 0x24446801, 0x40e40800, 0xdeaebeef, 0xdefbeef) + remove(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) + remove(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + assertNEQ(a, 0x24446800, 0xf0e40800, 0xeeaebeef, 0x10101010) } diff --git a/device/uapi.go b/device/uapi.go index 521a741..cc69488 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -371,7 +371,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error device.allowedips.RemoveByPeer(peer.Peer) case "allowed_ip": - device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) + add := true + verb := "Adding" + if len(value) > 0 && value[0] == '-' { + add = false + verb = "Removing" + value = value[1:] + } + device.log.Verbosef("%v - UAPI: %s allowedip", peer.Peer, verb) prefix, err := netip.ParsePrefix(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) @@ -379,7 +386,11 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error if peer.dummy { return nil } - device.allowedips.Insert(prefix, peer.Peer) + if add { + device.allowedips.Insert(prefix, peer.Peer) + } else { + device.allowedips.Remove(prefix, peer.Peer) + } case "protocol_version": if value != "1" { |