aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2022-03-16 19:34:42 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2022-03-16 19:45:10 -0600
commitf3aff443a6e829519c4144b8c523d1335a8e66ef (patch)
tree437fe987b6f95965009ab6d876670777246a07be
parenttun/netstack: bump mod (diff)
downloadwireguard-go-f3aff443a6e829519c4144b8c523d1335a8e66ef.tar.xz
wireguard-go-f3aff443a6e829519c4144b8c523d1335a8e66ef.zip
device: make allowedips genericjd/generic-aip
The implementation of commonBits uses a horrific unsafe.Slice trick. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--device/allowedips.go202
-rw-r--r--device/allowedips_rand_test.go13
-rw-r--r--device/allowedips_test.go54
3 files changed, 151 insertions, 118 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index 3cac694..c36ef3a 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -16,68 +16,86 @@ import (
"unsafe"
)
-type parentIndirection struct {
- parentBit **trieEntry
+type ipArray interface {
+ [4]byte | [16]byte
+}
+
+type parentIndirection[B ipArray] struct {
+ parentBit **trieEntry[B]
parentBitType uint8
}
-type trieEntry struct {
+type trieEntry[B ipArray] struct {
peer *Peer
- child [2]*trieEntry
- parent parentIndirection
+ child [2]*trieEntry[B]
+ parent parentIndirection[B]
cidr uint8
bitAtByte uint8
bitAtShift uint8
- bits []byte
+ bits B
perPeerElem *list.Element
}
-func commonBits(ip1, ip2 []byte) uint8 {
- size := len(ip1)
- if size == net.IPv4len {
- a := binary.BigEndian.Uint32(ip1)
- b := binary.BigEndian.Uint32(ip2)
- x := a ^ b
- return uint8(bits.LeadingZeros32(x))
- } else if size == net.IPv6len {
- a := binary.BigEndian.Uint64(ip1)
- b := binary.BigEndian.Uint64(ip2)
- x := a ^ b
- if x != 0 {
- return uint8(bits.LeadingZeros64(x))
- }
- a = binary.BigEndian.Uint64(ip1[8:])
- b = binary.BigEndian.Uint64(ip2[8:])
- x = a ^ b
- return 64 + uint8(bits.LeadingZeros64(x))
- } else {
- panic("Wrong size bit string")
+func commonBits4(ip1, ip2 [4]byte) uint8 {
+ a := binary.BigEndian.Uint32(ip1[:])
+ b := binary.BigEndian.Uint32(ip2[:])
+ x := a ^ b
+ return uint8(bits.LeadingZeros32(x))
+}
+
+func commonBits16(ip1, ip2 [16]byte) uint8 {
+ a := binary.BigEndian.Uint64(ip1[:8])
+ b := binary.BigEndian.Uint64(ip2[:8])
+ x := a ^ b
+ if x != 0 {
+ return uint8(bits.LeadingZeros64(x))
+ }
+ a = binary.BigEndian.Uint64(ip1[8:])
+ b = binary.BigEndian.Uint64(ip2[8:])
+ x = a ^ b
+ return 64 + uint8(bits.LeadingZeros64(x))
+}
+
+func giveMeA4[B ipArray](b B) [4]byte {
+ return *(*[4]byte)(unsafe.Slice(&b[0], 4))
+}
+
+func giveMeA16[B ipArray](b B) [16]byte {
+ return *(*[16]byte)(unsafe.Slice(&b[0], 16))
+}
+
+func commonBits[B ipArray](ip1, ip2 B) uint8 {
+ if len(ip1) == 4 {
+ return commonBits4(giveMeA4(ip1), giveMeA4(ip2))
+ } else if len(ip1) == 16 {
+ return commonBits16(giveMeA16(ip1), giveMeA16(ip2))
}
+ panic("Wrong size bit string")
}
-func (node *trieEntry) addToPeerEntries() {
+func (node *trieEntry[B]) addToPeerEntries() {
node.perPeerElem = node.peer.trieEntries.PushBack(node)
}
-func (node *trieEntry) removeFromPeerEntries() {
+func (node *trieEntry[B]) removeFromPeerEntries() {
if node.perPeerElem != nil {
node.peer.trieEntries.Remove(node.perPeerElem)
node.perPeerElem = nil
}
}
-func (node *trieEntry) choose(ip []byte) byte {
+func (node *trieEntry[B]) choose(ip B) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
-func (node *trieEntry) maskSelf() {
+func (node *trieEntry[B]) maskSelf() {
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
for i := 0; i < len(mask); i++ {
node.bits[i] &= mask[i]
}
}
-func (node *trieEntry) zeroizePointers() {
+func (node *trieEntry[B]) zeroizePointers() {
// Make the garbage collector's life slightly easier
node.peer = nil
node.child[0] = nil
@@ -85,7 +103,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil
}
-func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
+func (node *trieEntry[B]) nodePlacement(ip B, cidr uint8) (parent *trieEntry[B], exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
@@ -98,9 +116,9 @@ func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry,
return
}
-func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
+func (trie parentIndirection[B]) insert(ip B, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
- node := &trieEntry{
+ node := &trieEntry[B]{
peer: peer,
parent: trie,
bits: ip,
@@ -121,7 +139,7 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
return
}
- newNode := &trieEntry{
+ newNode := &trieEntry[B]{
peer: peer,
bits: ip,
cidr: cidr,
@@ -131,14 +149,14 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
newNode.maskSelf()
newNode.addToPeerEntries()
- var down *trieEntry
+ var down *trieEntry[B]
if node == nil {
down = *trie.parentBit
} else {
bit := node.choose(ip)
down = node.child[bit]
if down == nil {
- newNode.parent = parentIndirection{&node.child[bit], bit}
+ newNode.parent = parentIndirection[B]{&node.child[bit], bit}
node.child[bit] = newNode
return
}
@@ -151,21 +169,21 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if newNode.cidr == cidr {
bit := newNode.choose(down.bits)
- down.parent = parentIndirection{&newNode.child[bit], bit}
+ down.parent = parentIndirection[B]{&newNode.child[bit], bit}
newNode.child[bit] = down
if parent == nil {
newNode.parent = trie
*trie.parentBit = newNode
} else {
bit := parent.choose(newNode.bits)
- newNode.parent = parentIndirection{&parent.child[bit], bit}
+ newNode.parent = parentIndirection[B]{&parent.child[bit], bit}
parent.child[bit] = newNode
}
return
}
- node = &trieEntry{
- bits: append([]byte{}, newNode.bits...),
+ node = &trieEntry[B]{
+ bits: newNode.bits,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
@@ -173,22 +191,22 @@ func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
node.maskSelf()
bit := node.choose(down.bits)
- down.parent = parentIndirection{&node.child[bit], bit}
+ down.parent = parentIndirection[B]{&node.child[bit], bit}
node.child[bit] = down
bit = node.choose(newNode.bits)
- newNode.parent = parentIndirection{&node.child[bit], bit}
+ newNode.parent = parentIndirection[B]{&node.child[bit], bit}
node.child[bit] = newNode
if parent == nil {
node.parent = trie
*trie.parentBit = node
} else {
bit := parent.choose(node.bits)
- node.parent = parentIndirection{&parent.child[bit], bit}
+ node.parent = parentIndirection[B]{&parent.child[bit], bit}
parent.child[bit] = node
}
}
-func (node *trieEntry) lookup(ip []byte) *Peer {
+func (node *trieEntry[B]) lookup(ip B) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -205,8 +223,8 @@ func (node *trieEntry) lookup(ip []byte) *Peer {
}
type AllowedIPs struct {
- IPv4 *trieEntry
- IPv6 *trieEntry
+ IPv4 *trieEntry[[4]byte]
+ IPv6 *trieEntry[[16]byte]
mutex sync.RWMutex
}
@@ -215,14 +233,51 @@ func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix)
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
- node := elem.Value.(*trieEntry)
- a, _ := netip.AddrFromSlice(node.bits)
- if !cb(netip.PrefixFrom(a, int(node.cidr))) {
- return
+ if node, ok := elem.Value.(*trieEntry[[4]byte]); ok {
+ if !cb(netip.PrefixFrom(netip.AddrFrom4(node.bits), int(node.cidr))) {
+ return
+ }
+ } else if node, ok := elem.Value.(*trieEntry[[16]byte]); ok {
+ if !cb(netip.PrefixFrom(netip.AddrFrom16(node.bits), int(node.cidr))) {
+ return
+ }
}
}
}
+func (node *trieEntry[B]) remove() {
+ node.removeFromPeerEntries()
+ node.peer = nil
+ if node.child[0] != nil && node.child[1] != nil {
+ return
+ }
+ bit := 0
+ if node.child[0] == nil {
+ bit = 1
+ }
+ child := node.child[bit]
+ if child != nil {
+ child.parent = node.parent
+ }
+ *node.parent.parentBit = child
+ if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ node.zeroizePointers()
+ return
+ }
+ parent := (*trieEntry[B])(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+ if parent.peer != nil {
+ node.zeroizePointers()
+ return
+ }
+ child = parent.child[node.parent.parentBitType^1]
+ if child != nil {
+ child.parent = parent.parent
+ }
+ *parent.parent.parentBit = child
+ node.zeroizePointers()
+ parent.zeroizePointers()
+}
+
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
@@ -230,38 +285,11 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
var next *list.Element
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
next = elem.Next()
- node := elem.Value.(*trieEntry)
-
- node.removeFromPeerEntries()
- node.peer = nil
- if node.child[0] != nil && node.child[1] != nil {
- continue
- }
- bit := 0
- if node.child[0] == nil {
- bit = 1
- }
- child := node.child[bit]
- if child != nil {
- child.parent = node.parent
+ if node, ok := elem.Value.(*trieEntry[[4]byte]); ok {
+ node.remove()
+ } else if node, ok := elem.Value.(*trieEntry[[16]byte]); ok {
+ node.remove()
}
- *node.parent.parentBit = child
- if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
- node.zeroizePointers()
- continue
- }
- parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
- if parent.peer != nil {
- node.zeroizePointers()
- continue
- }
- child = parent.child[node.parent.parentBitType^1]
- if child != nil {
- child.parent = parent.parent
- }
- *parent.parent.parentBit = child
- node.zeroizePointers()
- parent.zeroizePointers()
}
}
@@ -270,11 +298,9 @@ func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
defer table.mutex.Unlock()
if prefix.Addr().Is6() {
- ip := prefix.Addr().As16()
- parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ parentIndirection[[16]byte]{&table.IPv6, 2}.insert(prefix.Addr().As16(), uint8(prefix.Bits()), peer)
} else if prefix.Addr().Is4() {
- ip := prefix.Addr().As4()
- parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ parentIndirection[[4]byte]{&table.IPv4, 2}.insert(prefix.Addr().As4(), uint8(prefix.Bits()), peer)
} else {
panic(errors.New("inserting unknown address type"))
}
@@ -285,9 +311,9 @@ func (table *AllowedIPs) Lookup(ip []byte) *Peer {
defer table.mutex.RUnlock()
switch len(ip) {
case net.IPv6len:
- return table.IPv6.lookup(ip)
+ return table.IPv6.lookup(*(*[16]byte)(ip))
case net.IPv4len:
- return table.IPv4.lookup(ip)
+ return table.IPv4.lookup(*(*[4]byte)(ip))
default:
panic(errors.New("looking up unknown address type"))
}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 0d3eecb..8c17d02 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -40,9 +40,18 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}
+func commonBitsSlice(addr1, addr2 []byte) uint8 {
+ if len(addr1) == 4 {
+ return commonBits4(*(*[4]byte)(addr1), *(*[4]byte)(addr2))
+ } else if len(addr1) == 16 {
+ return commonBits16(*(*[16]byte)(addr1), *(*[16]byte)(addr2))
+ }
+ return 0
+}
+
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r {
- if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
+ if t.cidr == cidr && commonBitsSlice(t.bits, addr) >= cidr {
t.peer = peer
t.bits = addr
return r
@@ -59,7 +68,7 @@ func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
func (r SlowRouter) Lookup(addr []byte) *Peer {
for _, t := range r {
- common := commonBits(t.bits, addr)
+ common := commonBitsSlice(t.bits, addr)
if common >= t.cidr {
return t.peer
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 225c788..a0d286f 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -7,28 +7,28 @@ package device
import (
"math/rand"
- "net"
"net/netip"
"testing"
+ "unsafe"
)
-type testPairCommonBits struct {
- s1 []byte
- s2 []byte
+type testPairCommonBits4 struct {
+ s1 [4]byte
+ s2 [4]byte
match uint8
}
-func TestCommonBits(t *testing.T) {
- tests := []testPairCommonBits{
- {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
- {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
- {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
- {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
- {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
+func TestCommonBits4(t *testing.T) {
+ tests := []testPairCommonBits4{
+ {s1: [4]byte{1, 4, 53, 128}, s2: [4]byte{0, 0, 0, 0}, match: 7},
+ {s1: [4]byte{0, 4, 53, 128}, s2: [4]byte{0, 0, 0, 0}, match: 13},
+ {s1: [4]byte{0, 4, 53, 253}, s2: [4]byte{0, 4, 53, 252}, match: 31},
+ {s1: [4]byte{192, 168, 1, 1}, s2: [4]byte{192, 169, 1, 1}, match: 15},
+ {s1: [4]byte{65, 168, 1, 1}, s2: [4]byte{192, 169, 1, 1}, match: 0},
}
for _, p := range tests {
- v := commonBits(p.s1, p.s2)
+ v := commonBits4(p.s1, p.s2)
if v != p.match {
t.Error(
"For slice", p.s1, p.s2,
@@ -39,48 +39,46 @@ func TestCommonBits(t *testing.T) {
}
}
-func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
- var trie *trieEntry
+func benchmarkTrie[B ipArray](peerNumber, addressNumber int, b *testing.B) {
+ var trie *trieEntry[B]
var peers []*Peer
- root := parentIndirection{&trie, 2}
+ root := parentIndirection[B]{&trie, 2}
rand.Seed(1)
- const AddressLength = 4
-
for n := 0; n < peerNumber; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < addressNumber; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint8(rand.Uint32() % (AddressLength * 8))
+ var addr B
+ rand.Read(unsafe.Slice(&addr[0], len(addr)))
+ cidr := uint8(rand.Uint32() % uint32(len(addr)*8))
index := rand.Int() % peerNumber
- root.insert(addr[:], cidr, peers[index])
+ root.insert(addr, cidr, peers[index])
}
for n := 0; n < b.N; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- trie.lookup(addr[:])
+ var addr B
+ rand.Read(unsafe.Slice(&addr[0], len(addr)))
+ trie.lookup(addr)
}
}
func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
- benchmarkTrie(100, 1000, net.IPv4len, b)
+ benchmarkTrie[[4]byte](100, 1000, b)
}
func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
- benchmarkTrie(10, 10, net.IPv4len, b)
+ benchmarkTrie[[4]byte](10, 10, b)
}
func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
- benchmarkTrie(100, 1000, net.IPv6len, b)
+ benchmarkTrie[[16]byte](100, 1000, b)
}
func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
- benchmarkTrie(10, 10, net.IPv6len, b)
+ benchmarkTrie[[16]byte](10, 10, b)
}
/* Test ported from kernel implementation: