From 2326d6a4d75f9f3736046cc526eb593a403d4c7a Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sun, 13 May 2018 19:33:41 +0200 Subject: Odds and ends --- Makefile | 3 - allowedips.go | 273 ++++++++++++++++++++++++++++++++++++++++++++++++ allowedips_rand_test.go | 131 +++++++++++++++++++++++ allowedips_test.go | 260 +++++++++++++++++++++++++++++++++++++++++++++ device.go | 4 +- keypair.go | 2 +- logger.go | 2 +- noise-helpers.go | 3 +- noise-types.go | 2 +- peer.go | 24 ++--- receive.go | 26 +++-- routing.go | 70 ------------- trie.go | 233 ----------------------------------------- trie_rand_test.go | 131 ----------------------- trie_test.go | 260 --------------------------------------------- tun_darwin.go | 4 +- tun_linux.go | 1 + tun_windows.go | 4 +- uapi.go | 6 +- 19 files changed, 707 insertions(+), 732 deletions(-) create mode 100644 allowedips.go create mode 100644 allowedips_rand_test.go create mode 100644 allowedips_test.go delete mode 100644 routing.go delete mode 100644 trie.go delete mode 100644 trie_rand_test.go delete mode 100644 trie_test.go diff --git a/Makefile b/Makefile index 5b23ecc..77eaac9 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,4 @@ wireguard-go: $(wildcard *.go) clean: rm -f wireguard-go -cloc: - cloc $(filter-out xchacha20.go $(wildcard *_test.go), $(wildcard *.go)) - .PHONY: clean cloc diff --git a/allowedips.go b/allowedips.go new file mode 100644 index 0000000..df53abf --- /dev/null +++ b/allowedips.go @@ -0,0 +1,273 @@ +/* SPDX-License-Identifier: GPL-2.0 + * + * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. + */ + +package main + +import ( + "errors" + "net" + "sync" +) + +type trieEntry struct { + cidr uint + child [2]*trieEntry + bits []byte + peer *Peer + + // index of "branching" bit + + bit_at_byte uint + bit_at_shift uint +} + +/* Finds length of matching prefix + * + * TODO: Only use during insertion (xor + prefix mask for lookup) + * Check out + * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits) + * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match + * + * Assumption: + * len(ip1) == len(ip2) + * len(ip1) mod 4 = 0 + */ +func commonBits(ip1 []byte, ip2 []byte) uint { + var i uint + size := uint(len(ip1)) + + for i = 0; i < size; i++ { + v := ip1[i] ^ ip2[i] + if v != 0 { + v >>= 1 + if v == 0 { + return i*8 + 7 + } + + v >>= 1 + if v == 0 { + return i*8 + 6 + } + + v >>= 1 + if v == 0 { + return i*8 + 5 + } + + v >>= 1 + if v == 0 { + return i*8 + 4 + } + + v >>= 1 + if v == 0 { + return i*8 + 3 + } + + v >>= 1 + if v == 0 { + return i*8 + 2 + } + + v >>= 1 + if v == 0 { + return i*8 + 1 + } + return i * 8 + } + } + return i * 8 +} + +func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { + if node == nil { + return node + } + + // walk recursively + + node.child[0] = node.child[0].removeByPeer(p) + node.child[1] = node.child[1].removeByPeer(p) + + if node.peer != p { + return node + } + + // remove peer & merge + + node.peer = nil + if node.child[0] == nil { + return node.child[1] + } + return node.child[0] +} + +func (node *trieEntry) choose(ip net.IP) byte { + return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 +} + +func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { + + // at leaf + + if node == nil { + return &trieEntry{ + bits: ip, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + } + + // traverse deeper + + common := commonBits(node.bits, ip) + if node.cidr <= cidr && common >= node.cidr { + if node.cidr == cidr { + node.peer = peer + return node + } + bit := node.choose(ip) + node.child[bit] = node.child[bit].insert(ip, cidr, peer) + return node + } + + // split node + + newNode := &trieEntry{ + bits: ip, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + cidr = min(cidr, common) + + // check for shorter prefix + + if newNode.cidr == cidr { + bit := newNode.choose(node.bits) + newNode.child[bit] = node + return newNode + } + + // create new parent for node & newNode + + parent := &trieEntry{ + bits: ip, + peer: nil, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + bit := parent.choose(ip) + parent.child[bit] = newNode + parent.child[bit^1] = node + + return parent +} + +func (node *trieEntry) lookup(ip net.IP) *Peer { + var found *Peer + size := uint(len(ip)) + for node != nil && commonBits(node.bits, ip) >= node.cidr { + if node.peer != nil { + found = node.peer + } + if node.bit_at_byte == size { + break + } + bit := node.choose(ip) + node = node.child[bit] + } + return found +} + +func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { + if node == nil { + return results + } + if node.peer == p { + var mask net.IPNet + mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) + if len(node.bits) == net.IPv4len { + mask.IP = net.IPv4( + node.bits[0], + node.bits[1], + node.bits[2], + node.bits[3], + ) + } else if len(node.bits) == net.IPv6len { + mask.IP = node.bits + } else { + panic(errors.New("unexpected address length")) + } + results = append(results, mask) + } + results = node.child[0].entriesForPeer(p, results) + results = node.child[1].entriesForPeer(p, results) + return results +} + +type AllowedIPs struct { + IPv4 *trieEntry + IPv6 *trieEntry + mutex sync.RWMutex +} + +func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { + table.mutex.RLock() + defer table.mutex.RUnlock() + + allowed := make([]net.IPNet, 0, 10) + allowed = table.IPv4.entriesForPeer(peer, allowed) + allowed = table.IPv6.entriesForPeer(peer, allowed) + return allowed +} + +func (table *AllowedIPs) Reset() { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = nil + table.IPv6 = nil +} + +func (table *AllowedIPs) RemoveByPeer(peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = table.IPv4.removeByPeer(peer) + table.IPv6 = table.IPv6.removeByPeer(peer) +} + +func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + switch len(ip) { + case net.IPv6len: + table.IPv6 = table.IPv6.insert(ip, cidr, peer) + case net.IPv4len: + table.IPv4 = table.IPv4.insert(ip, cidr, peer) + default: + panic(errors.New("inserting unknown address type")) + } +} + +func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv4.lookup(address) +} + +func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv6.lookup(address) +} diff --git a/allowedips_rand_test.go b/allowedips_rand_test.go new file mode 100644 index 0000000..6ec039d --- /dev/null +++ b/allowedips_rand_test.go @@ -0,0 +1,131 @@ +/* SPDX-License-Identifier: GPL-2.0 + * + * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. + */ + +package main + +import ( + "math/rand" + "sort" + "testing" +) + +const ( + NumberOfPeers = 100 + NumberOfAddresses = 250 + NumberOfTests = 10000 +) + +type SlowNode struct { + peer *Peer + cidr uint + bits []byte +} + +type SlowRouter []*SlowNode + +func (r SlowRouter) Len() int { + return len(r) +} + +func (r SlowRouter) Less(i, j int) bool { + return r[i].cidr > r[j].cidr +} + +func (r SlowRouter) Swap(i, j int) { + r[i], r[j] = r[j], r[i] +} + +func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { + for _, t := range r { + if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { + t.peer = peer + t.bits = addr + return r + } + } + r = append(r, &SlowNode{ + cidr: cidr, + bits: addr, + peer: peer, + }) + sort.Sort(r) + return r +} + +func (r SlowRouter) Lookup(addr []byte) *Peer { + for _, t := range r { + common := commonBits(t.bits, addr) + if common >= t.cidr { + return t.peer + } + } + return nil +} + +func TestTrieRandomIPv4(t *testing.T) { + var trie *trieEntry + var slow SlowRouter + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 4 + + for n := 0; n < NumberOfPeers; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < NumberOfAddresses; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % NumberOfPeers + trie = trie.insert(addr[:], cidr, peers[index]) + slow = slow.Insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < NumberOfTests; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + peer1 := slow.Lookup(addr[:]) + peer2 := trie.lookup(addr[:]) + if peer1 != peer2 { + t.Error("trieEntry did not match naive implementation, for:", addr) + } + } +} + +func TestTrieRandomIPv6(t *testing.T) { + var trie *trieEntry + var slow SlowRouter + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 16 + + for n := 0; n < NumberOfPeers; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < NumberOfAddresses; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % NumberOfPeers + trie = trie.insert(addr[:], cidr, peers[index]) + slow = slow.Insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < NumberOfTests; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + peer1 := slow.Lookup(addr[:]) + peer2 := trie.lookup(addr[:]) + if peer1 != peer2 { + t.Error("trieEntry did not match naive implementation, for:", addr) + } + } +} diff --git a/allowedips_test.go b/allowedips_test.go new file mode 100644 index 0000000..7b73af3 --- /dev/null +++ b/allowedips_test.go @@ -0,0 +1,260 @@ +/* SPDX-License-Identifier: GPL-2.0 + * + * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. + */ + +package main + +import ( + "math/rand" + "net" + "testing" +) + +/* Todo: More comprehensive + */ + +type testPairCommonBits struct { + s1 []byte + s2 []byte + match uint +} + +type testPairTrieInsert struct { + key []byte + cidr uint + peer *Peer +} + +type testPairTrieLookup struct { + key []byte + peer *Peer +} + +func printTrie(t *testing.T, p *trieEntry) { + if p == nil { + return + } + t.Log(p) + printTrie(t, p.child[0]) + printTrie(t, p.child[1]) +} + +func TestCommonBits(t *testing.T) { + + tests := []testPairCommonBits{ + {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, + {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, + {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, + {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, + {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, + } + + for _, p := range tests { + v := commonBits(p.s1, p.s2) + if v != p.match { + t.Error( + "For slice", p.s1, p.s2, + "expected match", p.match, + ",but got", v, + ) + } + } +} + +func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { + var trie *trieEntry + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 4 + + for n := 0; n < peerNumber; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < addressNumber; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % peerNumber + trie = trie.insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < b.N; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + trie.lookup(addr[:]) + } +} + +func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { + benchmarkTrie(100, 1000, net.IPv4len, b) +} + +func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { + benchmarkTrie(10, 10, net.IPv4len, b) +} + +func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { + benchmarkTrie(100, 1000, net.IPv6len, b) +} + +func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { + benchmarkTrie(10, 10, net.IPv6len, b) +} + +/* Test ported from kernel implementation: + * selftest/routingtable.h + */ +func TestTrieIPv4(t *testing.T) { + a := &Peer{} + b := &Peer{} + c := &Peer{} + d := &Peer{} + e := &Peer{} + g := &Peer{} + h := &Peer{} + + var trie *trieEntry + + insert := func(peer *Peer, a, b, c, d byte, cidr uint) { + trie = trie.insert([]byte{a, b, c, d}, cidr, peer) + } + + assertEQ := func(peer *Peer, a, b, c, d byte) { + p := trie.lookup([]byte{a, b, c, d}) + if p != peer { + t.Error("Assert EQ failed") + } + } + + assertNEQ := func(peer *Peer, a, b, c, d byte) { + p := trie.lookup([]byte{a, b, c, d}) + if p == peer { + t.Error("Assert NEQ failed") + } + } + + insert(a, 192, 168, 4, 0, 24) + insert(b, 192, 168, 4, 4, 32) + insert(c, 192, 168, 0, 0, 16) + insert(d, 192, 95, 5, 64, 27) + insert(c, 192, 95, 5, 65, 27) + insert(e, 0, 0, 0, 0, 0) + insert(g, 64, 15, 112, 0, 20) + insert(h, 64, 15, 123, 211, 25) + insert(a, 10, 0, 0, 0, 25) + insert(b, 10, 0, 0, 128, 25) + insert(a, 10, 1, 0, 0, 30) + insert(b, 10, 1, 0, 4, 30) + insert(c, 10, 1, 0, 8, 29) + insert(d, 10, 1, 0, 16, 29) + + assertEQ(a, 192, 168, 4, 20) + assertEQ(a, 192, 168, 4, 0) + assertEQ(b, 192, 168, 4, 4) + assertEQ(c, 192, 168, 200, 182) + assertEQ(c, 192, 95, 5, 68) + assertEQ(e, 192, 95, 5, 96) + assertEQ(g, 64, 15, 116, 26) + assertEQ(g, 64, 15, 127, 3) + + insert(a, 1, 0, 0, 0, 32) + insert(a, 64, 0, 0, 0, 32) + insert(a, 128, 0, 0, 0, 32) + insert(a, 192, 0, 0, 0, 32) + insert(a, 255, 0, 0, 0, 32) + + assertEQ(a, 1, 0, 0, 0) + assertEQ(a, 64, 0, 0, 0) + assertEQ(a, 128, 0, 0, 0) + assertEQ(a, 192, 0, 0, 0) + assertEQ(a, 255, 0, 0, 0) + + trie = trie.removeByPeer(a) + + assertNEQ(a, 1, 0, 0, 0) + assertNEQ(a, 64, 0, 0, 0) + assertNEQ(a, 128, 0, 0, 0) + assertNEQ(a, 192, 0, 0, 0) + assertNEQ(a, 255, 0, 0, 0) + + trie = nil + + insert(a, 192, 168, 0, 0, 16) + insert(a, 192, 168, 0, 0, 24) + + trie = trie.removeByPeer(a) + + assertNEQ(a, 192, 168, 0, 1) +} + +/* Test ported from kernel implementation: + * selftest/routingtable.h + */ +func TestTrieIPv6(t *testing.T) { + a := &Peer{} + b := &Peer{} + c := &Peer{} + d := &Peer{} + e := &Peer{} + f := &Peer{} + g := &Peer{} + h := &Peer{} + + var trie *trieEntry + + expand := func(a uint32) []byte { + var out [4]byte + out[0] = byte(a >> 24 & 0xff) + out[1] = byte(a >> 16 & 0xff) + out[2] = byte(a >> 8 & 0xff) + out[3] = byte(a & 0xff) + return out[:] + } + + insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + trie = trie.insert(addr, cidr, peer) + } + + assertEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := trie.lookup(addr) + if p != peer { + t.Error("Assert EQ failed") + } + } + + insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) + insert(c, 0x26075300, 0x60006b00, 0, 0, 64) + insert(e, 0, 0, 0, 0, 0) + insert(f, 0, 0, 0, 0, 0) + insert(g, 0x24046800, 0, 0, 0, 32) + insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) + insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) + insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + + assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) + assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) + assertEQ(f, 0x26075300, 0x60006b01, 0, 0) + assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) + assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) + assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) + assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) + assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) + assertEQ(h, 0x24046800, 0x40040800, 0, 0) + assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) + assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) +} diff --git a/device.go b/device.go index 99e451e..34af419 100644 --- a/device.go +++ b/device.go @@ -46,7 +46,7 @@ type Device struct { routing struct { mutex sync.RWMutex - table RoutingTable + table AllowedIPs } peers struct { @@ -95,7 +95,7 @@ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { // stop routing and processing of packets - device.routing.table.RemovePeer(peer) + device.routing.table.RemoveByPeer(peer) peer.Stop() // remove from peer map diff --git a/keypair.go b/keypair.go index 6f6f7c0..ea72a11 100644 --- a/keypair.go +++ b/keypair.go @@ -33,7 +33,7 @@ type Keypairs struct { mutex sync.RWMutex current *Keypair previous *Keypair - next *Keypair // not yet "confirmed by transport" + next *Keypair } func (kp *Keypairs) Current() *Keypair { diff --git a/logger.go b/logger.go index 784235c..b8012aa 100644 --- a/logger.go +++ b/logger.go @@ -40,7 +40,7 @@ func NewLogger(level int, prepend string) *Logger { logger.Debug = log.New(logDebug, "DEBUG: "+prepend, - log.Ldate|log.Ltime|log.Lshortfile, + log.Ldate|log.Ltime, ) logger.Info = log.New(logInfo, diff --git a/noise-helpers.go b/noise-helpers.go index 6e23d83..63e45b3 100644 --- a/noise-helpers.go +++ b/noise-helpers.go @@ -71,14 +71,13 @@ func isZero(val []byte) bool { return acc == 1 } +/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ func setZero(arr []byte) { for i := range arr { arr[i] = 0 } } -/* curve25519 wrappers */ - func newPrivateKey() (sk NoisePrivateKey, err error) { // clamping: https://cr.yp.to/ecdh.html _, err = rand.Read(sk[:]) diff --git a/noise-types.go b/noise-types.go index 58aa0c2..2635e01 100644 --- a/noise-types.go +++ b/noise-types.go @@ -30,7 +30,7 @@ func loadExactHex(dst []byte, src string) error { return err } if len(slice) != len(dst) { - return errors.New("Hex string does not fit the slice") + return errors.New("hex string does not fit the slice") } copy(dst, slice) return nil diff --git a/peer.go b/peer.go index f49f806..d574c71 100644 --- a/peer.go +++ b/peer.go @@ -61,7 +61,7 @@ type Peer struct { mutex sync.Mutex // held when stopping / starting routines starting sync.WaitGroup // routines pending start stopping sync.WaitGroup // routines pending stop - stop chan struct{} // size 0, stop all go-routines in peer + stop chan struct{} // size 0, stop all go routines in peer } mac CookieGenerator @@ -70,7 +70,7 @@ type Peer struct { func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { if device.isClosed.Get() { - return nil, errors.New("Device closed") + return nil, errors.New("device closed") } // lock resources @@ -87,7 +87,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // check if over limit if len(device.peers.keyMap) >= MaxPeers { - return nil, errors.New("Too many peers") + return nil, errors.New("too many peers") } // create peer @@ -104,7 +104,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { _, ok := device.peers.keyMap[pk] if ok { - return nil, errors.New("Adding existing peer") + return nil, errors.New("adding existing peer") } device.peers.keyMap[pk] = peer @@ -134,26 +134,26 @@ func (peer *Peer) SendBuffer(buffer []byte) error { defer peer.device.net.mutex.RUnlock() if peer.device.net.bind == nil { - return errors.New("No bind") + return errors.New("no bind") } peer.mutex.RLock() defer peer.mutex.RUnlock() if peer.endpoint == nil { - return errors.New("No known endpoint for peer") + return errors.New("no known endpoint for peer") } return peer.device.net.bind.Send(buffer, peer.endpoint) } -/* Returns a short string identifier for logging - */ func (peer *Peer) String() string { - return fmt.Sprintf( - "peer(%s)", - base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), - ) + base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) + abbreviatedKey := "invalid" + if len(base64Key) == 44 { + abbreviatedKey = base64Key[0:4] + "..." + base64Key[40:44] + } + return fmt.Sprintf("peer(%s)", abbreviatedKey) } func (peer *Peer) Start() { diff --git a/receive.go b/receive.go index 60a2510..32ff512 100644 --- a/receive.go +++ b/receive.go @@ -600,20 +600,24 @@ func (peer *Peer) RoutineSequentialReceiver() { // check if using new key-pair kp := &peer.keypairs - kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true if kp.next == elem.keypair { - old := kp.previous - kp.previous = kp.current - device.DeleteKeypair(old) - kp.current = kp.next - kp.next = nil - peer.timersHandshakeComplete() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: + kp.mutex.Lock() + if kp.next != elem.keypair { + kp.mutex.Unlock() + } else { + old := kp.previous + kp.previous = kp.current + device.DeleteKeypair(old) + kp.current = kp.next + kp.next = nil + kp.mutex.Unlock() + peer.timersHandshakeComplete() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } } } - kp.mutex.Unlock() peer.keepKeyFreshReceiving() peer.timersAnyAuthenticatedPacketTraversal() diff --git a/routing.go b/routing.go deleted file mode 100644 index 77c9b1e..0000000 --- a/routing.go +++ /dev/null @@ -1,70 +0,0 @@ -/* SPDX-License-Identifier: GPL-2.0 - * - * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. - */ - -package main - -import ( - "errors" - "net" - "sync" -) - -type RoutingTable struct { - IPv4 *Trie - IPv6 *Trie - mutex sync.RWMutex -} - -func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet { - table.mutex.RLock() - defer table.mutex.RUnlock() - - allowed := make([]net.IPNet, 0, 10) - allowed = table.IPv4.AllowedIPs(peer, allowed) - allowed = table.IPv6.AllowedIPs(peer, allowed) - return allowed -} - -func (table *RoutingTable) Reset() { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = nil - table.IPv6 = nil -} - -func (table *RoutingTable) RemovePeer(peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = table.IPv4.RemovePeer(peer) - table.IPv6 = table.IPv6.RemovePeer(peer) -} - -func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - switch len(ip) { - case net.IPv6len: - table.IPv6 = table.IPv6.Insert(ip, cidr, peer) - case net.IPv4len: - table.IPv4 = table.IPv4.Insert(ip, cidr, peer) - default: - panic(errors.New("Inserting unknown address type")) - } -} - -func (table *RoutingTable) LookupIPv4(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv4.Lookup(address) -} - -func (table *RoutingTable) LookupIPv6(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv6.Lookup(address) -} diff --git a/trie.go b/trie.go deleted file mode 100644 index 03f0722..0000000 --- a/trie.go +++ /dev/null @@ -1,233 +0,0 @@ -/* SPDX-License-Identifier: GPL-2.0 - * - * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. - */ - -package main - -import ( - "errors" - "net" -) - -/* Binary trie - * - * The net.IPs used here are not formatted the - * same way as those created by the "net" functions. - * Here the IPs are slices of either 4 or 16 byte (not always 16) - * - * Synchronization done separately - * See: routing.go - */ - -type Trie struct { - cidr uint - child [2]*Trie - bits []byte - peer *Peer - - // index of "branching" bit - - bit_at_byte uint - bit_at_shift uint -} - -/* Finds length of matching prefix - * - * TODO: Only use during insertion (xor + prefix mask for lookup) - * Check out - * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits) - * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match - * - * Assumption: - * len(ip1) == len(ip2) - * len(ip1) mod 4 = 0 - */ -func commonBits(ip1 []byte, ip2 []byte) uint { - var i uint - size := uint(len(ip1)) - - for i = 0; i < size; i++ { - v := ip1[i] ^ ip2[i] - if v != 0 { - v >>= 1 - if v == 0 { - return i*8 + 7 - } - - v >>= 1 - if v == 0 { - return i*8 + 6 - } - - v >>= 1 - if v == 0 { - return i*8 + 5 - } - - v >>= 1 - if v == 0 { - return i*8 + 4 - } - - v >>= 1 - if v == 0 { - return i*8 + 3 - } - - v >>= 1 - if v == 0 { - return i*8 + 2 - } - - v >>= 1 - if v == 0 { - return i*8 + 1 - } - return i * 8 - } - } - return i * 8 -} - -func (node *Trie) RemovePeer(p *Peer) *Trie { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].RemovePeer(p) - node.child[1] = node.child[1].RemovePeer(p) - - if node.peer != p { - return node - } - - // remove peer & merge - - node.peer = nil - if node.child[0] == nil { - return node.child[1] - } - return node.child[0] -} - -func (node *Trie) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 -} - -func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { - - // at leaf - - if node == nil { - return &Trie{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - } - - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.peer = peer - return node - } - bit := node.choose(ip) - node.child[bit] = node.child[bit].Insert(ip, cidr, peer) - return node - } - - // split node - - newNode := &Trie{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - - cidr = min(cidr, common) - - // check for shorter prefix - - if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode - } - - // create new parent for node & newNode - - parent := &Trie{ - bits: ip, - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent -} - -func (node *Trie) Lookup(ip net.IP) *Peer { - var found *Peer - size := uint(len(ip)) - for node != nil && commonBits(node.bits, ip) >= node.cidr { - if node.peer != nil { - found = node.peer - } - if node.bit_at_byte == size { - break - } - bit := node.choose(ip) - node = node.child[bit] - } - return found -} - -func (node *Trie) Count() uint { - if node == nil { - return 0 - } - l := node.child[0].Count() - r := node.child[1].Count() - return l + r -} - -func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { - if node == nil { - return results - } - if node.peer == p { - var mask net.IPNet - mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) - if len(node.bits) == net.IPv4len { - mask.IP = net.IPv4( - node.bits[0], - node.bits[1], - node.bits[2], - node.bits[3], - ) - } else if len(node.bits) == net.IPv6len { - mask.IP = node.bits - } else { - panic(errors.New("bug: unexpected address length")) - } - results = append(results, mask) - } - results = node.child[0].AllowedIPs(p, results) - results = node.child[1].AllowedIPs(p, results) - return results -} diff --git a/trie_rand_test.go b/trie_rand_test.go deleted file mode 100644 index 157c270..0000000 --- a/trie_rand_test.go +++ /dev/null @@ -1,131 +0,0 @@ -/* SPDX-License-Identifier: GPL-2.0 - * - * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. - */ - -package main - -import ( - "math/rand" - "sort" - "testing" -) - -const ( - NumberOfPeers = 100 - NumberOfAddresses = 250 - NumberOfTests = 10000 -) - -type SlowNode struct { - peer *Peer - cidr uint - bits []byte -} - -type SlowRouter []*SlowNode - -func (r SlowRouter) Len() int { - return len(r) -} - -func (r SlowRouter) Less(i, j int) bool { - return r[i].cidr > r[j].cidr -} - -func (r SlowRouter) Swap(i, j int) { - r[i], r[j] = r[j], r[i] -} - -func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { - for _, t := range r { - if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { - t.peer = peer - t.bits = addr - return r - } - } - r = append(r, &SlowNode{ - cidr: cidr, - bits: addr, - peer: peer, - }) - sort.Sort(r) - return r -} - -func (r SlowRouter) Lookup(addr []byte) *Peer { - for _, t := range r { - common := commonBits(t.bits, addr) - if common >= t.cidr { - return t.peer - } - } - return nil -} - -func TestTrieRandomIPv4(t *testing.T) { - var trie *Trie - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.Insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.Lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) - } - } -} - -func TestTrieRandomIPv6(t *testing.T) { - var trie *Trie - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 16 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.Insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.Lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) - } - } -} diff --git a/trie_test.go b/trie_test.go deleted file mode 100644 index 3c3b5ba..0000000 --- a/trie_test.go +++ /dev/null @@ -1,260 +0,0 @@ -/* SPDX-License-Identifier: GPL-2.0 - * - * Copyright (C) 2017-2018 Jason A. Donenfeld . All Rights Reserved. - */ - -package main - -import ( - "math/rand" - "net" - "testing" -) - -/* Todo: More comprehensive - */ - -type testPairCommonBits struct { - s1 []byte - s2 []byte - match uint -} - -type testPairTrieInsert struct { - key []byte - cidr uint - peer *Peer -} - -type testPairTrieLookup struct { - key []byte - peer *Peer -} - -func printTrie(t *testing.T, p *Trie) { - if p == nil { - return - } - t.Log(p) - printTrie(t, p.child[0]) - printTrie(t, p.child[1]) -} - -func TestCommonBits(t *testing.T) { - - tests := []testPairCommonBits{ - {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, - {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, - {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, - {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, - {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, - } - - for _, p := range tests { - v := commonBits(p.s1, p.s2) - if v != p.match { - t.Error( - "For slice", p.s1, p.s2, - "expected match", p.match, - ",but got", v, - ) - } - } -} - -func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { - var trie *Trie - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < peerNumber; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < addressNumber; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % peerNumber - trie = trie.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < b.N; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - trie.Lookup(addr[:]) - } -} - -func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv4len, b) -} - -func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv4len, b) -} - -func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv6len, b) -} - -func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv6len, b) -} - -/* Test ported from kernel implementation: - * selftest/routingtable.h - */ -func TestTrieIPv4(t *testing.T) { - a := &Peer{} - b := &Peer{} - c := &Peer{} - d := &Peer{} - e := &Peer{} - g := &Peer{} - h := &Peer{} - - var trie *Trie - - insert := func(peer *Peer, a, b, c, d byte, cidr uint) { - trie = trie.Insert([]byte{a, b, c, d}, cidr, peer) - } - - assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.Lookup([]byte{a, b, c, d}) - if p != peer { - t.Error("Assert EQ failed") - } - } - - assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.Lookup([]byte{a, b, c, d}) - if p == peer { - t.Error("Assert NEQ failed") - } - } - - insert(a, 192, 168, 4, 0, 24) - insert(b, 192, 168, 4, 4, 32) - insert(c, 192, 168, 0, 0, 16) - insert(d, 192, 95, 5, 64, 27) - insert(c, 192, 95, 5, 65, 27) - insert(e, 0, 0, 0, 0, 0) - insert(g, 64, 15, 112, 0, 20) - insert(h, 64, 15, 123, 211, 25) - insert(a, 10, 0, 0, 0, 25) - insert(b, 10, 0, 0, 128, 25) - insert(a, 10, 1, 0, 0, 30) - insert(b, 10, 1, 0, 4, 30) - insert(c, 10, 1, 0, 8, 29) - insert(d, 10, 1, 0, 16, 29) - - assertEQ(a, 192, 168, 4, 20) - assertEQ(a, 192, 168, 4, 0) - assertEQ(b, 192, 168, 4, 4) - assertEQ(c, 192, 168, 200, 182) - assertEQ(c, 192, 95, 5, 68) - assertEQ(e, 192, 95, 5, 96) - assertEQ(g, 64, 15, 116, 26) - assertEQ(g, 64, 15, 127, 3) - - insert(a, 1, 0, 0, 0, 32) - insert(a, 64, 0, 0, 0, 32) - insert(a, 128, 0, 0, 0, 32) - insert(a, 192, 0, 0, 0, 32) - insert(a, 255, 0, 0, 0, 32) - - assertEQ(a, 1, 0, 0, 0) - assertEQ(a, 64, 0, 0, 0) - assertEQ(a, 128, 0, 0, 0) - assertEQ(a, 192, 0, 0, 0) - assertEQ(a, 255, 0, 0, 0) - - trie = trie.RemovePeer(a) - - assertNEQ(a, 1, 0, 0, 0) - assertNEQ(a, 64, 0, 0, 0) - assertNEQ(a, 128, 0, 0, 0) - assertNEQ(a, 192, 0, 0, 0) - assertNEQ(a, 255, 0, 0, 0) - - trie = nil - - insert(a, 192, 168, 0, 0, 16) - insert(a, 192, 168, 0, 0, 24) - - trie = trie.RemovePeer(a) - - assertNEQ(a, 192, 168, 0, 1) -} - -/* Test ported from kernel implementation: - * selftest/routingtable.h - */ -func TestTrieIPv6(t *testing.T) { - a := &Peer{} - b := &Peer{} - c := &Peer{} - d := &Peer{} - e := &Peer{} - f := &Peer{} - g := &Peer{} - h := &Peer{} - - var trie *Trie - - expand := func(a uint32) []byte { - var out [4]byte - out[0] = byte(a >> 24 & 0xff) - out[1] = byte(a >> 16 & 0xff) - out[2] = byte(a >> 8 & 0xff) - out[3] = byte(a & 0xff) - return out[:] - } - - insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - trie = trie.Insert(addr, cidr, peer) - } - - assertEQ := func(peer *Peer, a, b, c, d uint32) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - p := trie.Lookup(addr) - if p != peer { - t.Error("Assert EQ failed") - } - } - - insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) - insert(c, 0x26075300, 0x60006b00, 0, 0, 64) - insert(e, 0, 0, 0, 0, 0) - insert(f, 0, 0, 0, 0, 0) - insert(g, 0x24046800, 0, 0, 0, 32) - insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) - insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) - insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) - insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) - - assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) - assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) - assertEQ(f, 0x26075300, 0x60006b01, 0, 0) - assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) - assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) - assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) - assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) - assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) - assertEQ(h, 0x24046800, 0x40040800, 0, 0) - assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) - assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) -} diff --git a/tun_darwin.go b/tun_darwin.go index 1d66c66..fa8efe0 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -224,7 +224,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { } func (tun *NativeTun) Close() error { - return tun.fd.Close() + err := tun.fd.Close() + close(tun.events) + return err } func (tun *NativeTun) setMTU(n int) error { diff --git a/tun_linux.go b/tun_linux.go index 18994cc..9f60d2b 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -392,6 +392,7 @@ func (tun *NativeTun) Close() error { return err } tun.closingWriter.Write([]byte{0}) + close(tun.events) return nil } diff --git a/tun_windows.go b/tun_windows.go index c0c9ff8..6eea5a3 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -125,7 +125,9 @@ func (f *NativeTUN) Events() chan TUNEvent { } func (f *NativeTUN) Close() error { - return windows.Close(f.fd) + close(f.events) + err := windows.Close(f.fd) + return err } func (f *NativeTUN) Write(b []byte) (int, error) { diff --git a/uapi.go b/uapi.go index 4b2038b..90c400a 100644 --- a/uapi.go +++ b/uapi.go @@ -91,7 +91,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes)) send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) - for _, ip := range device.routing.table.AllowedIPs(peer) { + for _, ip := range device.routing.table.EntriesForPeer(peer) { send("allowed_ip=" + ip.String()) } @@ -337,7 +337,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "replace_allowed_ips": - logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer) + logDebug.Println("UAPI: Removing all allowed EntriesForPeer for peer:", peer) if value != "true" { logError.Println("Failed to set replace_allowed_ips, invalid value:", value) @@ -349,7 +349,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } device.routing.mutex.Lock() - device.routing.table.RemovePeer(peer) + device.routing.table.RemoveByPeer(peer) device.routing.mutex.Unlock() case "allowed_ip": -- cgit v1.2.3-59-g8ed1b