diff options
Diffstat (limited to 'device/allowedips.go')
-rw-r--r-- | device/allowedips.go | 374 |
1 files changed, 220 insertions, 154 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index efc27c0..d15373c 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -1,173 +1,201 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device import ( + "container/list" + "encoding/binary" "errors" "math/bits" "net" + "net/netip" "sync" "unsafe" ) -type trieEntry struct { - cidr uint - child [2]*trieEntry - bits net.IP - peer *Peer - - // index of "branching" bit - - bit_at_byte uint - bit_at_shift uint +type parentIndirection struct { + parentBit **trieEntry + parentBitType uint8 } -func isLittleEndian() bool { - one := uint32(1) - return *(*byte)(unsafe.Pointer(&one)) != 0 -} - -func swapU32(i uint32) uint32 { - if !isLittleEndian() { - return i - } - - return bits.ReverseBytes32(i) -} - -func swapU64(i uint64) uint64 { - if !isLittleEndian() { - return i - } - - return bits.ReverseBytes64(i) +type trieEntry struct { + peer *Peer + child [2]*trieEntry + parent parentIndirection + cidr uint8 + bitAtByte uint8 + bitAtShift uint8 + bits []byte + perPeerElem *list.Element } -func commonBits(ip1 net.IP, ip2 net.IP) uint { +func commonBits(ip1, ip2 []byte) uint8 { size := len(ip1) if size == net.IPv4len { - a := (*uint32)(unsafe.Pointer(&ip1[0])) - b := (*uint32)(unsafe.Pointer(&ip2[0])) - x := *a ^ *b - return uint(bits.LeadingZeros32(swapU32(x))) + a := binary.BigEndian.Uint32(ip1) + b := binary.BigEndian.Uint32(ip2) + x := a ^ b + return uint8(bits.LeadingZeros32(x)) } else if size == net.IPv6len { - a := (*uint64)(unsafe.Pointer(&ip1[0])) - b := (*uint64)(unsafe.Pointer(&ip2[0])) - x := *a ^ *b + a := binary.BigEndian.Uint64(ip1) + b := binary.BigEndian.Uint64(ip2) + x := a ^ b if x != 0 { - return uint(bits.LeadingZeros64(swapU64(x))) + return uint8(bits.LeadingZeros64(x)) } - a = (*uint64)(unsafe.Pointer(&ip1[8])) - b = (*uint64)(unsafe.Pointer(&ip2[8])) - x = *a ^ *b - return 64 + uint(bits.LeadingZeros64(swapU64(x))) + a = binary.BigEndian.Uint64(ip1[8:]) + b = binary.BigEndian.Uint64(ip2[8:]) + x = a ^ b + return 64 + uint8(bits.LeadingZeros64(x)) } else { panic("Wrong size bit string") } } -func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].removeByPeer(p) - node.child[1] = node.child[1].removeByPeer(p) +func (node *trieEntry) addToPeerEntries() { + node.perPeerElem = node.peer.trieEntries.PushBack(node) +} - if node.peer != p { - return node +func (node *trieEntry) removeFromPeerEntries() { + if node.perPeerElem != nil { + node.peer.trieEntries.Remove(node.perPeerElem) + node.perPeerElem = nil } +} - // remove peer & merge +func (node *trieEntry) choose(ip []byte) byte { + return (ip[node.bitAtByte] >> node.bitAtShift) & 1 +} - node.peer = nil - if node.child[0] == nil { - return node.child[1] +func (node *trieEntry) maskSelf() { + mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) + for i := 0; i < len(mask); i++ { + node.bits[i] &= mask[i] } - return node.child[0] } -func (node *trieEntry) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 +func (node *trieEntry) zeroizePointers() { + // Make the garbage collector's life slightly easier + node.peer = nil + node.child[0] = nil + node.child[1] = nil + node.parent.parentBit = nil } -func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { - - // at leaf - - if node == nil { - return &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), +func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { + for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { + parent = node + if parent.cidr == cidr { + exact = true + return } + bit := node.choose(ip) + node = node.child[bit] } + return +} - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.peer = peer - return node +func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { + if *trie.parentBit == nil { + node := &trieEntry{ + peer: peer, + parent: trie, + bits: ip, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } - bit := node.choose(ip) - node.child[bit] = node.child[bit].insert(ip, cidr, peer) - return node + node.maskSelf() + node.addToPeerEntries() + *trie.parentBit = node + return + } + node, exact := (*trie.parentBit).nodePlacement(ip, cidr) + if exact { + node.removeFromPeerEntries() + node.peer = peer + node.addToPeerEntries() + return } - - // split node newNode := &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + peer: peer, + bits: ip, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } + newNode.maskSelf() + newNode.addToPeerEntries() - cidr = min(cidr, common) - - // check for shorter prefix + var down *trieEntry + if node == nil { + down = *trie.parentBit + } else { + bit := node.choose(ip) + down = node.child[bit] + if down == nil { + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + return + } + } + common := commonBits(down.bits, ip) + if common < cidr { + cidr = common + } + parent := node if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode + bit := newNode.choose(down.bits) + down.parent = parentIndirection{&newNode.child[bit], bit} + newNode.child[bit] = down + if parent == nil { + newNode.parent = trie + *trie.parentBit = newNode + } else { + bit := parent.choose(newNode.bits) + newNode.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = newNode + } + return } - // create new parent for node & newNode - - parent := &trieEntry{ - bits: ip, - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + node = &trieEntry{ + bits: append([]byte{}, newNode.bits...), + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), + } + node.maskSelf() + + bit := node.choose(down.bits) + down.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = down + bit = node.choose(newNode.bits) + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + if parent == nil { + node.parent = trie + *trie.parentBit = node + } else { + bit := parent.choose(node.bits) + node.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = node } - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent } -func (node *trieEntry) lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip []byte) *Peer { var found *Peer - size := uint(len(ip)) + size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { if node.peer != nil { found = node.peer } - if node.bit_at_byte == size { + if node.bitAtByte == size { break } bit := node.choose(ip) @@ -176,76 +204,114 @@ func (node *trieEntry) lookup(ip net.IP) *Peer { return found } -func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { - if node == nil { - return results - } - if node.peer == p { - mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) - results = append(results, net.IPNet{ - Mask: mask, - IP: node.bits.Mask(mask), - }) - } - results = node.child[0].entriesForPeer(p, results) - results = node.child[1].entriesForPeer(p, results) - return results -} - type AllowedIPs struct { IPv4 *trieEntry IPv6 *trieEntry mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { table.mutex.RLock() defer table.mutex.RUnlock() - allowed := make([]net.IPNet, 0, 10) - allowed = table.IPv4.entriesForPeer(peer, allowed) - allowed = table.IPv6.entriesForPeer(peer, allowed) - return allowed + for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { + node := elem.Value.(*trieEntry) + a, _ := netip.AddrFromSlice(node.bits) + if !cb(netip.PrefixFrom(a, int(node.cidr))) { + return + } + } +} + +func (node *trieEntry) remove() { + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + return + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + node.zeroizePointers() + return + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + node.zeroizePointers() + return + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + node.zeroizePointers() + parent.zeroizePointers() } -func (table *AllowedIPs) Reset() { +func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - - table.IPv4 = nil - table.IPv6 = nil + var node *trieEntry + var exact bool + + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits())) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits())) + } else { + panic(errors.New("removing unknown address type")) + } + if !exact || node == nil || peer != node.peer { + return + } + node.remove() } func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - table.IPv4 = table.IPv4.removeByPeer(peer) - table.IPv6 = table.IPv6.removeByPeer(peer) + var next *list.Element + for elem := peer.trieEntries.Front(); elem != nil; elem = next { + next = elem.Next() + elem.Value.(*trieEntry).remove() + } } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { +func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - switch len(ip) { - case net.IPv6len: - table.IPv6 = table.IPv6.insert(ip, cidr, peer) - case net.IPv4len: - table.IPv4 = table.IPv4.insert(ip, cidr, peer) - default: + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else { panic(errors.New("inserting unknown address type")) } } -func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv4.lookup(address) -} - -func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { +func (table *AllowedIPs) Lookup(ip []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() - return table.IPv6.lookup(address) + switch len(ip) { + case net.IPv6len: + return table.IPv6.lookup(ip) + case net.IPv4len: + return table.IPv4.lookup(ip) + default: + panic(errors.New("looking up unknown address type")) + } } |