aboutsummaryrefslogtreecommitdiffstats
path: root/device/allowedips.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/allowedips.go')
-rw-r--r--device/allowedips.go89
1 files changed, 36 insertions, 53 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index c08399b..fa46f97 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -1,15 +1,17 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"container/list"
+ "encoding/binary"
"errors"
"math/bits"
"net"
+ "net/netip"
"sync"
"unsafe"
)
@@ -26,49 +28,28 @@ type trieEntry struct {
cidr uint8
bitAtByte uint8
bitAtShift uint8
- bits net.IP
+ bits []byte
perPeerElem *list.Element
}
-func isLittleEndian() bool {
- one := uint32(1)
- return *(*byte)(unsafe.Pointer(&one)) != 0
-}
-
-func swapU32(i uint32) uint32 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes32(i)
-}
-
-func swapU64(i uint64) uint64 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes64(i)
-}
-
-func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
+func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
- a := (*uint32)(unsafe.Pointer(&ip1[0]))
- b := (*uint32)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
- return uint8(bits.LeadingZeros32(swapU32(x)))
+ a := binary.BigEndian.Uint32(ip1)
+ b := binary.BigEndian.Uint32(ip2)
+ x := a ^ b
+ return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len {
- a := (*uint64)(unsafe.Pointer(&ip1[0]))
- b := (*uint64)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
+ a := binary.BigEndian.Uint64(ip1)
+ b := binary.BigEndian.Uint64(ip2)
+ x := a ^ b
if x != 0 {
- return uint8(bits.LeadingZeros64(swapU64(x)))
+ return uint8(bits.LeadingZeros64(x))
}
- a = (*uint64)(unsafe.Pointer(&ip1[8]))
- b = (*uint64)(unsafe.Pointer(&ip2[8]))
- x = *a ^ *b
- return 64 + uint8(bits.LeadingZeros64(swapU64(x)))
+ a = binary.BigEndian.Uint64(ip1[8:])
+ b = binary.BigEndian.Uint64(ip2[8:])
+ x = a ^ b
+ return 64 + uint8(bits.LeadingZeros64(x))
} else {
panic("Wrong size bit string")
}
@@ -85,7 +66,7 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
-func (node *trieEntry) choose(ip net.IP) byte {
+func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -104,7 +85,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil
}
-func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
+func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
@@ -117,7 +98,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
return
}
-func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
+func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
@@ -207,7 +188,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
}
}
-func (node *trieEntry) lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -229,13 +210,14 @@ type AllowedIPs struct {
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
- if !cb(node.bits, node.cidr) {
+ a, _ := netip.AddrFromSlice(node.bits)
+ if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return
}
}
@@ -283,28 +265,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
}
}
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
+func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- switch len(ip) {
- case net.IPv6len:
- parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
- case net.IPv4len:
- parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
- default:
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else {
panic(errors.New("inserting unknown address type"))
}
}
-func (table *AllowedIPs) Lookup(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
- switch len(address) {
+ switch len(ip) {
case net.IPv6len:
- return table.IPv6.lookup(address)
+ return table.IPv6.lookup(ip)
case net.IPv4len:
- return table.IPv4.lookup(address)
+ return table.IPv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}