diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-05-13 19:33:41 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-05-13 19:34:28 +0200 |
commit | 2326d6a4d75f9f3736046cc526eb593a403d4c7a (patch) | |
tree | da31622899bb84e256b22addd1604b8250582c44 /allowedips.go | |
parent | Cleanup ratelimiter (diff) | |
download | wireguard-go-2326d6a4d75f9f3736046cc526eb593a403d4c7a.tar.xz wireguard-go-2326d6a4d75f9f3736046cc526eb593a403d4c7a.zip |
Odds and ends
Diffstat (limited to '')
-rw-r--r-- | allowedips.go (renamed from trie.go) | 110 |
1 files changed, 75 insertions, 35 deletions
@@ -8,21 +8,12 @@ package main import ( "errors" "net" + "sync" ) -/* Binary trie - * - * The net.IPs used here are not formatted the - * same way as those created by the "net" functions. - * Here the IPs are slices of either 4 or 16 byte (not always 16) - * - * Synchronization done separately - * See: routing.go - */ - -type Trie struct { +type trieEntry struct { cidr uint - child [2]*Trie + child [2]*trieEntry bits []byte peer *Peer @@ -90,15 +81,15 @@ func commonBits(ip1 []byte, ip2 []byte) uint { return i * 8 } -func (node *Trie) RemovePeer(p *Peer) *Trie { +func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { if node == nil { return node } // walk recursively - node.child[0] = node.child[0].RemovePeer(p) - node.child[1] = node.child[1].RemovePeer(p) + node.child[0] = node.child[0].removeByPeer(p) + node.child[1] = node.child[1].removeByPeer(p) if node.peer != p { return node @@ -113,16 +104,16 @@ func (node *Trie) RemovePeer(p *Peer) *Trie { return node.child[0] } -func (node *Trie) choose(ip net.IP) byte { +func (node *trieEntry) choose(ip net.IP) byte { return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 } -func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { +func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { // at leaf if node == nil { - return &Trie{ + return &trieEntry{ bits: ip, peer: peer, cidr: cidr, @@ -140,13 +131,13 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { return node } bit := node.choose(ip) - node.child[bit] = node.child[bit].Insert(ip, cidr, peer) + node.child[bit] = node.child[bit].insert(ip, cidr, peer) return node } // split node - newNode := &Trie{ + newNode := &trieEntry{ bits: ip, peer: peer, cidr: cidr, @@ -166,7 +157,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { // create new parent for node & newNode - parent := &Trie{ + parent := &trieEntry{ bits: ip, peer: nil, cidr: cidr, @@ -181,7 +172,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { return parent } -func (node *Trie) Lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip net.IP) *Peer { var found *Peer size := uint(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { @@ -197,16 +188,7 @@ func (node *Trie) Lookup(ip net.IP) *Peer { return found } -func (node *Trie) Count() uint { - if node == nil { - return 0 - } - l := node.child[0].Count() - r := node.child[1].Count() - return l + r -} - -func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { +func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { if node == nil { return results } @@ -223,11 +205,69 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { } else if len(node.bits) == net.IPv6len { mask.IP = node.bits } else { - panic(errors.New("bug: unexpected address length")) + panic(errors.New("unexpected address length")) } results = append(results, mask) } - results = node.child[0].AllowedIPs(p, results) - results = node.child[1].AllowedIPs(p, results) + results = node.child[0].entriesForPeer(p, results) + results = node.child[1].entriesForPeer(p, results) return results } + +type AllowedIPs struct { + IPv4 *trieEntry + IPv6 *trieEntry + mutex sync.RWMutex +} + +func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { + table.mutex.RLock() + defer table.mutex.RUnlock() + + allowed := make([]net.IPNet, 0, 10) + allowed = table.IPv4.entriesForPeer(peer, allowed) + allowed = table.IPv6.entriesForPeer(peer, allowed) + return allowed +} + +func (table *AllowedIPs) Reset() { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = nil + table.IPv6 = nil +} + +func (table *AllowedIPs) RemoveByPeer(peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = table.IPv4.removeByPeer(peer) + table.IPv6 = table.IPv6.removeByPeer(peer) +} + +func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + switch len(ip) { + case net.IPv6len: + table.IPv6 = table.IPv6.insert(ip, cidr, peer) + case net.IPv4len: + table.IPv4 = table.IPv4.insert(ip, cidr, peer) + default: + panic(errors.New("inserting unknown address type")) + } +} + +func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv4.lookup(address) +} + +func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv6.lookup(address) +} |