aboutsummaryrefslogtreecommitdiffstats
path: root/allowedips.go
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2018-05-13 19:33:41 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2018-05-13 19:34:28 +0200
commit2326d6a4d75f9f3736046cc526eb593a403d4c7a (patch)
treeda31622899bb84e256b22addd1604b8250582c44 /allowedips.go
parentCleanup ratelimiter (diff)
downloadwireguard-go-2326d6a4d75f9f3736046cc526eb593a403d4c7a.tar.xz
wireguard-go-2326d6a4d75f9f3736046cc526eb593a403d4c7a.zip
Odds and ends
Diffstat (limited to '')
-rw-r--r--allowedips.go (renamed from trie.go)110
1 files changed, 75 insertions, 35 deletions
diff --git a/trie.go b/allowedips.go
index 03f0722..df53abf 100644
--- a/trie.go
+++ b/allowedips.go
@@ -8,21 +8,12 @@ package main
import (
"errors"
"net"
+ "sync"
)
-/* Binary trie
- *
- * The net.IPs used here are not formatted the
- * same way as those created by the "net" functions.
- * Here the IPs are slices of either 4 or 16 byte (not always 16)
- *
- * Synchronization done separately
- * See: routing.go
- */
-
-type Trie struct {
+type trieEntry struct {
cidr uint
- child [2]*Trie
+ child [2]*trieEntry
bits []byte
peer *Peer
@@ -90,15 +81,15 @@ func commonBits(ip1 []byte, ip2 []byte) uint {
return i * 8
}
-func (node *Trie) RemovePeer(p *Peer) *Trie {
+func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil {
return node
}
// walk recursively
- node.child[0] = node.child[0].RemovePeer(p)
- node.child[1] = node.child[1].RemovePeer(p)
+ node.child[0] = node.child[0].removeByPeer(p)
+ node.child[1] = node.child[1].removeByPeer(p)
if node.peer != p {
return node
@@ -113,16 +104,16 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node.child[0]
}
-func (node *Trie) choose(ip net.IP) byte {
+func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
}
-func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
+func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// at leaf
if node == nil {
- return &Trie{
+ return &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
@@ -140,13 +131,13 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
return node
}
bit := node.choose(ip)
- node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
+ node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
}
// split node
- newNode := &Trie{
+ newNode := &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
@@ -166,7 +157,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
// create new parent for node & newNode
- parent := &Trie{
+ parent := &trieEntry{
bits: ip,
peer: nil,
cidr: cidr,
@@ -181,7 +172,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
return parent
}
-func (node *Trie) Lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip net.IP) *Peer {
var found *Peer
size := uint(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -197,16 +188,7 @@ func (node *Trie) Lookup(ip net.IP) *Peer {
return found
}
-func (node *Trie) Count() uint {
- if node == nil {
- return 0
- }
- l := node.child[0].Count()
- r := node.child[1].Count()
- return l + r
-}
-
-func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
+func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
@@ -223,11 +205,69 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
} else if len(node.bits) == net.IPv6len {
mask.IP = node.bits
} else {
- panic(errors.New("bug: unexpected address length"))
+ panic(errors.New("unexpected address length"))
}
results = append(results, mask)
}
- results = node.child[0].AllowedIPs(p, results)
- results = node.child[1].AllowedIPs(p, results)
+ results = node.child[0].entriesForPeer(p, results)
+ results = node.child[1].entriesForPeer(p, results)
return results
}
+
+type AllowedIPs struct {
+ IPv4 *trieEntry
+ IPv6 *trieEntry
+ mutex sync.RWMutex
+}
+
+func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+
+ allowed := make([]net.IPNet, 0, 10)
+ allowed = table.IPv4.entriesForPeer(peer, allowed)
+ allowed = table.IPv6.entriesForPeer(peer, allowed)
+ return allowed
+}
+
+func (table *AllowedIPs) Reset() {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+
+ table.IPv4 = nil
+ table.IPv6 = nil
+}
+
+func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+
+ table.IPv4 = table.IPv4.removeByPeer(peer)
+ table.IPv6 = table.IPv6.removeByPeer(peer)
+}
+
+func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+
+ switch len(ip) {
+ case net.IPv6len:
+ table.IPv6 = table.IPv6.insert(ip, cidr, peer)
+ case net.IPv4len:
+ table.IPv4 = table.IPv4.insert(ip, cidr, peer)
+ default:
+ panic(errors.New("inserting unknown address type"))
+ }
+}
+
+func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+ return table.IPv4.lookup(address)
+}
+
+func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+ return table.IPv6.lookup(address)
+}