aboutsummaryrefslogtreecommitdiffstats
path: root/device/allowedips.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/allowedips.go')
-rw-r--r--device/allowedips.go359
1 files changed, 201 insertions, 158 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index efc27c0..fa46f97 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -1,173 +1,201 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
+ "container/list"
+ "encoding/binary"
"errors"
"math/bits"
"net"
+ "net/netip"
"sync"
"unsafe"
)
-type trieEntry struct {
- cidr uint
- child [2]*trieEntry
- bits net.IP
- peer *Peer
-
- // index of "branching" bit
-
- bit_at_byte uint
- bit_at_shift uint
-}
-
-func isLittleEndian() bool {
- one := uint32(1)
- return *(*byte)(unsafe.Pointer(&one)) != 0
+type parentIndirection struct {
+ parentBit **trieEntry
+ parentBitType uint8
}
-func swapU32(i uint32) uint32 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes32(i)
-}
-
-func swapU64(i uint64) uint64 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes64(i)
+type trieEntry struct {
+ peer *Peer
+ child [2]*trieEntry
+ parent parentIndirection
+ cidr uint8
+ bitAtByte uint8
+ bitAtShift uint8
+ bits []byte
+ perPeerElem *list.Element
}
-func commonBits(ip1 net.IP, ip2 net.IP) uint {
+func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
- a := (*uint32)(unsafe.Pointer(&ip1[0]))
- b := (*uint32)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
- return uint(bits.LeadingZeros32(swapU32(x)))
+ a := binary.BigEndian.Uint32(ip1)
+ b := binary.BigEndian.Uint32(ip2)
+ x := a ^ b
+ return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len {
- a := (*uint64)(unsafe.Pointer(&ip1[0]))
- b := (*uint64)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
+ a := binary.BigEndian.Uint64(ip1)
+ b := binary.BigEndian.Uint64(ip2)
+ x := a ^ b
if x != 0 {
- return uint(bits.LeadingZeros64(swapU64(x)))
+ return uint8(bits.LeadingZeros64(x))
}
- a = (*uint64)(unsafe.Pointer(&ip1[8]))
- b = (*uint64)(unsafe.Pointer(&ip2[8]))
- x = *a ^ *b
- return 64 + uint(bits.LeadingZeros64(swapU64(x)))
+ a = binary.BigEndian.Uint64(ip1[8:])
+ b = binary.BigEndian.Uint64(ip2[8:])
+ x = a ^ b
+ return 64 + uint8(bits.LeadingZeros64(x))
} else {
panic("Wrong size bit string")
}
}
-func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
- if node == nil {
- return node
- }
-
- // walk recursively
-
- node.child[0] = node.child[0].removeByPeer(p)
- node.child[1] = node.child[1].removeByPeer(p)
+func (node *trieEntry) addToPeerEntries() {
+ node.perPeerElem = node.peer.trieEntries.PushBack(node)
+}
- if node.peer != p {
- return node
+func (node *trieEntry) removeFromPeerEntries() {
+ if node.perPeerElem != nil {
+ node.peer.trieEntries.Remove(node.perPeerElem)
+ node.perPeerElem = nil
}
+}
- // remove peer & merge
+func (node *trieEntry) choose(ip []byte) byte {
+ return (ip[node.bitAtByte] >> node.bitAtShift) & 1
+}
- node.peer = nil
- if node.child[0] == nil {
- return node.child[1]
+func (node *trieEntry) maskSelf() {
+ mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
+ for i := 0; i < len(mask); i++ {
+ node.bits[i] &= mask[i]
}
- return node.child[0]
}
-func (node *trieEntry) choose(ip net.IP) byte {
- return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
+func (node *trieEntry) zeroizePointers() {
+ // Make the garbage collector's life slightly easier
+ node.peer = nil
+ node.child[0] = nil
+ node.child[1] = nil
+ node.parent.parentBit = nil
}
-func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
-
- // at leaf
-
- if node == nil {
- return &trieEntry{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
+ for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
+ parent = node
+ if parent.cidr == cidr {
+ exact = true
+ return
}
+ bit := node.choose(ip)
+ node = node.child[bit]
}
+ return
+}
- // traverse deeper
-
- common := commonBits(node.bits, ip)
- if node.cidr <= cidr && common >= node.cidr {
- if node.cidr == cidr {
- node.peer = peer
- return node
+func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
+ if *trie.parentBit == nil {
+ node := &trieEntry{
+ peer: peer,
+ parent: trie,
+ bits: ip,
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
}
- bit := node.choose(ip)
- node.child[bit] = node.child[bit].insert(ip, cidr, peer)
- return node
+ node.maskSelf()
+ node.addToPeerEntries()
+ *trie.parentBit = node
+ return
+ }
+ node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
+ if exact {
+ node.removeFromPeerEntries()
+ node.peer = peer
+ node.addToPeerEntries()
+ return
}
-
- // split node
newNode := &trieEntry{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+ peer: peer,
+ bits: ip,
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
}
+ newNode.maskSelf()
+ newNode.addToPeerEntries()
- cidr = min(cidr, common)
-
- // check for shorter prefix
+ var down *trieEntry
+ if node == nil {
+ down = *trie.parentBit
+ } else {
+ bit := node.choose(ip)
+ down = node.child[bit]
+ if down == nil {
+ newNode.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = newNode
+ return
+ }
+ }
+ common := commonBits(down.bits, ip)
+ if common < cidr {
+ cidr = common
+ }
+ parent := node
if newNode.cidr == cidr {
- bit := newNode.choose(node.bits)
- newNode.child[bit] = node
- return newNode
+ bit := newNode.choose(down.bits)
+ down.parent = parentIndirection{&newNode.child[bit], bit}
+ newNode.child[bit] = down
+ if parent == nil {
+ newNode.parent = trie
+ *trie.parentBit = newNode
+ } else {
+ bit := parent.choose(newNode.bits)
+ newNode.parent = parentIndirection{&parent.child[bit], bit}
+ parent.child[bit] = newNode
+ }
+ return
}
- // create new parent for node & newNode
-
- parent := &trieEntry{
- bits: ip,
- peer: nil,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+ node = &trieEntry{
+ bits: append([]byte{}, newNode.bits...),
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
+ }
+ node.maskSelf()
+
+ bit := node.choose(down.bits)
+ down.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = down
+ bit = node.choose(newNode.bits)
+ newNode.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = newNode
+ if parent == nil {
+ node.parent = trie
+ *trie.parentBit = node
+ } else {
+ bit := parent.choose(node.bits)
+ node.parent = parentIndirection{&parent.child[bit], bit}
+ parent.child[bit] = node
}
-
- bit := parent.choose(ip)
- parent.child[bit] = newNode
- parent.child[bit^1] = node
-
- return parent
}
-func (node *trieEntry) lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
- size := uint(len(ip))
+ size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
- if node.bit_at_byte == size {
+ if node.bitAtByte == size {
break
}
bit := node.choose(ip)
@@ -176,76 +204,91 @@ func (node *trieEntry) lookup(ip net.IP) *Peer {
return found
}
-func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
- if node == nil {
- return results
- }
- if node.peer == p {
- mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
- results = append(results, net.IPNet{
- Mask: mask,
- IP: node.bits.Mask(mask),
- })
- }
- results = node.child[0].entriesForPeer(p, results)
- results = node.child[1].entriesForPeer(p, results)
- return results
-}
-
type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
- allowed := make([]net.IPNet, 0, 10)
- allowed = table.IPv4.entriesForPeer(peer, allowed)
- allowed = table.IPv6.entriesForPeer(peer, allowed)
- return allowed
-}
-
-func (table *AllowedIPs) Reset() {
- table.mutex.Lock()
- defer table.mutex.Unlock()
-
- table.IPv4 = nil
- table.IPv6 = nil
+ for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
+ node := elem.Value.(*trieEntry)
+ a, _ := netip.AddrFromSlice(node.bits)
+ if !cb(netip.PrefixFrom(a, int(node.cidr))) {
+ return
+ }
+ }
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- table.IPv4 = table.IPv4.removeByPeer(peer)
- table.IPv6 = table.IPv6.removeByPeer(peer)
+ var next *list.Element
+ for elem := peer.trieEntries.Front(); elem != nil; elem = next {
+ next = elem.Next()
+ node := elem.Value.(*trieEntry)
+
+ node.removeFromPeerEntries()
+ node.peer = nil
+ if node.child[0] != nil && node.child[1] != nil {
+ continue
+ }
+ bit := 0
+ if node.child[0] == nil {
+ bit = 1
+ }
+ child := node.child[bit]
+ if child != nil {
+ child.parent = node.parent
+ }
+ *node.parent.parentBit = child
+ if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ node.zeroizePointers()
+ continue
+ }
+ parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+ if parent.peer != nil {
+ node.zeroizePointers()
+ continue
+ }
+ child = parent.child[node.parent.parentBitType^1]
+ if child != nil {
+ child.parent = parent.parent
+ }
+ *parent.parent.parentBit = child
+ node.zeroizePointers()
+ parent.zeroizePointers()
+ }
}
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
+func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- switch len(ip) {
- case net.IPv6len:
- table.IPv6 = table.IPv6.insert(ip, cidr, peer)
- case net.IPv4len:
- table.IPv4 = table.IPv4.insert(ip, cidr, peer)
- default:
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else {
panic(errors.New("inserting unknown address type"))
}
}
-func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
- return table.IPv4.lookup(address)
-}
-
-func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
- table.mutex.RLock()
- defer table.mutex.RUnlock()
- return table.IPv6.lookup(address)
+ switch len(ip) {
+ case net.IPv6len:
+ return table.IPv6.lookup(ip)
+ case net.IPv4len:
+ return table.IPv4.lookup(ip)
+ default:
+ panic(errors.New("looking up unknown address type"))
+ }
}