aboutsummaryrefslogtreecommitdiffstats
path: root/trie.go
diff options
context:
space:
mode:
Diffstat (limited to 'trie.go')
-rw-r--r--trie.go228
1 files changed, 228 insertions, 0 deletions
diff --git a/trie.go b/trie.go
new file mode 100644
index 0000000..405ffc3
--- /dev/null
+++ b/trie.go
@@ -0,0 +1,228 @@
+package main
+
+import (
+ "errors"
+ "net"
+)
+
+/* Binary trie
+ *
+ * The net.IPs used here are not formatted the
+ * same way as those created by the "net" functions.
+ * Here the IPs are slices of either 4 or 16 byte (not always 16)
+ *
+ * Synchronization done separately
+ * See: routing.go
+ */
+
+type Trie struct {
+ cidr uint
+ child [2]*Trie
+ bits []byte
+ peer *Peer
+
+ // index of "branching" bit
+
+ bit_at_byte uint
+ bit_at_shift uint
+}
+
+/* Finds length of matching prefix
+ *
+ * TODO: Only use during insertion (xor + prefix mask for lookup)
+ * Check out
+ * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
+ * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
+ *
+ * Assumption:
+ * len(ip1) == len(ip2)
+ * len(ip1) mod 4 = 0
+ */
+func commonBits(ip1 []byte, ip2 []byte) uint {
+ var i uint
+ size := uint(len(ip1))
+
+ for i = 0; i < size; i++ {
+ v := ip1[i] ^ ip2[i]
+ if v != 0 {
+ v >>= 1
+ if v == 0 {
+ return i*8 + 7
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 6
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 5
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 4
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 3
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 2
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 1
+ }
+ return i * 8
+ }
+ }
+ return i * 8
+}
+
+func (node *Trie) RemovePeer(p *Peer) *Trie {
+ if node == nil {
+ return node
+ }
+
+ // walk recursively
+
+ node.child[0] = node.child[0].RemovePeer(p)
+ node.child[1] = node.child[1].RemovePeer(p)
+
+ if node.peer != p {
+ return node
+ }
+
+ // remove peer & merge
+
+ node.peer = nil
+ if node.child[0] == nil {
+ return node.child[1]
+ }
+ return node.child[0]
+}
+
+func (node *Trie) choose(ip net.IP) byte {
+ return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
+}
+
+func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie {
+
+ // at leaf
+
+ if node == nil {
+ return &Trie{
+ bits: ip,
+ peer: peer,
+ cidr: cidr,
+ bit_at_byte: cidr / 8,
+ bit_at_shift: 7 - (cidr % 8),
+ }
+ }
+
+ // traverse deeper
+
+ common := commonBits(node.bits, ip)
+ if node.cidr <= cidr && common >= node.cidr {
+ if node.cidr == cidr {
+ node.peer = peer
+ return node
+ }
+ bit := node.choose(ip)
+ node.child[bit] = node.child[bit].Insert(ip, cidr, peer)
+ return node
+ }
+
+ // split node
+
+ newNode := &Trie{
+ bits: ip,
+ peer: peer,
+ cidr: cidr,
+ bit_at_byte: cidr / 8,
+ bit_at_shift: 7 - (cidr % 8),
+ }
+
+ cidr = min(cidr, common)
+
+ // check for shorter prefix
+
+ if newNode.cidr == cidr {
+ bit := newNode.choose(node.bits)
+ newNode.child[bit] = node
+ return newNode
+ }
+
+ // create new parent for node & newNode
+
+ parent := &Trie{
+ bits: ip,
+ peer: nil,
+ cidr: cidr,
+ bit_at_byte: cidr / 8,
+ bit_at_shift: 7 - (cidr % 8),
+ }
+
+ bit := parent.choose(ip)
+ parent.child[bit] = newNode
+ parent.child[bit^1] = node
+
+ return parent
+}
+
+func (node *Trie) Lookup(ip net.IP) *Peer {
+ var found *Peer
+ size := uint(len(ip))
+ for node != nil && commonBits(node.bits, ip) >= node.cidr {
+ if node.peer != nil {
+ found = node.peer
+ }
+ if node.bit_at_byte == size {
+ break
+ }
+ bit := node.choose(ip)
+ node = node.child[bit]
+ }
+ return found
+}
+
+func (node *Trie) Count() uint {
+ if node == nil {
+ return 0
+ }
+ l := node.child[0].Count()
+ r := node.child[1].Count()
+ return l + r
+}
+
+func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet {
+ if node == nil {
+ return results
+ }
+ if node.peer == p {
+ var mask net.IPNet
+ mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
+ if len(node.bits) == net.IPv4len {
+ mask.IP = net.IPv4(
+ node.bits[0],
+ node.bits[1],
+ node.bits[2],
+ node.bits[3],
+ )
+ } else if len(node.bits) == net.IPv6len {
+ mask.IP = node.bits
+ } else {
+ panic(errors.New("bug: unexpected address length"))
+ }
+ results = append(results, mask)
+ }
+ results = node.child[0].AllowedIPs(p, results)
+ results = node.child[1].AllowedIPs(p, results)
+ return results
+}