aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--device/allowedips.go119
-rw-r--r--device/allowedips_rand_test.go12
-rw-r--r--device/allowedips_test.go23
3 files changed, 95 insertions, 59 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index 1564d2d..d613121 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -14,9 +14,15 @@ import (
"unsafe"
)
+type parentIndirection struct {
+ parentBit **trieEntry
+ parentBitType uint8
+}
+
type trieEntry struct {
peer *Peer
child [2]*trieEntry
+ parent parentIndirection
cidr uint8
bitAtByte uint8
bitAtShift uint8
@@ -114,43 +120,45 @@ func (node *trieEntry) maskSelf() {
}
}
-func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
-
- // at leaf
+func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
+ for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
+ parent = node
+ if parent.cidr == cidr {
+ exact = true
+ return
+ }
+ bit := node.choose(ip)
+ node = node.child[bit]
+ }
+ return
+}
- if node == nil {
+func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
+ if *trie.parentBit == nil {
node := &trieEntry{
- bits: ip,
peer: peer,
+ parent: trie,
+ bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
node.maskSelf()
node.addToPeerEntries()
- return node
+ *trie.parentBit = node
+ return
}
-
- // traverse deeper
-
- common := commonBits(node.bits, ip)
- if node.cidr <= cidr && common >= node.cidr {
- if node.cidr == cidr {
- node.removeFromPeerEntries()
- node.peer = peer
- node.addToPeerEntries()
- return node
- }
- bit := node.choose(ip)
- node.child[bit] = node.child[bit].insert(ip, cidr, peer)
- return node
+ node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
+ if exact {
+ node.removeFromPeerEntries()
+ node.peer = peer
+ node.addToPeerEntries()
+ return
}
- // split node
-
newNode := &trieEntry{
- bits: ip,
peer: peer,
+ bits: ip,
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
@@ -158,34 +166,61 @@ func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry {
newNode.maskSelf()
newNode.addToPeerEntries()
+ var down *trieEntry
+ if node == nil {
+ down = *trie.parentBit
+ } else {
+ bit := node.choose(ip)
+ down = node.child[bit]
+ if down == nil {
+ newNode.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = newNode
+ return
+ }
+ }
+ common := commonBits(down.bits, ip)
if common < cidr {
cidr = common
}
-
- // check for shorter prefix
+ parent := node
if newNode.cidr == cidr {
- bit := newNode.choose(node.bits)
- newNode.child[bit] = node
- return newNode
+ bit := newNode.choose(down.bits)
+ down.parent = parentIndirection{&newNode.child[bit], bit}
+ newNode.child[bit] = down
+ if parent == nil {
+ newNode.parent = trie
+ *trie.parentBit = newNode
+ } else {
+ bit := parent.choose(newNode.bits)
+ newNode.parent = parentIndirection{&parent.child[bit], bit}
+ parent.child[bit] = newNode
+ }
+ return
}
- // create new parent for node & newNode
-
- parent := &trieEntry{
- bits: append([]byte{}, ip...),
- peer: nil,
+ node = &trieEntry{
+ bits: append([]byte{}, newNode.bits...),
cidr: cidr,
bitAtByte: cidr / 8,
bitAtShift: 7 - (cidr % 8),
}
- parent.maskSelf()
-
- bit := parent.choose(ip)
- parent.child[bit] = newNode
- parent.child[bit^1] = node
-
- return parent
+ node.maskSelf()
+
+ bit := node.choose(down.bits)
+ down.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = down
+ bit = node.choose(newNode.bits)
+ newNode.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = newNode
+ if parent == nil {
+ node.parent = trie
+ *trie.parentBit = node
+ } else {
+ bit := parent.choose(node.bits)
+ node.parent = parentIndirection{&parent.child[bit], bit}
+ parent.child[bit] = node
+ }
}
func (node *trieEntry) lookup(ip net.IP) *Peer {
@@ -236,9 +271,9 @@ func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
switch len(ip) {
case net.IPv6len:
- table.IPv6 = table.IPv6.insert(ip, cidr, peer)
+ parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
case net.IPv4len:
- table.IPv4 = table.IPv4.insert(ip, cidr, peer)
+ parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
default:
panic(errors.New("inserting unknown address type"))
}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 2da8795..48a5bcd 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -65,9 +65,9 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
}
func TestTrieRandomIPv4(t *testing.T) {
- var trie *trieEntry
var slow SlowRouter
var peers []*Peer
+ var allowedIPs AllowedIPs
rand.Seed(1)
@@ -82,7 +82,7 @@ func TestTrieRandomIPv4(t *testing.T) {
rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
- trie = trie.insert(addr[:], cidr, peers[index])
+ allowedIPs.Insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
@@ -90,7 +90,7 @@ func TestTrieRandomIPv4(t *testing.T) {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
- peer2 := trie.lookup(addr[:])
+ peer2 := allowedIPs.LookupIPv4(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
}
@@ -98,9 +98,9 @@ func TestTrieRandomIPv4(t *testing.T) {
}
func TestTrieRandomIPv6(t *testing.T) {
- var trie *trieEntry
var slow SlowRouter
var peers []*Peer
+ var allowedIPs AllowedIPs
rand.Seed(1)
@@ -115,7 +115,7 @@ func TestTrieRandomIPv6(t *testing.T) {
rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
- trie = trie.insert(addr[:], cidr, peers[index])
+ allowedIPs.Insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
@@ -123,7 +123,7 @@ func TestTrieRandomIPv6(t *testing.T) {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
- peer2 := trie.lookup(addr[:])
+ peer2 := allowedIPs.LookupIPv6(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 8dc8438..cbd32cc 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -42,6 +42,7 @@ func TestCommonBits(t *testing.T) {
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
+ root := parentIndirection{&trie, 2}
rand.Seed(1)
@@ -56,7 +57,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test
rand.Read(addr[:])
cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
- trie = trie.insert(addr[:], cidr, peers[index])
+ root.insert(addr[:], cidr, peers[index])
}
for n := 0; n < b.N; n++ {
@@ -94,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *trieEntry
+ var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
- trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
+ allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.lookup([]byte{a, b, c, d})
+ p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
if p != peer {
t.Error("Assert EQ failed")
}
}
assertNEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.lookup([]byte{a, b, c, d})
+ p := allowedIPs.LookupIPv4([]byte{a, b, c, d})
if p == peer {
t.Error("Assert NEQ failed")
}
@@ -150,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0)
- trie = trie.removeByPeer(a)
+ allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0)
@@ -158,12 +159,12 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0)
- trie = nil
+ allowedIPs = AllowedIPs{}
insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24)
- trie = trie.removeByPeer(a)
+ allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
}
@@ -181,7 +182,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *trieEntry
+ var allowedIPs AllowedIPs
expand := func(a uint32) []byte {
var out [4]byte
@@ -198,7 +199,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- trie = trie.insert(addr, cidr, peer)
+ allowedIPs.Insert(addr, cidr, peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
@@ -207,7 +208,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- p := trie.lookup(addr)
+ p := allowedIPs.LookupIPv6(addr)
if p != peer {
t.Error("Assert EQ failed")
}