/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. */ package device import ( "errors" "math/bits" "net" "sync" "unsafe" ) type trieEntry struct { cidr uint child [2]*trieEntry bits net.IP peer *Peer // index of "branching" bit bit_at_byte uint bit_at_shift uint } func isLittleEndian() bool { one := uint32(1) return *(*byte)(unsafe.Pointer(&one)) != 0 } func swapU32(i uint32) uint32 { if !isLittleEndian() { return i } return bits.ReverseBytes32(i) } func swapU64(i uint64) uint64 { if !isLittleEndian() { return i } return bits.ReverseBytes64(i) } func commonBits(ip1 net.IP, ip2 net.IP) uint { size := len(ip1) if size == net.IPv4len { a := (*uint32)(unsafe.Pointer(&ip1[0])) b := (*uint32)(unsafe.Pointer(&ip2[0])) x := *a ^ *b return uint(bits.LeadingZeros32(swapU32(x))) } else if size == net.IPv6len { a := (*uint64)(unsafe.Pointer(&ip1[0])) b := (*uint64)(unsafe.Pointer(&ip2[0])) x := *a ^ *b if x != 0 { return uint(bits.LeadingZeros64(swapU64(x))) } a = (*uint64)(unsafe.Pointer(&ip1[8])) b = (*uint64)(unsafe.Pointer(&ip2[8])) x = *a ^ *b return 64 + uint(bits.LeadingZeros64(swapU64(x))) } else { panic("Wrong size bit string") } } func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { if node == nil { return node } // walk recursively node.child[0] = node.child[0].removeByPeer(p) node.child[1] = node.child[1].removeByPeer(p) if node.peer != p { return node } // remove peer & merge node.peer = nil if node.child[0] == nil { return node.child[1] } return node.child[0] } func (node *trieEntry) choose(ip net.IP) byte { return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 } func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { // at leaf if node == nil { return &trieEntry{ bits: ip, peer: peer, cidr: cidr, bit_at_byte: cidr / 8, bit_at_shift: 7 - (cidr % 8), } } // traverse deeper common := commonBits(node.bits, ip) if node.cidr <= cidr && common >= node.cidr { if node.cidr == cidr { node.peer = peer return node } bit := node.choose(ip) node.child[bit] = node.child[bit].insert(ip, cidr, peer) return node } // split node newNode := &trieEntry{ bits: ip, peer: peer, cidr: cidr, bit_at_byte: cidr / 8, bit_at_shift: 7 - (cidr % 8), } cidr = min(cidr, common) // check for shorter prefix if newNode.cidr == cidr { bit := newNode.choose(node.bits) newNode.child[bit] = node return newNode } // create new parent for node & newNode parent := &trieEntry{ bits: ip, peer: nil, cidr: cidr, bit_at_byte: cidr / 8, bit_at_shift: 7 - (cidr % 8), } bit := parent.choose(ip) parent.child[bit] = newNode parent.child[bit^1] = node return parent } func (node *trieEntry) lookup(ip net.IP) *Peer { var found *Peer size := uint(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { if node.peer != nil { found = node.peer } if node.bit_at_byte == size { break } bit := node.choose(ip) node = node.child[bit] } return found } func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { if node == nil { return results } if node.peer == p { mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) results = append(results, net.IPNet{ Mask: mask, IP: node.bits.Mask(mask), }) } results = node.child[0].entriesForPeer(p, results) results = node.child[1].entriesForPeer(p, results) return results } type AllowedIPs struct { IPv4 *trieEntry IPv6 *trieEntry mutex sync.RWMutex } func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { table.mutex.RLock() defer table.mutex.RUnlock() allowed := make([]net.IPNet, 0, 10) allowed = table.IPv4.entriesForPeer(peer, allowed) allowed = table.IPv6.entriesForPeer(peer, allowed) return allowed } func (table *AllowedIPs) Reset() { table.mutex.Lock() defer table.mutex.Unlock() table.IPv4 = nil table.IPv6 = nil } func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() table.IPv4 = table.IPv4.removeByPeer(peer) table.IPv6 = table.IPv6.removeByPeer(peer) } func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() switch len(ip) { case net.IPv6len: table.IPv6 = table.IPv6.insert(ip, cidr, peer) case net.IPv4len: table.IPv4 = table.IPv4.insert(ip, cidr, peer) default: panic(errors.New("inserting unknown address type")) } } func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() return table.IPv4.lookup(address) } func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() return table.IPv6.lookup(address) }