From 4a57024b94edf23a20f1e4289052d0717227683b Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 3 Jun 2021 13:51:03 +0200 Subject: device: reduce size of trie struct Signed-off-by: Jason A. Donenfeld --- device/allowedips.go | 68 ++++++++++++++++++++++-------------------- device/allowedips_rand_test.go | 8 ++--- device/allowedips_test.go | 11 +++---- device/misc.go | 7 ----- device/uapi.go | 4 +-- 5 files changed, 45 insertions(+), 53 deletions(-) (limited to 'device') diff --git a/device/allowedips.go b/device/allowedips.go index b6f096a..1564d2d 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -15,13 +15,13 @@ import ( ) type trieEntry struct { - child [2]*trieEntry - peer *Peer - bits net.IP - cidr uint - bit_at_byte uint - bit_at_shift uint - perPeerElem *list.Element + peer *Peer + child [2]*trieEntry + cidr uint8 + bitAtByte uint8 + bitAtShift uint8 + bits net.IP + perPeerElem *list.Element } func isLittleEndian() bool { @@ -45,24 +45,24 @@ func swapU64(i uint64) uint64 { return bits.ReverseBytes64(i) } -func commonBits(ip1 net.IP, ip2 net.IP) uint { +func commonBits(ip1 net.IP, ip2 net.IP) uint8 { size := len(ip1) if size == net.IPv4len { a := (*uint32)(unsafe.Pointer(&ip1[0])) b := (*uint32)(unsafe.Pointer(&ip2[0])) x := *a ^ *b - return uint(bits.LeadingZeros32(swapU32(x))) + return uint8(bits.LeadingZeros32(swapU32(x))) } else if size == net.IPv6len { a := (*uint64)(unsafe.Pointer(&ip1[0])) b := (*uint64)(unsafe.Pointer(&ip2[0])) x := *a ^ *b if x != 0 { - return uint(bits.LeadingZeros64(swapU64(x))) + return uint8(bits.LeadingZeros64(swapU64(x))) } a = (*uint64)(unsafe.Pointer(&ip1[8])) b = (*uint64)(unsafe.Pointer(&ip2[8])) x = *a ^ *b - return 64 + uint(bits.LeadingZeros64(swapU64(x))) + return 64 + uint8(bits.LeadingZeros64(swapU64(x))) } else { panic("Wrong size bit string") } @@ -104,7 +104,7 @@ func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { } func (node *trieEntry) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 + return (ip[node.bitAtByte] >> node.bitAtShift) & 1 } func (node *trieEntry) maskSelf() { @@ -114,17 +114,17 @@ func (node *trieEntry) maskSelf() { } } -func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { +func (node *trieEntry) insert(ip net.IP, cidr uint8, peer *Peer) *trieEntry { // at leaf if node == nil { node := &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + bits: ip, + peer: peer, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } node.maskSelf() node.addToPeerEntries() @@ -149,16 +149,18 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { // split node newNode := &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + bits: ip, + peer: peer, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } newNode.maskSelf() newNode.addToPeerEntries() - cidr = min(cidr, common) + if common < cidr { + cidr = common + } // check for shorter prefix @@ -171,11 +173,11 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { // create new parent for node & newNode parent := &trieEntry{ - bits: append([]byte{}, ip...), - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + bits: append([]byte{}, ip...), + peer: nil, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } parent.maskSelf() @@ -188,12 +190,12 @@ func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { func (node *trieEntry) lookup(ip net.IP) *Peer { var found *Peer - size := uint(len(ip)) + size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { if node.peer != nil { found = node.peer } - if node.bit_at_byte == size { + if node.bitAtByte == size { break } bit := node.choose(ip) @@ -208,7 +210,7 @@ type AllowedIPs struct { mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint) bool) { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) { table.mutex.RLock() defer table.mutex.RUnlock() @@ -228,7 +230,7 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.IPv6 = table.IPv6.removeByPeer(peer) } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { +func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index bb3fb43..2da8795 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -19,7 +19,7 @@ const ( type SlowNode struct { peer *Peer - cidr uint + cidr uint8 bits []byte } @@ -37,7 +37,7 @@ func (r SlowRouter) Swap(i, j int) { r[i], r[j] = r[j], r[i] } -func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { +func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { for _, t := range r { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { t.peer = peer @@ -80,7 +80,7 @@ func TestTrieRandomIPv4(t *testing.T) { for n := 0; n < NumberOfAddresses; n++ { var addr [AddressLength]byte rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) + cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % NumberOfPeers trie = trie.insert(addr[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index]) @@ -113,7 +113,7 @@ func TestTrieRandomIPv6(t *testing.T) { for n := 0; n < NumberOfAddresses; n++ { var addr [AddressLength]byte rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) + cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % NumberOfPeers trie = trie.insert(addr[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index]) diff --git a/device/allowedips_test.go b/device/allowedips_test.go index cdd65cf..8dc8438 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -11,13 +11,10 @@ import ( "testing" ) -/* Todo: More comprehensive - */ - type testPairCommonBits struct { s1 []byte s2 []byte - match uint + match uint8 } func TestCommonBits(t *testing.T) { @@ -57,7 +54,7 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test for n := 0; n < addressNumber; n++ { var addr [AddressLength]byte rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) + cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % peerNumber trie = trie.insert(addr[:], cidr, peers[index]) } @@ -99,7 +96,7 @@ func TestTrieIPv4(t *testing.T) { var trie *trieEntry - insert := func(peer *Peer, a, b, c, d byte, cidr uint) { + insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { trie = trie.insert([]byte{a, b, c, d}, cidr, peer) } @@ -195,7 +192,7 @@ func TestTrieIPv6(t *testing.T) { return out[:] } - insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { var addr []byte addr = append(addr, expand(a)...) addr = append(addr, expand(b)...) diff --git a/device/misc.go b/device/misc.go index 2c2510f..4126704 100644 --- a/device/misc.go +++ b/device/misc.go @@ -39,10 +39,3 @@ func (a *AtomicBool) Set(val bool) { } atomic.StoreInt32(&a.int32, flag) } - -func min(a, b uint) uint { - if a > b { - return b - } - return a -} diff --git a/device/uapi.go b/device/uapi.go index 659af0a..66ecd48 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -121,7 +121,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)) sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval)) - device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint) bool { + device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool { sendf("allowed_ip=%s/%d", ip.String(), cidr) return true }) @@ -379,7 +379,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error return nil } ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint(ones), peer.Peer) + device.allowedips.Insert(network.IP, uint8(ones), peer.Peer) case "protocol_version": if value != "1" { -- cgit v1.2.3-59-g8ed1b