package main import ( "errors" "net" ) /* Binary trie * * The net.IPs used here are not formatted the * same way as those created by the "net" functions. * Here the IPs are slices of either 4 or 16 byte (not always 16) * * Synchronization done separately * See: routing.go */ type Trie struct { cidr uint child [2]*Trie bits []byte peer *Peer // index of "branching" bit bit_at_byte uint bit_at_shift uint } /* Finds length of matching prefix * * TODO: Only use during insertion (xor + prefix mask for lookup) * Check out * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits) * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match * * Assumption: * len(ip1) == len(ip2) * len(ip1) mod 4 = 0 */ func commonBits(ip1 []byte, ip2 []byte) uint { var i uint size := uint(len(ip1)) for i = 0; i < size; i++ { v := ip1[i] ^ ip2[i] if v != 0 { v >>= 1 if v == 0 { return i*8 + 7 } v >>= 1 if v == 0 { return i*8 + 6 } v >>= 1 if v == 0 { return i*8 + 5 } v >>= 1 if v == 0 { return i*8 + 4 } v >>= 1 if v == 0 { return i*8 + 3 } v >>= 1 if v == 0 { return i*8 + 2 } v >>= 1 if v == 0 { return i*8 + 1 } return i * 8 } } return i * 8 } func (node *Trie) RemovePeer(p *Peer) *Trie { if node == nil { return node } // walk recursively node.child[0] = node.child[0].RemovePeer(p) node.child[1] = node.child[1].RemovePeer(p) if node.peer != p { return node } // remove peer & merge node.peer = nil if node.child[0] == nil { return node.child[1] } return node.child[0] } func (node *Trie) choose(ip net.IP) byte { return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 } func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { // at leaf if node == nil { return &Trie{ bits: ip, peer: peer, cidr: cidr, bit_at_byte: cidr / 8, bit_at_shift: 7 - (cidr % 8), } } // traverse deeper common := commonBits(node.bits, ip) if node.cidr <= cidr && common >= node.cidr { if node.cidr == cidr { node.peer = peer return node } bit := node.choose(ip) node.child[bit] = node.child[bit].Insert(ip, cidr, peer) return node } // split node newNode := &Trie{ bits: ip, peer: peer, cidr: cidr, bit_at_byte: cidr / 8, bit_at_shift: 7 - (cidr % 8), } cidr = min(cidr, common) // check for shorter prefix if newNode.cidr == cidr { bit := newNode.choose(node.bits) newNode.child[bit] = node return newNode } // create new parent for node & newNode parent := &Trie{ bits: ip, peer: nil, cidr: cidr, bit_at_byte: cidr / 8, bit_at_shift: 7 - (cidr % 8), } bit := parent.choose(ip) parent.child[bit] = newNode parent.child[bit^1] = node return parent } func (node *Trie) Lookup(ip net.IP) *Peer { var found *Peer size := uint(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { if node.peer != nil { found = node.peer } if node.bit_at_byte == size { break } bit := node.choose(ip) node = node.child[bit] } return found } func (node *Trie) Count() uint { if node == nil { return 0 } l := node.child[0].Count() r := node.child[1].Count() return l + r } func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { if node == nil { return results } if node.peer == p { var mask net.IPNet mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) if len(node.bits) == net.IPv4len { mask.IP = net.IPv4( node.bits[0], node.bits[1], node.bits[2], node.bits[3], ) } else if len(node.bits) == net.IPv6len { mask.IP = node.bits } else { panic(errors.New("bug: unexpected address length")) } results = append(results, mask) } results = node.child[0].AllowedIPs(p, results) results = node.child[1].AllowedIPs(p, results) return results }