summaryrefslogtreecommitdiffstats
path: root/device
diff options
context:
space:
mode:
Diffstat (limited to 'device')
-rw-r--r--device/allowedips.go58
-rw-r--r--device/allowedips_rand_test.go96
2 files changed, 82 insertions, 72 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index d613121..7af9fc7 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -85,30 +85,6 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
-func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
- if node == nil {
- return node
- }
-
- // walk recursively
-
- node.child[0] = node.child[0].removeByPeer(p)
- node.child[1] = node.child[1].removeByPeer(p)
-
- if node.peer != p {
- return node
- }
-
- // remove peer & merge
-
- node.removeFromPeerEntries()
- node.peer = nil
- if node.child[0] == nil {
- return node.child[1]
- }
- return node.child[0]
-}
-
func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -261,8 +237,38 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- table.IPv4 = table.IPv4.removeByPeer(peer)
- table.IPv6 = table.IPv6.removeByPeer(peer)
+ var next *list.Element
+ for elem := peer.trieEntries.Front(); elem != nil; elem = next {
+ next = elem.Next()
+ node := elem.Value.(*trieEntry)
+
+ node.removeFromPeerEntries()
+ node.peer = nil
+ if node.child[0] != nil && node.child[1] != nil {
+ continue
+ }
+ bit := 0
+ if node.child[0] == nil {
+ bit = 1
+ }
+ child := node.child[bit]
+ if child != nil {
+ child.parent = node.parent
+ }
+ *node.parent.parentBit = child
+ if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ continue
+ }
+ parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+ if parent.peer != nil {
+ continue
+ }
+ child = parent.child[node.parent.parentBitType^1]
+ if child != nil {
+ child.parent = parent.parent
+ }
+ *parent.parent.parentBit = child
+ }
}
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 48a5bcd..c5f80fe 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -7,6 +7,7 @@ package device
import (
"math/rand"
+ "net"
"sort"
"testing"
)
@@ -64,68 +65,71 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil
}
-func TestTrieRandomIPv4(t *testing.T) {
- var slow SlowRouter
- var peers []*Peer
- var allowedIPs AllowedIPs
-
- rand.Seed(1)
-
- const AddressLength = 4
-
- for n := 0; n < NumberOfPeers; n++ {
- peers = append(peers, &Peer{})
- }
-
- for n := 0; n < NumberOfAddresses; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint8(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- allowedIPs.Insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
- }
-
- for n := 0; n < NumberOfTests; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := allowedIPs.LookupIPv4(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
+ n := 0
+ for _, x := range r {
+ if x.peer != peer {
+ r[n] = x
+ n++
}
}
+ return r[:n]
}
-func TestTrieRandomIPv6(t *testing.T) {
- var slow SlowRouter
+func TestTrieRandom(t *testing.T) {
+ var slow4, slow6 SlowRouter
var peers []*Peer
var allowedIPs AllowedIPs
rand.Seed(1)
- const AddressLength = 16
-
for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint8(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- allowedIPs.Insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ cidr := uint8(rand.Intn(32) + 1)
+ index := rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(addr4[:], cidr, peers[index])
+ slow4 = slow4.Insert(addr4[:], cidr, peers[index])
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ cidr = uint8(rand.Intn(128) + 1)
+ index = rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(addr6[:], cidr, peers[index])
+ slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
- for n := 0; n < NumberOfTests; n++ {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := allowedIPs.LookupIPv6(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+ for p := 0; ; p++ {
+ for n := 0; n < NumberOfTests; n++ {
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ peer1 := slow4.Lookup(addr4[:])
+ peer2 := allowedIPs.LookupIPv4(addr4[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
+ }
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ peer1 = slow6.Lookup(addr6[:])
+ peer2 = allowedIPs.LookupIPv6(addr6[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
+ }
+ }
+ if p >= len(peers) {
+ break
}
+ allowedIPs.RemoveByPeer(peers[p])
+ slow4 = slow4.RemoveByPeer(peers[p])
+ slow6 = slow6.RemoveByPeer(peers[p])
+ }
+
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Failed to remove all nodes from trie by peer")
}
}