From 69f0fe67b63d90e523a5a1241fb1b46c2e8dbe03 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sun, 3 Mar 2019 04:04:41 +0100 Subject: global: begin modularization --- Makefile | 8 +- allowedips.go | 251 -------------- allowedips_rand_test.go | 131 -------- allowedips_test.go | 260 -------------- bind_test.go | 55 --- conn.go | 180 ---------- conn_default.go | 170 ---------- conn_linux.go | 746 ----------------------------------------- constants.go | 41 --- cookie.go | 250 -------------- cookie_test.go | 191 ----------- device.go | 396 ---------------------- device/allowedips.go | 251 ++++++++++++++ device/allowedips_rand_test.go | 131 ++++++++ device/allowedips_test.go | 260 ++++++++++++++ device/bind_test.go | 55 +++ device/conn.go | 180 ++++++++++ device/conn_default.go | 170 ++++++++++ device/conn_linux.go | 746 +++++++++++++++++++++++++++++++++++++++++ device/constants.go | 41 +++ device/cookie.go | 250 ++++++++++++++ device/cookie_test.go | 191 +++++++++++ device/device.go | 396 ++++++++++++++++++++++ device/device_test.go | 48 +++ device/endpoint_test.go | 53 +++ device/indextable.go | 97 ++++++ device/ip.go | 22 ++ device/kdf_test.go | 84 +++++ device/keypair.go | 50 +++ device/logger.go | 59 ++++ device/mark_default.go | 12 + device/mark_unix.go | 64 ++++ device/misc.go | 48 +++ device/noise-helpers.go | 104 ++++++ device/noise-protocol.go | 600 +++++++++++++++++++++++++++++++++ device/noise-types.go | 81 +++++ device/noise_test.go | 144 ++++++++ device/peer.go | 270 +++++++++++++++ device/pools.go | 89 +++++ device/queueconstants.go | 16 + device/receive.go | 641 +++++++++++++++++++++++++++++++++++ device/send.go | 618 ++++++++++++++++++++++++++++++++++ device/timers.go | 227 +++++++++++++ device/tun.go | 55 +++ device/uapi.go | 426 +++++++++++++++++++++++ device/version.go | 3 + device_test.go | 48 --- endpoint_test.go | 53 --- go.mod | 6 +- go.sum | 13 +- helper_test.go | 92 ----- indextable.go | 97 ------ ip.go | 22 -- ipc/uapi_bsd.go | 202 +++++++++++ ipc/uapi_linux.go | 199 +++++++++++ ipc/uapi_windows.go | 76 +++++ kdf_test.go | 84 ----- keypair.go | 50 --- logger.go | 59 ---- main.go | 32 +- main_windows.go | 12 +- mark_default.go | 12 - mark_unix.go | 64 ---- misc.go | 48 --- noise-helpers.go | 104 ------ noise-protocol.go | 600 --------------------------------- noise-types.go | 81 ----- noise_test.go | 144 -------- peer.go | 270 --------------- pools.go | 89 ----- queueconstants.go | 16 - receive.go | 641 ----------------------------------- send.go | 618 ---------------------------------- timers.go | 227 ------------- tun.go | 55 --- tun/helper_test.go | 92 +++++ tun/tun_windows.go | 2 +- uapi.go | 425 ----------------------- uapi_bsd.go | 202 ----------- uapi_linux.go | 199 ----------- uapi_windows.go | 76 ----- 81 files changed, 7090 insertions(+), 7081 deletions(-) delete mode 100644 allowedips.go delete mode 100644 allowedips_rand_test.go delete mode 100644 allowedips_test.go delete mode 100644 bind_test.go delete mode 100644 conn.go delete mode 100644 conn_default.go delete mode 100644 conn_linux.go delete mode 100644 constants.go delete mode 100644 cookie.go delete mode 100644 cookie_test.go delete mode 100644 device.go create mode 100644 device/allowedips.go create mode 100644 device/allowedips_rand_test.go create mode 100644 device/allowedips_test.go create mode 100644 device/bind_test.go create mode 100644 device/conn.go create mode 100644 device/conn_default.go create mode 100644 device/conn_linux.go create mode 100644 device/constants.go create mode 100644 device/cookie.go create mode 100644 device/cookie_test.go create mode 100644 device/device.go create mode 100644 device/device_test.go create mode 100644 device/endpoint_test.go create mode 100644 device/indextable.go create mode 100644 device/ip.go create mode 100644 device/kdf_test.go create mode 100644 device/keypair.go create mode 100644 device/logger.go create mode 100644 device/mark_default.go create mode 100644 device/mark_unix.go create mode 100644 device/misc.go create mode 100644 device/noise-helpers.go create mode 100644 device/noise-protocol.go create mode 100644 device/noise-types.go create mode 100644 device/noise_test.go create mode 100644 device/peer.go create mode 100644 device/pools.go create mode 100644 device/queueconstants.go create mode 100644 device/receive.go create mode 100644 device/send.go create mode 100644 device/timers.go create mode 100644 device/tun.go create mode 100644 device/uapi.go create mode 100644 device/version.go delete mode 100644 device_test.go delete mode 100644 endpoint_test.go delete mode 100644 helper_test.go delete mode 100644 indextable.go delete mode 100644 ip.go create mode 100644 ipc/uapi_bsd.go create mode 100644 ipc/uapi_linux.go create mode 100644 ipc/uapi_windows.go delete mode 100644 kdf_test.go delete mode 100644 keypair.go delete mode 100644 logger.go delete mode 100644 mark_default.go delete mode 100644 mark_unix.go delete mode 100644 misc.go delete mode 100644 noise-helpers.go delete mode 100644 noise-protocol.go delete mode 100644 noise-types.go delete mode 100644 noise_test.go delete mode 100644 peer.go delete mode 100644 pools.go delete mode 100644 queueconstants.go delete mode 100644 receive.go delete mode 100644 send.go delete mode 100644 timers.go delete mode 100644 tun.go create mode 100644 tun/helper_test.go delete mode 100644 uapi.go delete mode 100644 uapi_bsd.go delete mode 100644 uapi_linux.go delete mode 100644 uapi_windows.go diff --git a/Makefile b/Makefile index cff8a7f..69a8100 100644 --- a/Makefile +++ b/Makefile @@ -24,10 +24,10 @@ MAKEFLAGS += --no-print-directory generate-version-and-build: @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ tag="$$(git describe --dirty 2>/dev/null)" && \ - ver="$$(printf 'package main\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \ - [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ - echo "$$ver" > version.go && \ - git update-index --assume-unchanged version.go || true + ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \ + [ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \ + echo "$$ver" > device/version.go && \ + git update-index --assume-unchanged device/version.go || true @$(MAKE) wireguard-go wireguard-go: $(wildcard *.go) $(wildcard */*.go) diff --git a/allowedips.go b/allowedips.go deleted file mode 100644 index 2c4f601..0000000 --- a/allowedips.go +++ /dev/null @@ -1,251 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "errors" - "math/bits" - "net" - "sync" - "unsafe" -) - -type trieEntry struct { - cidr uint - child [2]*trieEntry - bits net.IP - peer *Peer - - // index of "branching" bit - - bit_at_byte uint - bit_at_shift uint -} - -func isLittleEndian() bool { - one := uint32(1) - return *(*byte)(unsafe.Pointer(&one)) != 0 -} - -func swapU32(i uint32) uint32 { - if !isLittleEndian() { - return i - } - - return bits.ReverseBytes32(i) -} - -func swapU64(i uint64) uint64 { - if !isLittleEndian() { - return i - } - - return bits.ReverseBytes64(i) -} - -func commonBits(ip1 net.IP, ip2 net.IP) uint { - size := len(ip1) - if size == net.IPv4len { - a := (*uint32)(unsafe.Pointer(&ip1[0])) - b := (*uint32)(unsafe.Pointer(&ip2[0])) - x := *a ^ *b - return uint(bits.LeadingZeros32(swapU32(x))) - } else if size == net.IPv6len { - a := (*uint64)(unsafe.Pointer(&ip1[0])) - b := (*uint64)(unsafe.Pointer(&ip2[0])) - x := *a ^ *b - if x != 0 { - return uint(bits.LeadingZeros64(swapU64(x))) - } - a = (*uint64)(unsafe.Pointer(&ip1[8])) - b = (*uint64)(unsafe.Pointer(&ip2[8])) - x = *a ^ *b - return 64 + uint(bits.LeadingZeros64(swapU64(x))) - } else { - panic("Wrong size bit string") - } -} - -func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].removeByPeer(p) - node.child[1] = node.child[1].removeByPeer(p) - - if node.peer != p { - return node - } - - // remove peer & merge - - node.peer = nil - if node.child[0] == nil { - return node.child[1] - } - return node.child[0] -} - -func (node *trieEntry) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 -} - -func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { - - // at leaf - - if node == nil { - return &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - } - - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.peer = peer - return node - } - bit := node.choose(ip) - node.child[bit] = node.child[bit].insert(ip, cidr, peer) - return node - } - - // split node - - newNode := &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - - cidr = min(cidr, common) - - // check for shorter prefix - - if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode - } - - // create new parent for node & newNode - - parent := &trieEntry{ - bits: ip, - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), - } - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent -} - -func (node *trieEntry) lookup(ip net.IP) *Peer { - var found *Peer - size := uint(len(ip)) - for node != nil && commonBits(node.bits, ip) >= node.cidr { - if node.peer != nil { - found = node.peer - } - if node.bit_at_byte == size { - break - } - bit := node.choose(ip) - node = node.child[bit] - } - return found -} - -func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { - if node == nil { - return results - } - if node.peer == p { - mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) - results = append(results, net.IPNet{ - Mask: mask, - IP: node.bits.Mask(mask), - }) - } - results = node.child[0].entriesForPeer(p, results) - results = node.child[1].entriesForPeer(p, results) - return results -} - -type AllowedIPs struct { - IPv4 *trieEntry - IPv6 *trieEntry - mutex sync.RWMutex -} - -func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { - table.mutex.RLock() - defer table.mutex.RUnlock() - - allowed := make([]net.IPNet, 0, 10) - allowed = table.IPv4.entriesForPeer(peer, allowed) - allowed = table.IPv6.entriesForPeer(peer, allowed) - return allowed -} - -func (table *AllowedIPs) Reset() { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = nil - table.IPv6 = nil -} - -func (table *AllowedIPs) RemoveByPeer(peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = table.IPv4.removeByPeer(peer) - table.IPv6 = table.IPv6.removeByPeer(peer) -} - -func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - switch len(ip) { - case net.IPv6len: - table.IPv6 = table.IPv6.insert(ip, cidr, peer) - case net.IPv4len: - table.IPv4 = table.IPv4.insert(ip, cidr, peer) - default: - panic(errors.New("inserting unknown address type")) - } -} - -func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv4.lookup(address) -} - -func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv6.lookup(address) -} diff --git a/allowedips_rand_test.go b/allowedips_rand_test.go deleted file mode 100644 index 56a31c4..0000000 --- a/allowedips_rand_test.go +++ /dev/null @@ -1,131 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "math/rand" - "sort" - "testing" -) - -const ( - NumberOfPeers = 100 - NumberOfAddresses = 250 - NumberOfTests = 10000 -) - -type SlowNode struct { - peer *Peer - cidr uint - bits []byte -} - -type SlowRouter []*SlowNode - -func (r SlowRouter) Len() int { - return len(r) -} - -func (r SlowRouter) Less(i, j int) bool { - return r[i].cidr > r[j].cidr -} - -func (r SlowRouter) Swap(i, j int) { - r[i], r[j] = r[j], r[i] -} - -func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { - for _, t := range r { - if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { - t.peer = peer - t.bits = addr - return r - } - } - r = append(r, &SlowNode{ - cidr: cidr, - bits: addr, - peer: peer, - }) - sort.Sort(r) - return r -} - -func (r SlowRouter) Lookup(addr []byte) *Peer { - for _, t := range r { - common := commonBits(t.bits, addr) - if common >= t.cidr { - return t.peer - } - } - return nil -} - -func TestTrieRandomIPv4(t *testing.T) { - var trie *trieEntry - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) - } - } -} - -func TestTrieRandomIPv6(t *testing.T) { - var trie *trieEntry - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 16 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) - } - } -} diff --git a/allowedips_test.go b/allowedips_test.go deleted file mode 100644 index ca694ab..0000000 --- a/allowedips_test.go +++ /dev/null @@ -1,260 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "math/rand" - "net" - "testing" -) - -/* Todo: More comprehensive - */ - -type testPairCommonBits struct { - s1 []byte - s2 []byte - match uint -} - -type testPairTrieInsert struct { - key []byte - cidr uint - peer *Peer -} - -type testPairTrieLookup struct { - key []byte - peer *Peer -} - -func printTrie(t *testing.T, p *trieEntry) { - if p == nil { - return - } - t.Log(p) - printTrie(t, p.child[0]) - printTrie(t, p.child[1]) -} - -func TestCommonBits(t *testing.T) { - - tests := []testPairCommonBits{ - {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, - {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, - {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, - {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, - {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, - } - - for _, p := range tests { - v := commonBits(p.s1, p.s2) - if v != p.match { - t.Error( - "For slice", p.s1, p.s2, - "expected match", p.match, - ",but got", v, - ) - } - } -} - -func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { - var trie *trieEntry - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < peerNumber; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < addressNumber; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % peerNumber - trie = trie.insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < b.N; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - trie.lookup(addr[:]) - } -} - -func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv4len, b) -} - -func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv4len, b) -} - -func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv6len, b) -} - -func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv6len, b) -} - -/* Test ported from kernel implementation: - * selftest/allowedips.h - */ -func TestTrieIPv4(t *testing.T) { - a := &Peer{} - b := &Peer{} - c := &Peer{} - d := &Peer{} - e := &Peer{} - g := &Peer{} - h := &Peer{} - - var trie *trieEntry - - insert := func(peer *Peer, a, b, c, d byte, cidr uint) { - trie = trie.insert([]byte{a, b, c, d}, cidr, peer) - } - - assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) - if p != peer { - t.Error("Assert EQ failed") - } - } - - assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) - if p == peer { - t.Error("Assert NEQ failed") - } - } - - insert(a, 192, 168, 4, 0, 24) - insert(b, 192, 168, 4, 4, 32) - insert(c, 192, 168, 0, 0, 16) - insert(d, 192, 95, 5, 64, 27) - insert(c, 192, 95, 5, 65, 27) - insert(e, 0, 0, 0, 0, 0) - insert(g, 64, 15, 112, 0, 20) - insert(h, 64, 15, 123, 211, 25) - insert(a, 10, 0, 0, 0, 25) - insert(b, 10, 0, 0, 128, 25) - insert(a, 10, 1, 0, 0, 30) - insert(b, 10, 1, 0, 4, 30) - insert(c, 10, 1, 0, 8, 29) - insert(d, 10, 1, 0, 16, 29) - - assertEQ(a, 192, 168, 4, 20) - assertEQ(a, 192, 168, 4, 0) - assertEQ(b, 192, 168, 4, 4) - assertEQ(c, 192, 168, 200, 182) - assertEQ(c, 192, 95, 5, 68) - assertEQ(e, 192, 95, 5, 96) - assertEQ(g, 64, 15, 116, 26) - assertEQ(g, 64, 15, 127, 3) - - insert(a, 1, 0, 0, 0, 32) - insert(a, 64, 0, 0, 0, 32) - insert(a, 128, 0, 0, 0, 32) - insert(a, 192, 0, 0, 0, 32) - insert(a, 255, 0, 0, 0, 32) - - assertEQ(a, 1, 0, 0, 0) - assertEQ(a, 64, 0, 0, 0) - assertEQ(a, 128, 0, 0, 0) - assertEQ(a, 192, 0, 0, 0) - assertEQ(a, 255, 0, 0, 0) - - trie = trie.removeByPeer(a) - - assertNEQ(a, 1, 0, 0, 0) - assertNEQ(a, 64, 0, 0, 0) - assertNEQ(a, 128, 0, 0, 0) - assertNEQ(a, 192, 0, 0, 0) - assertNEQ(a, 255, 0, 0, 0) - - trie = nil - - insert(a, 192, 168, 0, 0, 16) - insert(a, 192, 168, 0, 0, 24) - - trie = trie.removeByPeer(a) - - assertNEQ(a, 192, 168, 0, 1) -} - -/* Test ported from kernel implementation: - * selftest/allowedips.h - */ -func TestTrieIPv6(t *testing.T) { - a := &Peer{} - b := &Peer{} - c := &Peer{} - d := &Peer{} - e := &Peer{} - f := &Peer{} - g := &Peer{} - h := &Peer{} - - var trie *trieEntry - - expand := func(a uint32) []byte { - var out [4]byte - out[0] = byte(a >> 24 & 0xff) - out[1] = byte(a >> 16 & 0xff) - out[2] = byte(a >> 8 & 0xff) - out[3] = byte(a & 0xff) - return out[:] - } - - insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - trie = trie.insert(addr, cidr, peer) - } - - assertEQ := func(peer *Peer, a, b, c, d uint32) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - p := trie.lookup(addr) - if p != peer { - t.Error("Assert EQ failed") - } - } - - insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) - insert(c, 0x26075300, 0x60006b00, 0, 0, 64) - insert(e, 0, 0, 0, 0, 0) - insert(f, 0, 0, 0, 0, 0) - insert(g, 0x24046800, 0, 0, 0, 32) - insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) - insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) - insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) - insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) - - assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) - assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) - assertEQ(f, 0x26075300, 0x60006b01, 0, 0) - assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) - assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) - assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) - assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) - assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) - assertEQ(h, 0x24046800, 0x40040800, 0, 0) - assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) - assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) -} diff --git a/bind_test.go b/bind_test.go deleted file mode 100644 index c534646..0000000 --- a/bind_test.go +++ /dev/null @@ -1,55 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import "errors" - -type DummyDatagram struct { - msg []byte - endpoint Endpoint - world bool // better type -} - -type DummyBind struct { - in6 chan DummyDatagram - ou6 chan DummyDatagram - in4 chan DummyDatagram - ou4 chan DummyDatagram - closed bool -} - -func (b *DummyBind) SetMark(v uint32) error { - return nil -} - -func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - datagram, ok := <-b.in6 - if !ok { - return 0, nil, errors.New("closed") - } - copy(buff, datagram.msg) - return len(datagram.msg), datagram.endpoint, nil -} - -func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - datagram, ok := <-b.in4 - if !ok { - return 0, nil, errors.New("closed") - } - copy(buff, datagram.msg) - return len(datagram.msg), datagram.endpoint, nil -} - -func (b *DummyBind) Close() error { - close(b.in6) - close(b.in4) - b.closed = true - return nil -} - -func (b *DummyBind) Send(buff []byte, end Endpoint) error { - return nil -} diff --git a/conn.go b/conn.go deleted file mode 100644 index b19a9c2..0000000 --- a/conn.go +++ /dev/null @@ -1,180 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "errors" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "net" -) - -const ( - ConnRoutineNumber = 2 -) - -/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic - */ -type Bind interface { - SetMark(value uint32) error - ReceiveIPv6(buff []byte) (int, Endpoint, error) - ReceiveIPv4(buff []byte) (int, Endpoint, error) - Send(buff []byte, end Endpoint) error - Close() error -} - -/* An Endpoint maintains the source/destination caching for a peer - * - * dst : the remote address of a peer ("endpoint" in uapi terminology) - * src : the local address from which datagrams originate going to the peer - */ -type Endpoint interface { - ClearSrc() // clears the source address - SrcToString() string // returns the local source address (ip:port) - DstToString() string // returns the destination address (ip:port) - DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP -} - -func parseEndpoint(s string) (*net.UDPAddr, error) { - - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 - } - return addr, err -} - -func unsafeCloseBind(device *Device) error { - var err error - netc := &device.net - if netc.bind != nil { - err = netc.bind.Close() - netc.bind = nil - } - netc.stopping.Wait() - return err -} - -func (device *Device) BindSetMark(mark uint32) error { - - device.net.Lock() - defer device.net.Unlock() - - // check if modified - - if device.net.fwmark == mark { - return nil - } - - // update fwmark on existing bind - - device.net.fwmark = mark - if device.isUp.Get() && device.net.bind != nil { - if err := device.net.bind.SetMark(mark); err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - return nil -} - -func (device *Device) BindUpdate() error { - - device.net.Lock() - defer device.net.Unlock() - - // close existing sockets - - if err := unsafeCloseBind(device); err != nil { - return err - } - - // open new sockets - - if device.isUp.Get() { - - // bind to new port - - var err error - netc := &device.net - netc.bind, netc.port, err = CreateBind(netc.port, device) - if err != nil { - netc.bind = nil - netc.port = 0 - return err - } - - // set fwmark - - if netc.fwmark != 0 { - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - // start receiving routines - - device.net.starting.Add(ConnRoutineNumber) - device.net.stopping.Add(ConnRoutineNumber) - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - device.net.starting.Wait() - - device.log.Debug.Println("UDP bind has been updated") - } - - return nil -} - -func (device *Device) BindClose() error { - device.net.Lock() - err := unsafeCloseBind(device) - device.net.Unlock() - return err -} diff --git a/conn_default.go b/conn_default.go deleted file mode 100644 index 6f17de5..0000000 --- a/conn_default.go +++ /dev/null @@ -1,170 +0,0 @@ -// +build !linux android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "net" - "os" - "syscall" -) - -/* This code is meant to be a temporary solution - * on platforms for which the sticky socket / source caching behavior - * has not yet been implemented. - * - * See conn_linux.go for an implementation on the linux platform. - */ - -type NativeBind struct { - ipv4 *net.UDPConn - ipv6 *net.UDPConn -} - -type NativeEndpoint net.UDPAddr - -var _ Bind = (*NativeBind)(nil) -var _ Endpoint = (*NativeEndpoint)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*NativeEndpoint)(addr), err -} - -func (_ *NativeEndpoint) ClearSrc() {} - -func (e *NativeEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP -} - -func (e *NativeEndpoint) SrcIP() net.IP { - return nil // not supported -} - -func (e *NativeEndpoint) DstToBytes() []byte { - addr := (*net.UDPAddr)(e) - out := addr.IP.To4() - if out == nil { - out = addr.IP - } - out = append(out, byte(addr.Port&0xff)) - out = append(out, byte((addr.Port>>8)&0xff)) - return out -} - -func (e *NativeEndpoint) DstToString() string { - return (*net.UDPAddr)(e).String() -} - -func (e *NativeEndpoint) SrcToString() string { - return "" -} - -func listenNet(network string, port int) (*net.UDPConn, int, error) { - - // listen - - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) - if err != nil { - return nil, 0, err - } - - // retrieve port - - laddr := conn.LocalAddr() - uaddr, err := net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - if err != nil { - return nil, 0, err - } - return conn, uaddr.Port, nil -} - -func extractErrno(err error) error { - opErr, ok := err.(*net.OpError) - if !ok { - return nil - } - syscallErr, ok := opErr.Err.(*os.SyscallError) - if !ok { - return nil - } - return syscallErr.Err -} - -func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { - var err error - var bind NativeBind - - port := int(uport) - - bind.ipv4, port, err = listenNet("udp4", port) - if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { - return nil, 0, err - } - - bind.ipv6, port, err = listenNet("udp6", port) - if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { - bind.ipv4.Close() - bind.ipv4 = nil - return nil, 0, err - } - - return &bind, uint16(port), nil -} - -func (bind *NativeBind) Close() error { - var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - } - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - if bind.ipv4 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv4.ReadFromUDP(buff) - if endpoint != nil { - endpoint.IP = endpoint.IP.To4() - } - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - if bind.ipv6 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { - var err error - nend := endpoint.(*NativeEndpoint) - if nend.IP.To4() != nil { - if bind.ipv4 == nil { - return syscall.EAFNOSUPPORT - } - _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } else { - if bind.ipv6 == nil { - return syscall.EAFNOSUPPORT - } - _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } - return err -} diff --git a/conn_linux.go b/conn_linux.go deleted file mode 100644 index d3dbb98..0000000 --- a/conn_linux.go +++ /dev/null @@ -1,746 +0,0 @@ -// +build !android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - * - * This implements userspace semantics of "sticky sockets", modeled after - * WireGuard's kernelspace implementation. This is more or less a straight port - * of the sticky-sockets.c example code: - * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c - * - * Currently there is no way to achieve this within the net package: - * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. - */ - -package main - -import ( - "errors" - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/rwcancel" - "net" - "strconv" - "sync" - "syscall" - "unsafe" -) - -const ( - FD_ERR = -1 -) - -type IPv4Source struct { - src [4]byte - ifindex int32 -} - -type IPv6Source struct { - src [16]byte - //ifindex belongs in dst.ZoneId -} - -type NativeEndpoint struct { - dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte - src [unsafe.Sizeof(IPv6Source{})]byte - isV6 bool -} - -func (endpoint *NativeEndpoint) src4() *IPv4Source { - return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) src6() *IPv6Source { - return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { - return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) -} - -func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { - return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) -} - -type NativeBind struct { - sock4 int - sock6 int - netlinkSock int - netlinkCancel *rwcancel.RWCancel - lastMark uint32 -} - -var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = (*NativeBind)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - var end NativeEndpoint - addr, err := parseEndpoint(s) - if err != nil { - return nil, err - } - - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := end.dst4() - end.isV6 = false - dst.Port = addr.Port - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return &end, nil - } - - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) - if err != nil { - return nil, err - } - dst := end.dst6() - end.isV6 = true - dst.Port = addr.Port - dst.ZoneId = zone - copy(dst.Addr[:], ipv6[:]) - end.ClearSrc() - return &end, nil - } - - return nil, errors.New("Invalid IP address") -} - -func createNetlinkRouteSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - return -1, err - } - saddr := &unix.SockaddrNetlink{ - Family: unix.AF_NETLINK, - Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), - } - err = unix.Bind(sock, saddr) - if err != nil { - unix.Close(sock) - return -1, err - } - return sock, nil - -} - -func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) { - var err error - var bind NativeBind - var newPort uint16 - - bind.netlinkSock, err = createNetlinkRouteSocket() - if err != nil { - return nil, 0, err - } - bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock) - if err != nil { - unix.Close(bind.netlinkSock) - return nil, 0, err - } - - go bind.routineRouteListener(device) - - // attempt ipv6 bind, update port if succesful - - bind.sock6, newPort, err = create6(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() - return nil, 0, err - } - } else { - port = newPort - } - - // attempt ipv4 bind, update port if succesful - - bind.sock4, newPort, err = create4(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() - unix.Close(bind.sock6) - return nil, 0, err - } - } else { - port = newPort - } - - if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR { - return nil, 0, errors.New("ipv4 and ipv6 not supported") - } - - return &bind, port, nil -} - -func (bind *NativeBind) SetMark(value uint32) error { - if bind.sock6 != -1 { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - if bind.sock4 != -1 { - err := unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - bind.lastMark = value - return nil -} - -func closeUnblock(fd int) error { - // shutdown to unblock readers and writers - unix.Shutdown(fd, unix.SHUT_RDWR) - return unix.Close(fd) -} - -func (bind *NativeBind) Close() error { - var err1, err2, err3 error - if bind.sock6 != -1 { - err1 = closeUnblock(bind.sock6) - } - if bind.sock4 != -1 { - err2 = closeUnblock(bind.sock4) - } - err3 = bind.netlinkCancel.Cancel() - - if err1 != nil { - return err1 - } - if err2 != nil { - return err2 - } - return err3 -} - -func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - if bind.sock6 == -1 { - return 0, nil, syscall.EAFNOSUPPORT - } - n, err := receive6( - bind.sock6, - buff, - &end, - ) - return n, &end, err -} - -func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - if bind.sock4 == -1 { - return 0, nil, syscall.EAFNOSUPPORT - } - n, err := receive4( - bind.sock4, - buff, - &end, - ) - return n, &end, err -} - -func (bind *NativeBind) Send(buff []byte, end Endpoint) error { - nend := end.(*NativeEndpoint) - if !nend.isV6 { - if bind.sock4 == -1 { - return syscall.EAFNOSUPPORT - } - return send4(bind.sock4, nend, buff) - } else { - if bind.sock6 == -1 { - return syscall.EAFNOSUPPORT - } - return send6(bind.sock6, nend, buff) - } -} - -func (end *NativeEndpoint) SrcIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.src4().src[0], - end.src4().src[1], - end.src4().src[2], - end.src4().src[3], - ) - } else { - return end.src6().src[:] - } -} - -func (end *NativeEndpoint) DstIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.dst4().Addr[0], - end.dst4().Addr[1], - end.dst4().Addr[2], - end.dst4().Addr[3], - ) - } else { - return end.dst6().Addr[:] - } -} - -func (end *NativeEndpoint) DstToBytes() []byte { - if !end.isV6 { - return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] - } else { - return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] - } -} - -func (end *NativeEndpoint) SrcToString() string { - return end.SrcIP().String() -} - -func (end *NativeEndpoint) DstToString() string { - var udpAddr net.UDPAddr - udpAddr.IP = end.DstIP() - if !end.isV6 { - udpAddr.Port = end.dst4().Port - } else { - udpAddr.Port = end.dst6().Port - } - return udpAddr.String() -} - -func (end *NativeEndpoint) ClearDst() { - for i := range end.dst { - end.dst[i] = 0 - } -} - -func (end *NativeEndpoint) ClearSrc() { - for i := range end.src { - end.src[i] = 0 - } -} - -func zoneToUint32(zone string) (uint32, error) { - if zone == "" { - return 0, nil - } - if intr, err := net.InterfaceByName(zone); err == nil { - return uint32(intr.Index), nil - } - n, err := strconv.ParseUint(zone, 10, 32) - return uint32(n), err -} - -func create4(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return FD_ERR, 0, err - } - - addr := unix.SockaddrInet4{ - Port: int(port), - } - - // set sockopts and bind - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IP, - unix.IP_PKTINFO, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - return FD_ERR, 0, err - } - - return fd, uint16(addr.Port), err -} - -func create6(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return FD_ERR, 0, err - } - - // set sockopts and bind - - addr := unix.SockaddrInet6{ - Port: int(port), - } - - if err := func() error { - - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_RECVPKTINFO, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_V6ONLY, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - - }(); err != nil { - unix.Close(fd) - return FD_ERR, 0, err - } - - return fd, uint16(addr.Port), err -} - -func send4(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet4Pktinfo{ - Spec_dst: end.src4().src, - Ifindex: end.src4().ifindex, - }, - } - - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - } - - return err -} - -func send6(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet6Pktinfo{ - Addr: end.src6().src, - Ifindex: end.dst6().ZoneId, - }, - } - - if cmsg.pktinfo.Addr == [16]byte{} { - cmsg.pktinfo.Ifindex = 0 - } - - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - } - - return err -} - -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // contruct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = false - - if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { - *end.dst4() = *newDst4 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().src = cmsg.pktinfo.Spec_dst - end.src4().ifindex = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // contruct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = true - - if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { - *end.dst6() = *newDst6 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && - cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6().src = cmsg.pktinfo.Addr - end.dst6().ZoneId = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func (bind *NativeBind) routineRouteListener(device *Device) { - type peerEndpointPtr struct { - peer *Peer - endpoint *Endpoint - } - var reqPeer map[uint32]peerEndpointPtr - var reqPeerLock sync.Mutex - - defer unix.Close(bind.netlinkSock) - - for msg := make([]byte, 1<<16); ; { - var err error - var msgn int - for { - msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) - if err == nil || !rwcancel.RetryAfterError(err) { - break - } - if !bind.netlinkCancel.ReadyRead() { - return - } - } - if err != nil { - return - } - - for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { - - hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) - - if uint(hdr.Len) > uint(len(remain)) { - break - } - - switch hdr.Type { - case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - if hdr.Seq <= MaxPeers && hdr.Seq > 0 { - if uint(len(remain)) < uint(hdr.Len) { - break - } - if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { - attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] - for { - if uint(len(attr)) < uint(unix.SizeofRtAttr) { - break - } - attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) - if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { - break - } - if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { - ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) - reqPeerLock.Lock() - if reqPeer == nil { - reqPeerLock.Unlock() - break - } - pePtr, ok := reqPeer[hdr.Seq] - reqPeerLock.Unlock() - if !ok { - break - } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() - break - } - if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx { - pePtr.peer.Unlock() - break - } - pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc() - pePtr.peer.Unlock() - } - attr = attr[attrhdr.Len:] - } - } - break - } - reqPeerLock.Lock() - reqPeer = make(map[uint32]peerEndpointPtr) - reqPeerLock.Unlock() - go func() { - device.peers.RLock() - i := uint32(1) - for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { - peer.RUnlock() - continue - } - if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { - peer.RUnlock() - break - } - nlmsg := struct { - hdr unix.NlMsghdr - msg unix.RtMsg - dsthdr unix.RtAttr - dst [4]byte - srchdr unix.RtAttr - src [4]byte - markhdr unix.RtAttr - mark uint32 - }{ - unix.NlMsghdr{ - Type: uint16(unix.RTM_GETROUTE), - Flags: unix.NLM_F_REQUEST, - Seq: i, - }, - unix.RtMsg{ - Family: unix.AF_INET, - Dst_len: 32, - Src_len: 32, - }, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_DST, - }, - peer.endpoint.(*NativeEndpoint).dst4().Addr, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_SRC, - }, - peer.endpoint.(*NativeEndpoint).src4().src, - unix.RtAttr{ - Len: 8, - Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix - }, - uint32(bind.lastMark), - } - nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) - reqPeerLock.Lock() - reqPeer[i] = peerEndpointPtr{ - peer: peer, - endpoint: &peer.endpoint, - } - reqPeerLock.Unlock() - peer.RUnlock() - i++ - _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) - if err != nil { - break - } - } - device.peers.RUnlock() - }() - } - remain = remain[hdr.Len:] - } - } -} diff --git a/constants.go b/constants.go deleted file mode 100644 index ab93e5f..0000000 --- a/constants.go +++ /dev/null @@ -1,41 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "time" -) - -/* Specification constants */ - -const ( - RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 - RejectAfterMessages = (1 << 64) - (1 << 4) - 1 - RekeyAfterTime = time.Second * 120 - RekeyAttemptTime = time.Second * 90 - RekeyTimeout = time.Second * 5 - MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */ - RekeyTimeoutJitterMaxMs = 334 - RejectAfterTime = time.Second * 180 - KeepaliveTimeout = time.Second * 10 - CookieRefreshTime = time.Second * 120 - HandshakeInitationRate = time.Second / 20 - PaddingMultiple = 16 -) - -const ( - MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) - MaxMessageSize = MaxSegmentSize // maximum size of transport message - MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content -) - -/* Implementation constants */ - -const ( - UnderLoadQueueSize = QueueHandshakeSize / 8 - UnderLoadAfterTime = time.Second // how long does the device remain under load after detected - MaxPeers = 1 << 16 // maximum number of configured peers -) diff --git a/cookie.go b/cookie.go deleted file mode 100644 index c648bf1..0000000 --- a/cookie.go +++ /dev/null @@ -1,250 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "crypto/hmac" - "crypto/rand" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/chacha20poly1305" - "sync" - "time" -) - -type CookieChecker struct { - sync.RWMutex - mac1 struct { - key [blake2s.Size]byte - } - mac2 struct { - secret [blake2s.Size]byte - secretSet time.Time - encryptionKey [chacha20poly1305.KeySize]byte - } -} - -type CookieGenerator struct { - sync.RWMutex - mac1 struct { - key [blake2s.Size]byte - } - mac2 struct { - cookie [blake2s.Size128]byte - cookieSet time.Time - hasLastMAC1 bool - lastMAC1 [blake2s.Size128]byte - encryptionKey [chacha20poly1305.KeySize]byte - } -} - -func (st *CookieChecker) Init(pk NoisePublicKey) { - st.Lock() - defer st.Unlock() - - // mac1 state - - func() { - hash, _ := blake2s.New256(nil) - hash.Write([]byte(WGLabelMAC1)) - hash.Write(pk[:]) - hash.Sum(st.mac1.key[:0]) - }() - - // mac2 state - - func() { - hash, _ := blake2s.New256(nil) - hash.Write([]byte(WGLabelCookie)) - hash.Write(pk[:]) - hash.Sum(st.mac2.encryptionKey[:0]) - }() - - st.mac2.secretSet = time.Time{} -} - -func (st *CookieChecker) CheckMAC1(msg []byte) bool { - st.RLock() - defer st.RUnlock() - - size := len(msg) - smac2 := size - blake2s.Size128 - smac1 := smac2 - blake2s.Size128 - - var mac1 [blake2s.Size128]byte - - mac, _ := blake2s.New128(st.mac1.key[:]) - mac.Write(msg[:smac1]) - mac.Sum(mac1[:0]) - - return hmac.Equal(mac1[:], msg[smac1:smac2]) -} - -func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { - st.RLock() - defer st.RUnlock() - - if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { - return false - } - - // derive cookie key - - var cookie [blake2s.Size128]byte - func() { - mac, _ := blake2s.New128(st.mac2.secret[:]) - mac.Write(src) - mac.Sum(cookie[:0]) - }() - - // calculate mac of packet (including mac1) - - smac2 := len(msg) - blake2s.Size128 - - var mac2 [blake2s.Size128]byte - func() { - mac, _ := blake2s.New128(cookie[:]) - mac.Write(msg[:smac2]) - mac.Sum(mac2[:0]) - }() - - return hmac.Equal(mac2[:], msg[smac2:]) -} - -func (st *CookieChecker) CreateReply( - msg []byte, - recv uint32, - src []byte, -) (*MessageCookieReply, error) { - - st.RLock() - - // refresh cookie secret - - if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { - st.RUnlock() - st.Lock() - _, err := rand.Read(st.mac2.secret[:]) - if err != nil { - st.Unlock() - return nil, err - } - st.mac2.secretSet = time.Now() - st.Unlock() - st.RLock() - } - - // derive cookie - - var cookie [blake2s.Size128]byte - func() { - mac, _ := blake2s.New128(st.mac2.secret[:]) - mac.Write(src) - mac.Sum(cookie[:0]) - }() - - // encrypt cookie - - size := len(msg) - - smac2 := size - blake2s.Size128 - smac1 := smac2 - blake2s.Size128 - - reply := new(MessageCookieReply) - reply.Type = MessageCookieReplyType - reply.Receiver = recv - - _, err := rand.Read(reply.Nonce[:]) - if err != nil { - st.RUnlock() - return nil, err - } - - xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) - xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2]) - - st.RUnlock() - - return reply, nil -} - -func (st *CookieGenerator) Init(pk NoisePublicKey) { - st.Lock() - defer st.Unlock() - - func() { - hash, _ := blake2s.New256(nil) - hash.Write([]byte(WGLabelMAC1)) - hash.Write(pk[:]) - hash.Sum(st.mac1.key[:0]) - }() - - func() { - hash, _ := blake2s.New256(nil) - hash.Write([]byte(WGLabelCookie)) - hash.Write(pk[:]) - hash.Sum(st.mac2.encryptionKey[:0]) - }() - - st.mac2.cookieSet = time.Time{} -} - -func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { - st.Lock() - defer st.Unlock() - - if !st.mac2.hasLastMAC1 { - return false - } - - var cookie [blake2s.Size128]byte - - xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) - _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) - - if err != nil { - return false - } - - st.mac2.cookieSet = time.Now() - st.mac2.cookie = cookie - return true -} - -func (st *CookieGenerator) AddMacs(msg []byte) { - - size := len(msg) - - smac2 := size - blake2s.Size128 - smac1 := smac2 - blake2s.Size128 - - mac1 := msg[smac1:smac2] - mac2 := msg[smac2:] - - st.Lock() - defer st.Unlock() - - // set mac1 - - func() { - mac, _ := blake2s.New128(st.mac1.key[:]) - mac.Write(msg[:smac1]) - mac.Sum(mac1[:0]) - }() - copy(st.mac2.lastMAC1[:], mac1) - st.mac2.hasLastMAC1 = true - - // set mac2 - - if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime { - return - } - - func() { - mac, _ := blake2s.New128(st.mac2.cookie[:]) - mac.Write(msg[:smac2]) - mac.Sum(mac2[:0]) - }() -} diff --git a/cookie_test.go b/cookie_test.go deleted file mode 100644 index 0586260..0000000 --- a/cookie_test.go +++ /dev/null @@ -1,191 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "testing" -) - -func TestCookieMAC1(t *testing.T) { - - // setup generator / checker - - var ( - generator CookieGenerator - checker CookieChecker - ) - - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) - } - pk := sk.publicKey() - - generator.Init(pk) - checker.Init(pk) - - // check mac1 - - src := []byte{192, 168, 13, 37, 10, 10, 10} - - checkMAC1 := func(msg []byte) { - generator.AddMacs(msg) - if !checker.CheckMAC1(msg) { - t.Fatal("MAC1 generation/verification failed") - } - if checker.CheckMAC2(msg, src) { - t.Fatal("MAC2 generation/verification failed") - } - } - - checkMAC1([]byte{ - 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, - 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, - 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, - 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, - 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, - 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, - 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, - 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, - 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, - 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, - 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, - 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, - 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, - 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, - 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, - 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, - }) - - checkMAC1([]byte{ - 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, - 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, - 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, - 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, - 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, - 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, - 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, - 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, - }) - - checkMAC1([]byte{ - 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, - 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, - 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, - 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, - }) - - // exchange cookie reply - - func() { - msg := []byte{ - 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, - 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, - 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, - 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, - 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, - 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, - 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, - 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, - 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, - 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, - 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, - 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, - 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, - 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, - 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, - 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, - } - generator.AddMacs(msg) - reply, err := checker.CreateReply(msg, 1377, src) - if err != nil { - t.Fatal("Failed to create cookie reply:", err) - } - if !generator.ConsumeReply(reply) { - t.Fatal("Failed to consume cookie reply") - } - }() - - // check mac2 - - checkMAC2 := func(msg []byte) { - generator.AddMacs(msg) - - if !checker.CheckMAC1(msg) { - t.Fatal("MAC1 generation/verification failed") - } - if !checker.CheckMAC2(msg, src) { - t.Fatal("MAC2 generation/verification failed") - } - - msg[5] ^= 0x20 - - if checker.CheckMAC1(msg) { - t.Fatal("MAC1 generation/verification failed") - } - if checker.CheckMAC2(msg, src) { - t.Fatal("MAC2 generation/verification failed") - } - - msg[5] ^= 0x20 - - srcBad1 := []byte{192, 168, 13, 37, 40, 01} - if checker.CheckMAC2(msg, srcBad1) { - t.Fatal("MAC2 generation/verification failed") - } - - srcBad2 := []byte{192, 168, 13, 38, 40, 01} - if checker.CheckMAC2(msg, srcBad2) { - t.Fatal("MAC2 generation/verification failed") - } - } - - checkMAC2([]byte{ - 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, - 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, - 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, - 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, - 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, - 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, - 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, - 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, - }) - - checkMAC2([]byte{ - 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, - 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, - 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, - 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, - 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, - 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, - 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, - 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, - 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, - 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, - 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, - 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, - 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, - 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, - 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, - 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, - 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, - 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, - 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, - 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, - 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, - 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, - 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, - 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, - 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, - 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, - 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, - 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, - 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, - 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, - 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, - 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, - }) -} diff --git a/device.go b/device.go deleted file mode 100644 index 18e1138..0000000 --- a/device.go +++ /dev/null @@ -1,396 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "golang.zx2c4.com/wireguard/ratelimiter" - "golang.zx2c4.com/wireguard/tun" - "runtime" - "sync" - "sync/atomic" - "time" -) - -const ( - DeviceRoutineNumberPerCPU = 3 - DeviceRoutineNumberAdditional = 2 -) - -type Device struct { - isUp AtomicBool // device is (going) up - isClosed AtomicBool // device is closed? (acting as guard) - log *Logger - - // synchronized resources (locks acquired in order) - - state struct { - starting sync.WaitGroup - stopping sync.WaitGroup - sync.Mutex - changing AtomicBool - current bool - } - - net struct { - starting sync.WaitGroup - stopping sync.WaitGroup - sync.RWMutex - bind Bind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) - } - - staticIdentity struct { - sync.RWMutex - privateKey NoisePrivateKey - publicKey NoisePublicKey - } - - peers struct { - sync.RWMutex - keyMap map[NoisePublicKey]*Peer - } - - // unprotected / "self-synchronising resources" - - allowedips AllowedIPs - indexTable IndexTable - cookieChecker CookieChecker - - rate struct { - underLoadUntil atomic.Value - limiter ratelimiter.Ratelimiter - } - - pool struct { - messageBufferPool *sync.Pool - messageBufferReuseChan chan *[MaxMessageSize]byte - inboundElementPool *sync.Pool - inboundElementReuseChan chan *QueueInboundElement - outboundElementPool *sync.Pool - outboundElementReuseChan chan *QueueOutboundElement - } - - queue struct { - encryption chan *QueueOutboundElement - decryption chan *QueueInboundElement - handshake chan QueueHandshakeElement - } - - signals struct { - stop chan struct{} - } - - tun struct { - device tun.TUNDevice - mtu int32 - } -} - -/* Converts the peer into a "zombie", which remains in the peer map, - * but processes no packets and does not exists in the routing table. - * - * Must hold device.peers.Mutex - */ -func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { - - // stop routing and processing of packets - - device.allowedips.RemoveByPeer(peer) - peer.Stop() - - // remove from peer map - - delete(device.peers.keyMap, key) -} - -func deviceUpdateState(device *Device) { - - // check if state already being updated (guard) - - if device.state.changing.Swap(true) { - return - } - - // compare to current state of device - - device.state.Lock() - - newIsUp := device.isUp.Get() - - if newIsUp == device.state.current { - device.state.changing.Set(false) - device.state.Unlock() - return - } - - // change state of device - - switch newIsUp { - case true: - if err := device.BindUpdate(); err != nil { - device.isUp.Set(false) - break - } - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Start() - if peer.persistentKeepaliveInterval > 0 { - peer.SendKeepalive() - } - } - device.peers.RUnlock() - - case false: - device.BindClose() - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Stop() - } - device.peers.RUnlock() - } - - // update state variables - - device.state.current = newIsUp - device.state.changing.Set(false) - device.state.Unlock() - - // check for state change in the mean time - - deviceUpdateState(device) -} - -func (device *Device) Up() { - - // closed device cannot be brought up - - if device.isClosed.Get() { - return - } - - device.isUp.Set(true) - deviceUpdateState(device) -} - -func (device *Device) Down() { - device.isUp.Set(false) - deviceUpdateState(device) -} - -func (device *Device) IsUnderLoad() bool { - - // check if currently under load - - now := time.Now() - underLoad := len(device.queue.handshake) >= UnderLoadQueueSize - if underLoad { - device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime)) - return true - } - - // check if recently under load - - until := device.rate.underLoadUntil.Load().(time.Time) - return until.After(now) -} - -func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { - - // lock required resources - - device.staticIdentity.Lock() - defer device.staticIdentity.Unlock() - - device.peers.Lock() - defer device.peers.Unlock() - - for _, peer := range device.peers.keyMap { - peer.handshake.mutex.RLock() - defer peer.handshake.mutex.RUnlock() - } - - // remove peers with matching public keys - - publicKey := sk.publicKey() - for key, peer := range device.peers.keyMap { - if peer.handshake.remoteStatic.Equals(publicKey) { - unsafeRemovePeer(device, peer, key) - } - } - - // update key material - - device.staticIdentity.privateKey = sk - device.staticIdentity.publicKey = publicKey - device.cookieChecker.Init(publicKey) - - // do static-static DH pre-computations - - rmKey := device.staticIdentity.privateKey.IsZero() - - for key, peer := range device.peers.keyMap { - - handshake := &peer.handshake - - if rmKey { - handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{} - } else { - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) - } - - if isZero(handshake.precomputedStaticStatic[:]) { - unsafeRemovePeer(device, peer, key) - } - } - - return nil -} - -func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device { - device := new(Device) - - device.isUp.Set(false) - device.isClosed.Set(false) - - device.log = logger - - device.tun.device = tunDevice - mtu, err := device.tun.device.MTU() - if err != nil { - logger.Error.Println("Trouble determining MTU, assuming default:", err) - mtu = DefaultMTU - } - device.tun.mtu = int32(mtu) - - device.peers.keyMap = make(map[NoisePublicKey]*Peer) - - device.rate.limiter.Init() - device.rate.underLoadUntil.Store(time.Time{}) - - device.indexTable.Init() - device.allowedips.Reset() - - device.PopulatePools() - - // create queues - - device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) - device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) - device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) - - // prepare signals - - device.signals.stop = make(chan struct{}) - - // prepare net - - device.net.port = 0 - device.net.bind = nil - - // start workers - - cpus := runtime.NumCPU() - device.state.starting.Wait() - device.state.stopping.Wait() - device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) - device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) - for i := 0; i < cpus; i += 1 { - go device.RoutineEncryption() - go device.RoutineDecryption() - go device.RoutineHandshake() - } - - go device.RoutineReadFromTUN() - go device.RoutineTUNEventReader() - - device.state.starting.Wait() - - return device -} - -func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { - device.peers.RLock() - defer device.peers.RUnlock() - - return device.peers.keyMap[pk] -} - -func (device *Device) RemovePeer(key NoisePublicKey) { - device.peers.Lock() - defer device.peers.Unlock() - - // stop peer and remove from routing - - peer, ok := device.peers.keyMap[key] - if ok { - unsafeRemovePeer(device, peer, key) - } -} - -func (device *Device) RemoveAllPeers() { - device.peers.Lock() - defer device.peers.Unlock() - - for key, peer := range device.peers.keyMap { - unsafeRemovePeer(device, peer, key) - } - - device.peers.keyMap = make(map[NoisePublicKey]*Peer) -} - -func (device *Device) FlushPacketQueues() { - for { - select { - case elem, ok := <-device.queue.decryption: - if ok { - elem.Drop() - } - case elem, ok := <-device.queue.encryption: - if ok { - elem.Drop() - } - case <-device.queue.handshake: - default: - return - } - } - -} - -func (device *Device) Close() { - if device.isClosed.Swap(true) { - return - } - - device.state.starting.Wait() - - device.log.Info.Println("Device closing") - device.state.changing.Set(true) - device.state.Lock() - defer device.state.Unlock() - - device.tun.device.Close() - device.BindClose() - - device.isUp.Set(false) - - close(device.signals.stop) - - device.RemoveAllPeers() - - device.state.stopping.Wait() - device.FlushPacketQueues() - - device.rate.limiter.Close() - - device.state.changing.Set(false) - device.log.Info.Println("Interface closed") -} - -func (device *Device) Wait() chan struct{} { - return device.signals.stop -} diff --git a/device/allowedips.go b/device/allowedips.go new file mode 100644 index 0000000..efc27c0 --- /dev/null +++ b/device/allowedips.go @@ -0,0 +1,251 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + "math/bits" + "net" + "sync" + "unsafe" +) + +type trieEntry struct { + cidr uint + child [2]*trieEntry + bits net.IP + peer *Peer + + // index of "branching" bit + + bit_at_byte uint + bit_at_shift uint +} + +func isLittleEndian() bool { + one := uint32(1) + return *(*byte)(unsafe.Pointer(&one)) != 0 +} + +func swapU32(i uint32) uint32 { + if !isLittleEndian() { + return i + } + + return bits.ReverseBytes32(i) +} + +func swapU64(i uint64) uint64 { + if !isLittleEndian() { + return i + } + + return bits.ReverseBytes64(i) +} + +func commonBits(ip1 net.IP, ip2 net.IP) uint { + size := len(ip1) + if size == net.IPv4len { + a := (*uint32)(unsafe.Pointer(&ip1[0])) + b := (*uint32)(unsafe.Pointer(&ip2[0])) + x := *a ^ *b + return uint(bits.LeadingZeros32(swapU32(x))) + } else if size == net.IPv6len { + a := (*uint64)(unsafe.Pointer(&ip1[0])) + b := (*uint64)(unsafe.Pointer(&ip2[0])) + x := *a ^ *b + if x != 0 { + return uint(bits.LeadingZeros64(swapU64(x))) + } + a = (*uint64)(unsafe.Pointer(&ip1[8])) + b = (*uint64)(unsafe.Pointer(&ip2[8])) + x = *a ^ *b + return 64 + uint(bits.LeadingZeros64(swapU64(x))) + } else { + panic("Wrong size bit string") + } +} + +func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { + if node == nil { + return node + } + + // walk recursively + + node.child[0] = node.child[0].removeByPeer(p) + node.child[1] = node.child[1].removeByPeer(p) + + if node.peer != p { + return node + } + + // remove peer & merge + + node.peer = nil + if node.child[0] == nil { + return node.child[1] + } + return node.child[0] +} + +func (node *trieEntry) choose(ip net.IP) byte { + return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 +} + +func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { + + // at leaf + + if node == nil { + return &trieEntry{ + bits: ip, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + } + + // traverse deeper + + common := commonBits(node.bits, ip) + if node.cidr <= cidr && common >= node.cidr { + if node.cidr == cidr { + node.peer = peer + return node + } + bit := node.choose(ip) + node.child[bit] = node.child[bit].insert(ip, cidr, peer) + return node + } + + // split node + + newNode := &trieEntry{ + bits: ip, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + cidr = min(cidr, common) + + // check for shorter prefix + + if newNode.cidr == cidr { + bit := newNode.choose(node.bits) + newNode.child[bit] = node + return newNode + } + + // create new parent for node & newNode + + parent := &trieEntry{ + bits: ip, + peer: nil, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + bit := parent.choose(ip) + parent.child[bit] = newNode + parent.child[bit^1] = node + + return parent +} + +func (node *trieEntry) lookup(ip net.IP) *Peer { + var found *Peer + size := uint(len(ip)) + for node != nil && commonBits(node.bits, ip) >= node.cidr { + if node.peer != nil { + found = node.peer + } + if node.bit_at_byte == size { + break + } + bit := node.choose(ip) + node = node.child[bit] + } + return found +} + +func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { + if node == nil { + return results + } + if node.peer == p { + mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) + results = append(results, net.IPNet{ + Mask: mask, + IP: node.bits.Mask(mask), + }) + } + results = node.child[0].entriesForPeer(p, results) + results = node.child[1].entriesForPeer(p, results) + return results +} + +type AllowedIPs struct { + IPv4 *trieEntry + IPv6 *trieEntry + mutex sync.RWMutex +} + +func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { + table.mutex.RLock() + defer table.mutex.RUnlock() + + allowed := make([]net.IPNet, 0, 10) + allowed = table.IPv4.entriesForPeer(peer, allowed) + allowed = table.IPv6.entriesForPeer(peer, allowed) + return allowed +} + +func (table *AllowedIPs) Reset() { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = nil + table.IPv6 = nil +} + +func (table *AllowedIPs) RemoveByPeer(peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = table.IPv4.removeByPeer(peer) + table.IPv6 = table.IPv6.removeByPeer(peer) +} + +func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + switch len(ip) { + case net.IPv6len: + table.IPv6 = table.IPv6.insert(ip, cidr, peer) + case net.IPv4len: + table.IPv4 = table.IPv4.insert(ip, cidr, peer) + default: + panic(errors.New("inserting unknown address type")) + } +} + +func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv4.lookup(address) +} + +func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv6.lookup(address) +} diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go new file mode 100644 index 0000000..59c10f7 --- /dev/null +++ b/device/allowedips_rand_test.go @@ -0,0 +1,131 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "math/rand" + "sort" + "testing" +) + +const ( + NumberOfPeers = 100 + NumberOfAddresses = 250 + NumberOfTests = 10000 +) + +type SlowNode struct { + peer *Peer + cidr uint + bits []byte +} + +type SlowRouter []*SlowNode + +func (r SlowRouter) Len() int { + return len(r) +} + +func (r SlowRouter) Less(i, j int) bool { + return r[i].cidr > r[j].cidr +} + +func (r SlowRouter) Swap(i, j int) { + r[i], r[j] = r[j], r[i] +} + +func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { + for _, t := range r { + if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { + t.peer = peer + t.bits = addr + return r + } + } + r = append(r, &SlowNode{ + cidr: cidr, + bits: addr, + peer: peer, + }) + sort.Sort(r) + return r +} + +func (r SlowRouter) Lookup(addr []byte) *Peer { + for _, t := range r { + common := commonBits(t.bits, addr) + if common >= t.cidr { + return t.peer + } + } + return nil +} + +func TestTrieRandomIPv4(t *testing.T) { + var trie *trieEntry + var slow SlowRouter + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 4 + + for n := 0; n < NumberOfPeers; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < NumberOfAddresses; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % NumberOfPeers + trie = trie.insert(addr[:], cidr, peers[index]) + slow = slow.Insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < NumberOfTests; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + peer1 := slow.Lookup(addr[:]) + peer2 := trie.lookup(addr[:]) + if peer1 != peer2 { + t.Error("Trie did not match naive implementation, for:", addr) + } + } +} + +func TestTrieRandomIPv6(t *testing.T) { + var trie *trieEntry + var slow SlowRouter + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 16 + + for n := 0; n < NumberOfPeers; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < NumberOfAddresses; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % NumberOfPeers + trie = trie.insert(addr[:], cidr, peers[index]) + slow = slow.Insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < NumberOfTests; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + peer1 := slow.Lookup(addr[:]) + peer2 := trie.lookup(addr[:]) + if peer1 != peer2 { + t.Error("Trie did not match naive implementation, for:", addr) + } + } +} diff --git a/device/allowedips_test.go b/device/allowedips_test.go new file mode 100644 index 0000000..075ff06 --- /dev/null +++ b/device/allowedips_test.go @@ -0,0 +1,260 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "math/rand" + "net" + "testing" +) + +/* Todo: More comprehensive + */ + +type testPairCommonBits struct { + s1 []byte + s2 []byte + match uint +} + +type testPairTrieInsert struct { + key []byte + cidr uint + peer *Peer +} + +type testPairTrieLookup struct { + key []byte + peer *Peer +} + +func printTrie(t *testing.T, p *trieEntry) { + if p == nil { + return + } + t.Log(p) + printTrie(t, p.child[0]) + printTrie(t, p.child[1]) +} + +func TestCommonBits(t *testing.T) { + + tests := []testPairCommonBits{ + {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, + {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, + {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, + {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, + {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, + } + + for _, p := range tests { + v := commonBits(p.s1, p.s2) + if v != p.match { + t.Error( + "For slice", p.s1, p.s2, + "expected match", p.match, + ",but got", v, + ) + } + } +} + +func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { + var trie *trieEntry + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 4 + + for n := 0; n < peerNumber; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < addressNumber; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % peerNumber + trie = trie.insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < b.N; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + trie.lookup(addr[:]) + } +} + +func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { + benchmarkTrie(100, 1000, net.IPv4len, b) +} + +func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { + benchmarkTrie(10, 10, net.IPv4len, b) +} + +func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { + benchmarkTrie(100, 1000, net.IPv6len, b) +} + +func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { + benchmarkTrie(10, 10, net.IPv6len, b) +} + +/* Test ported from kernel implementation: + * selftest/allowedips.h + */ +func TestTrieIPv4(t *testing.T) { + a := &Peer{} + b := &Peer{} + c := &Peer{} + d := &Peer{} + e := &Peer{} + g := &Peer{} + h := &Peer{} + + var trie *trieEntry + + insert := func(peer *Peer, a, b, c, d byte, cidr uint) { + trie = trie.insert([]byte{a, b, c, d}, cidr, peer) + } + + assertEQ := func(peer *Peer, a, b, c, d byte) { + p := trie.lookup([]byte{a, b, c, d}) + if p != peer { + t.Error("Assert EQ failed") + } + } + + assertNEQ := func(peer *Peer, a, b, c, d byte) { + p := trie.lookup([]byte{a, b, c, d}) + if p == peer { + t.Error("Assert NEQ failed") + } + } + + insert(a, 192, 168, 4, 0, 24) + insert(b, 192, 168, 4, 4, 32) + insert(c, 192, 168, 0, 0, 16) + insert(d, 192, 95, 5, 64, 27) + insert(c, 192, 95, 5, 65, 27) + insert(e, 0, 0, 0, 0, 0) + insert(g, 64, 15, 112, 0, 20) + insert(h, 64, 15, 123, 211, 25) + insert(a, 10, 0, 0, 0, 25) + insert(b, 10, 0, 0, 128, 25) + insert(a, 10, 1, 0, 0, 30) + insert(b, 10, 1, 0, 4, 30) + insert(c, 10, 1, 0, 8, 29) + insert(d, 10, 1, 0, 16, 29) + + assertEQ(a, 192, 168, 4, 20) + assertEQ(a, 192, 168, 4, 0) + assertEQ(b, 192, 168, 4, 4) + assertEQ(c, 192, 168, 200, 182) + assertEQ(c, 192, 95, 5, 68) + assertEQ(e, 192, 95, 5, 96) + assertEQ(g, 64, 15, 116, 26) + assertEQ(g, 64, 15, 127, 3) + + insert(a, 1, 0, 0, 0, 32) + insert(a, 64, 0, 0, 0, 32) + insert(a, 128, 0, 0, 0, 32) + insert(a, 192, 0, 0, 0, 32) + insert(a, 255, 0, 0, 0, 32) + + assertEQ(a, 1, 0, 0, 0) + assertEQ(a, 64, 0, 0, 0) + assertEQ(a, 128, 0, 0, 0) + assertEQ(a, 192, 0, 0, 0) + assertEQ(a, 255, 0, 0, 0) + + trie = trie.removeByPeer(a) + + assertNEQ(a, 1, 0, 0, 0) + assertNEQ(a, 64, 0, 0, 0) + assertNEQ(a, 128, 0, 0, 0) + assertNEQ(a, 192, 0, 0, 0) + assertNEQ(a, 255, 0, 0, 0) + + trie = nil + + insert(a, 192, 168, 0, 0, 16) + insert(a, 192, 168, 0, 0, 24) + + trie = trie.removeByPeer(a) + + assertNEQ(a, 192, 168, 0, 1) +} + +/* Test ported from kernel implementation: + * selftest/allowedips.h + */ +func TestTrieIPv6(t *testing.T) { + a := &Peer{} + b := &Peer{} + c := &Peer{} + d := &Peer{} + e := &Peer{} + f := &Peer{} + g := &Peer{} + h := &Peer{} + + var trie *trieEntry + + expand := func(a uint32) []byte { + var out [4]byte + out[0] = byte(a >> 24 & 0xff) + out[1] = byte(a >> 16 & 0xff) + out[2] = byte(a >> 8 & 0xff) + out[3] = byte(a & 0xff) + return out[:] + } + + insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + trie = trie.insert(addr, cidr, peer) + } + + assertEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := trie.lookup(addr) + if p != peer { + t.Error("Assert EQ failed") + } + } + + insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) + insert(c, 0x26075300, 0x60006b00, 0, 0, 64) + insert(e, 0, 0, 0, 0, 0) + insert(f, 0, 0, 0, 0, 0) + insert(g, 0x24046800, 0, 0, 0, 32) + insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) + insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) + insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + + assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) + assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) + assertEQ(f, 0x26075300, 0x60006b01, 0, 0) + assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) + assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) + assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) + assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) + assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) + assertEQ(h, 0x24046800, 0x40040800, 0, 0) + assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) + assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) +} diff --git a/device/bind_test.go b/device/bind_test.go new file mode 100644 index 0000000..0c2e2cf --- /dev/null +++ b/device/bind_test.go @@ -0,0 +1,55 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import "errors" + +type DummyDatagram struct { + msg []byte + endpoint Endpoint + world bool // better type +} + +type DummyBind struct { + in6 chan DummyDatagram + ou6 chan DummyDatagram + in4 chan DummyDatagram + ou4 chan DummyDatagram + closed bool +} + +func (b *DummyBind) SetMark(v uint32) error { + return nil +} + +func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + datagram, ok := <-b.in6 + if !ok { + return 0, nil, errors.New("closed") + } + copy(buff, datagram.msg) + return len(datagram.msg), datagram.endpoint, nil +} + +func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + datagram, ok := <-b.in4 + if !ok { + return 0, nil, errors.New("closed") + } + copy(buff, datagram.msg) + return len(datagram.msg), datagram.endpoint, nil +} + +func (b *DummyBind) Close() error { + close(b.in6) + close(b.in4) + b.closed = true + return nil +} + +func (b *DummyBind) Send(buff []byte, end Endpoint) error { + return nil +} diff --git a/device/conn.go b/device/conn.go new file mode 100644 index 0000000..2594680 --- /dev/null +++ b/device/conn.go @@ -0,0 +1,180 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" +) + +const ( + ConnRoutineNumber = 2 +) + +/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic + */ +type Bind interface { + SetMark(value uint32) error + ReceiveIPv6(buff []byte) (int, Endpoint, error) + ReceiveIPv4(buff []byte) (int, Endpoint, error) + Send(buff []byte, end Endpoint) error + Close() error +} + +/* An Endpoint maintains the source/destination caching for a peer + * + * dst : the remote address of a peer ("endpoint" in uapi terminology) + * src : the local address from which datagrams originate going to the peer + */ +type Endpoint interface { + ClearSrc() // clears the source address + SrcToString() string // returns the local source address (ip:port) + DstToString() string // returns the destination address (ip:port) + DstToBytes() []byte // used for mac2 cookie calculations + DstIP() net.IP + SrcIP() net.IP +} + +func parseEndpoint(s string) (*net.UDPAddr, error) { + + // ensure that the host is an IP address + + host, _, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + if ip := net.ParseIP(host); ip == nil { + return nil, errors.New("Failed to parse IP address: " + host) + } + + // parse address and port + + addr, err := net.ResolveUDPAddr("udp", s) + if err != nil { + return nil, err + } + ip4 := addr.IP.To4() + if ip4 != nil { + addr.IP = ip4 + } + return addr, err +} + +func unsafeCloseBind(device *Device) error { + var err error + netc := &device.net + if netc.bind != nil { + err = netc.bind.Close() + netc.bind = nil + } + netc.stopping.Wait() + return err +} + +func (device *Device) BindSetMark(mark uint32) error { + + device.net.Lock() + defer device.net.Unlock() + + // check if modified + + if device.net.fwmark == mark { + return nil + } + + // update fwmark on existing bind + + device.net.fwmark = mark + if device.isUp.Get() && device.net.bind != nil { + if err := device.net.bind.SetMark(mark); err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + return nil +} + +func (device *Device) BindUpdate() error { + + device.net.Lock() + defer device.net.Unlock() + + // close existing sockets + + if err := unsafeCloseBind(device); err != nil { + return err + } + + // open new sockets + + if device.isUp.Get() { + + // bind to new port + + var err error + netc := &device.net + netc.bind, netc.port, err = CreateBind(netc.port, device) + if err != nil { + netc.bind = nil + netc.port = 0 + return err + } + + // set fwmark + + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + // start receiving routines + + device.net.starting.Add(ConnRoutineNumber) + device.net.stopping.Add(ConnRoutineNumber) + go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + device.net.starting.Wait() + + device.log.Debug.Println("UDP bind has been updated") + } + + return nil +} + +func (device *Device) BindClose() error { + device.net.Lock() + err := unsafeCloseBind(device) + device.net.Unlock() + return err +} \ No newline at end of file diff --git a/device/conn_default.go b/device/conn_default.go new file mode 100644 index 0000000..8a86719 --- /dev/null +++ b/device/conn_default.go @@ -0,0 +1,170 @@ +// +build !linux android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "net" + "os" + "syscall" +) + +/* This code is meant to be a temporary solution + * on platforms for which the sticky socket / source caching behavior + * has not yet been implemented. + * + * See conn_linux.go for an implementation on the linux platform. + */ + +type NativeBind struct { + ipv4 *net.UDPConn + ipv6 *net.UDPConn +} + +type NativeEndpoint net.UDPAddr + +var _ Bind = (*NativeBind)(nil) +var _ Endpoint = (*NativeEndpoint)(nil) + +func CreateEndpoint(s string) (Endpoint, error) { + addr, err := parseEndpoint(s) + return (*NativeEndpoint)(addr), err +} + +func (_ *NativeEndpoint) ClearSrc() {} + +func (e *NativeEndpoint) DstIP() net.IP { + return (*net.UDPAddr)(e).IP +} + +func (e *NativeEndpoint) SrcIP() net.IP { + return nil // not supported +} + +func (e *NativeEndpoint) DstToBytes() []byte { + addr := (*net.UDPAddr)(e) + out := addr.IP.To4() + if out == nil { + out = addr.IP + } + out = append(out, byte(addr.Port&0xff)) + out = append(out, byte((addr.Port>>8)&0xff)) + return out +} + +func (e *NativeEndpoint) DstToString() string { + return (*net.UDPAddr)(e).String() +} + +func (e *NativeEndpoint) SrcToString() string { + return "" +} + +func listenNet(network string, port int) (*net.UDPConn, int, error) { + + // listen + + conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + if err != nil { + return nil, 0, err + } + + // retrieve port + + laddr := conn.LocalAddr() + uaddr, err := net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + if err != nil { + return nil, 0, err + } + return conn, uaddr.Port, nil +} + +func extractErrno(err error) error { + opErr, ok := err.(*net.OpError) + if !ok { + return nil + } + syscallErr, ok := opErr.Err.(*os.SyscallError) + if !ok { + return nil + } + return syscallErr.Err +} + +func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { + var err error + var bind NativeBind + + port := int(uport) + + bind.ipv4, port, err = listenNet("udp4", port) + if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { + return nil, 0, err + } + + bind.ipv6, port, err = listenNet("udp6", port) + if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { + bind.ipv4.Close() + bind.ipv4 = nil + return nil, 0, err + } + + return &bind, uint16(port), nil +} + +func (bind *NativeBind) Close() error { + var err1, err2 error + if bind.ipv4 != nil { + err1 = bind.ipv4.Close() + } + if bind.ipv6 != nil { + err2 = bind.ipv6.Close() + } + if err1 != nil { + return err1 + } + return err2 +} + +func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + if bind.ipv4 == nil { + return 0, nil, syscall.EAFNOSUPPORT + } + n, endpoint, err := bind.ipv4.ReadFromUDP(buff) + if endpoint != nil { + endpoint.IP = endpoint.IP.To4() + } + return n, (*NativeEndpoint)(endpoint), err +} + +func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + if bind.ipv6 == nil { + return 0, nil, syscall.EAFNOSUPPORT + } + n, endpoint, err := bind.ipv6.ReadFromUDP(buff) + return n, (*NativeEndpoint)(endpoint), err +} + +func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { + var err error + nend := endpoint.(*NativeEndpoint) + if nend.IP.To4() != nil { + if bind.ipv4 == nil { + return syscall.EAFNOSUPPORT + } + _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) + } else { + if bind.ipv6 == nil { + return syscall.EAFNOSUPPORT + } + _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) + } + return err +} diff --git a/device/conn_linux.go b/device/conn_linux.go new file mode 100644 index 0000000..49949d5 --- /dev/null +++ b/device/conn_linux.go @@ -0,0 +1,746 @@ +// +build !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * + * This implements userspace semantics of "sticky sockets", modeled after + * WireGuard's kernelspace implementation. This is more or less a straight port + * of the sticky-sockets.c example code: + * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c + * + * Currently there is no way to achieve this within the net package: + * See e.g. https://github.com/golang/go/issues/17930 + * So this code is remains platform dependent. + */ + +package device + +import ( + "errors" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/rwcancel" + "net" + "strconv" + "sync" + "syscall" + "unsafe" +) + +const ( + FD_ERR = -1 +) + +type IPv4Source struct { + src [4]byte + ifindex int32 +} + +type IPv6Source struct { + src [16]byte + //ifindex belongs in dst.ZoneId +} + +type NativeEndpoint struct { + dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte + src [unsafe.Sizeof(IPv6Source{})]byte + isV6 bool +} + +func (endpoint *NativeEndpoint) src4() *IPv4Source { + return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) +} + +func (endpoint *NativeEndpoint) src6() *IPv6Source { + return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) +} + +func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { + return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) +} + +func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { + return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) +} + +type NativeBind struct { + sock4 int + sock6 int + netlinkSock int + netlinkCancel *rwcancel.RWCancel + lastMark uint32 +} + +var _ Endpoint = (*NativeEndpoint)(nil) +var _ Bind = (*NativeBind)(nil) + +func CreateEndpoint(s string) (Endpoint, error) { + var end NativeEndpoint + addr, err := parseEndpoint(s) + if err != nil { + return nil, err + } + + ipv4 := addr.IP.To4() + if ipv4 != nil { + dst := end.dst4() + end.isV6 = false + dst.Port = addr.Port + copy(dst.Addr[:], ipv4) + end.ClearSrc() + return &end, nil + } + + ipv6 := addr.IP.To16() + if ipv6 != nil { + zone, err := zoneToUint32(addr.Zone) + if err != nil { + return nil, err + } + dst := end.dst6() + end.isV6 = true + dst.Port = addr.Port + dst.ZoneId = zone + copy(dst.Addr[:], ipv6[:]) + end.ClearSrc() + return &end, nil + } + + return nil, errors.New("Invalid IP address") +} + +func createNetlinkRouteSocket() (int, error) { + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + if err != nil { + return -1, err + } + saddr := &unix.SockaddrNetlink{ + Family: unix.AF_NETLINK, + Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), + } + err = unix.Bind(sock, saddr) + if err != nil { + unix.Close(sock) + return -1, err + } + return sock, nil + +} + +func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) { + var err error + var bind NativeBind + var newPort uint16 + + bind.netlinkSock, err = createNetlinkRouteSocket() + if err != nil { + return nil, 0, err + } + bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock) + if err != nil { + unix.Close(bind.netlinkSock) + return nil, 0, err + } + + go bind.routineRouteListener(device) + + // attempt ipv6 bind, update port if succesful + + bind.sock6, newPort, err = create6(port) + if err != nil { + if err != syscall.EAFNOSUPPORT { + bind.netlinkCancel.Cancel() + return nil, 0, err + } + } else { + port = newPort + } + + // attempt ipv4 bind, update port if succesful + + bind.sock4, newPort, err = create4(port) + if err != nil { + if err != syscall.EAFNOSUPPORT { + bind.netlinkCancel.Cancel() + unix.Close(bind.sock6) + return nil, 0, err + } + } else { + port = newPort + } + + if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR { + return nil, 0, errors.New("ipv4 and ipv6 not supported") + } + + return &bind, port, nil +} + +func (bind *NativeBind) SetMark(value uint32) error { + if bind.sock6 != -1 { + err := unix.SetsockoptInt( + bind.sock6, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) + + if err != nil { + return err + } + } + + if bind.sock4 != -1 { + err := unix.SetsockoptInt( + bind.sock4, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) + + if err != nil { + return err + } + } + + bind.lastMark = value + return nil +} + +func closeUnblock(fd int) error { + // shutdown to unblock readers and writers + unix.Shutdown(fd, unix.SHUT_RDWR) + return unix.Close(fd) +} + +func (bind *NativeBind) Close() error { + var err1, err2, err3 error + if bind.sock6 != -1 { + err1 = closeUnblock(bind.sock6) + } + if bind.sock4 != -1 { + err2 = closeUnblock(bind.sock4) + } + err3 = bind.netlinkCancel.Cancel() + + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + return err3 +} + +func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + if bind.sock6 == -1 { + return 0, nil, syscall.EAFNOSUPPORT + } + n, err := receive6( + bind.sock6, + buff, + &end, + ) + return n, &end, err +} + +func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + if bind.sock4 == -1 { + return 0, nil, syscall.EAFNOSUPPORT + } + n, err := receive4( + bind.sock4, + buff, + &end, + ) + return n, &end, err +} + +func (bind *NativeBind) Send(buff []byte, end Endpoint) error { + nend := end.(*NativeEndpoint) + if !nend.isV6 { + if bind.sock4 == -1 { + return syscall.EAFNOSUPPORT + } + return send4(bind.sock4, nend, buff) + } else { + if bind.sock6 == -1 { + return syscall.EAFNOSUPPORT + } + return send6(bind.sock6, nend, buff) + } +} + +func (end *NativeEndpoint) SrcIP() net.IP { + if !end.isV6 { + return net.IPv4( + end.src4().src[0], + end.src4().src[1], + end.src4().src[2], + end.src4().src[3], + ) + } else { + return end.src6().src[:] + } +} + +func (end *NativeEndpoint) DstIP() net.IP { + if !end.isV6 { + return net.IPv4( + end.dst4().Addr[0], + end.dst4().Addr[1], + end.dst4().Addr[2], + end.dst4().Addr[3], + ) + } else { + return end.dst6().Addr[:] + } +} + +func (end *NativeEndpoint) DstToBytes() []byte { + if !end.isV6 { + return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] + } else { + return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] + } +} + +func (end *NativeEndpoint) SrcToString() string { + return end.SrcIP().String() +} + +func (end *NativeEndpoint) DstToString() string { + var udpAddr net.UDPAddr + udpAddr.IP = end.DstIP() + if !end.isV6 { + udpAddr.Port = end.dst4().Port + } else { + udpAddr.Port = end.dst6().Port + } + return udpAddr.String() +} + +func (end *NativeEndpoint) ClearDst() { + for i := range end.dst { + end.dst[i] = 0 + } +} + +func (end *NativeEndpoint) ClearSrc() { + for i := range end.src { + end.src[i] = 0 + } +} + +func zoneToUint32(zone string) (uint32, error) { + if zone == "" { + return 0, nil + } + if intr, err := net.InterfaceByName(zone); err == nil { + return uint32(intr.Index), nil + } + n, err := strconv.ParseUint(zone, 10, 32) + return uint32(n), err +} + +func create4(port uint16) (int, uint16, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return FD_ERR, 0, err + } + + addr := unix.SockaddrInet4{ + Port: int(port), + } + + // set sockopts and bind + + if err := func() error { + if err := unix.SetsockoptInt( + fd, + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IP, + unix.IP_PKTINFO, + 1, + ); err != nil { + return err + } + + return unix.Bind(fd, &addr) + }(); err != nil { + unix.Close(fd) + return FD_ERR, 0, err + } + + return fd, uint16(addr.Port), err +} + +func create6(port uint16) (int, uint16, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET6, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return FD_ERR, 0, err + } + + // set sockopts and bind + + addr := unix.SockaddrInet6{ + Port: int(port), + } + + if err := func() error { + + if err := unix.SetsockoptInt( + fd, + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_RECVPKTINFO, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_V6ONLY, + 1, + ); err != nil { + return err + } + + return unix.Bind(fd, &addr) + + }(); err != nil { + unix.Close(fd) + return FD_ERR, 0, err + } + + return fd, uint16(addr.Port), err +} + +func send4(sock int, end *NativeEndpoint, buff []byte) error { + + // construct message header + + cmsg := struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + }{ + unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, + }, + unix.Inet4Pktinfo{ + Spec_dst: end.src4().src, + Ifindex: end.src4().ifindex, + }, + } + + _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) + + if err == nil { + return nil + } + + // clear src and retry + + if err == unix.EINVAL { + end.ClearSrc() + cmsg.pktinfo = unix.Inet4Pktinfo{} + _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) + } + + return err +} + +func send6(sock int, end *NativeEndpoint, buff []byte) error { + + // construct message header + + cmsg := struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + }{ + unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, + }, + unix.Inet6Pktinfo{ + Addr: end.src6().src, + Ifindex: end.dst6().ZoneId, + }, + } + + if cmsg.pktinfo.Addr == [16]byte{} { + cmsg.pktinfo.Ifindex = 0 + } + + _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) + + if err == nil { + return nil + } + + // clear src and retry + + if err == unix.EINVAL { + end.ClearSrc() + cmsg.pktinfo = unix.Inet6Pktinfo{} + _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) + } + + return err +} + +func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { + + // contruct message header + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + } + + size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) + + if err != nil { + return 0, err + } + end.isV6 = false + + if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { + *end.dst4() = *newDst4 + } + + // update source cache + + if cmsg.cmsghdr.Level == unix.IPPROTO_IP && + cmsg.cmsghdr.Type == unix.IP_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { + end.src4().src = cmsg.pktinfo.Spec_dst + end.src4().ifindex = cmsg.pktinfo.Ifindex + } + + return size, nil +} + +func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { + + // contruct message header + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + } + + size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) + + if err != nil { + return 0, err + } + end.isV6 = true + + if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { + *end.dst6() = *newDst6 + } + + // update source cache + + if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && + cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { + end.src6().src = cmsg.pktinfo.Addr + end.dst6().ZoneId = cmsg.pktinfo.Ifindex + } + + return size, nil +} + +func (bind *NativeBind) routineRouteListener(device *Device) { + type peerEndpointPtr struct { + peer *Peer + endpoint *Endpoint + } + var reqPeer map[uint32]peerEndpointPtr + var reqPeerLock sync.Mutex + + defer unix.Close(bind.netlinkSock) + + for msg := make([]byte, 1<<16); ; { + var err error + var msgn int + for { + msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) + if err == nil || !rwcancel.RetryAfterError(err) { + break + } + if !bind.netlinkCancel.ReadyRead() { + return + } + } + if err != nil { + return + } + + for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { + + hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) + + if uint(hdr.Len) > uint(len(remain)) { + break + } + + switch hdr.Type { + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + if hdr.Seq <= MaxPeers && hdr.Seq > 0 { + if uint(len(remain)) < uint(hdr.Len) { + break + } + if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { + attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] + for { + if uint(len(attr)) < uint(unix.SizeofRtAttr) { + break + } + attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) + if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { + break + } + if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { + ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) + reqPeerLock.Lock() + if reqPeer == nil { + reqPeerLock.Unlock() + break + } + pePtr, ok := reqPeer[hdr.Seq] + reqPeerLock.Unlock() + if !ok { + break + } + pePtr.peer.Lock() + if &pePtr.peer.endpoint != pePtr.endpoint { + pePtr.peer.Unlock() + break + } + if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx { + pePtr.peer.Unlock() + break + } + pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc() + pePtr.peer.Unlock() + } + attr = attr[attrhdr.Len:] + } + } + break + } + reqPeerLock.Lock() + reqPeer = make(map[uint32]peerEndpointPtr) + reqPeerLock.Unlock() + go func() { + device.peers.RLock() + i := uint32(1) + for _, peer := range device.peers.keyMap { + peer.RLock() + if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { + peer.RUnlock() + continue + } + if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { + peer.RUnlock() + break + } + nlmsg := struct { + hdr unix.NlMsghdr + msg unix.RtMsg + dsthdr unix.RtAttr + dst [4]byte + srchdr unix.RtAttr + src [4]byte + markhdr unix.RtAttr + mark uint32 + }{ + unix.NlMsghdr{ + Type: uint16(unix.RTM_GETROUTE), + Flags: unix.NLM_F_REQUEST, + Seq: i, + }, + unix.RtMsg{ + Family: unix.AF_INET, + Dst_len: 32, + Src_len: 32, + }, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_DST, + }, + peer.endpoint.(*NativeEndpoint).dst4().Addr, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_SRC, + }, + peer.endpoint.(*NativeEndpoint).src4().src, + unix.RtAttr{ + Len: 8, + Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix + }, + uint32(bind.lastMark), + } + nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) + reqPeerLock.Lock() + reqPeer[i] = peerEndpointPtr{ + peer: peer, + endpoint: &peer.endpoint, + } + reqPeerLock.Unlock() + peer.RUnlock() + i++ + _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) + if err != nil { + break + } + } + device.peers.RUnlock() + }() + } + remain = remain[hdr.Len:] + } + } +} diff --git a/device/constants.go b/device/constants.go new file mode 100644 index 0000000..27d910f --- /dev/null +++ b/device/constants.go @@ -0,0 +1,41 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "time" +) + +/* Specification constants */ + +const ( + RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 + RejectAfterMessages = (1 << 64) - (1 << 4) - 1 + RekeyAfterTime = time.Second * 120 + RekeyAttemptTime = time.Second * 90 + RekeyTimeout = time.Second * 5 + MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */ + RekeyTimeoutJitterMaxMs = 334 + RejectAfterTime = time.Second * 180 + KeepaliveTimeout = time.Second * 10 + CookieRefreshTime = time.Second * 120 + HandshakeInitationRate = time.Second / 20 + PaddingMultiple = 16 +) + +const ( + MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) + MaxMessageSize = MaxSegmentSize // maximum size of transport message + MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content +) + +/* Implementation constants */ + +const ( + UnderLoadQueueSize = QueueHandshakeSize / 8 + UnderLoadAfterTime = time.Second // how long does the device remain under load after detected + MaxPeers = 1 << 16 // maximum number of configured peers +) diff --git a/device/cookie.go b/device/cookie.go new file mode 100644 index 0000000..2f21067 --- /dev/null +++ b/device/cookie.go @@ -0,0 +1,250 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/hmac" + "crypto/rand" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/chacha20poly1305" + "sync" + "time" +) + +type CookieChecker struct { + sync.RWMutex + mac1 struct { + key [blake2s.Size]byte + } + mac2 struct { + secret [blake2s.Size]byte + secretSet time.Time + encryptionKey [chacha20poly1305.KeySize]byte + } +} + +type CookieGenerator struct { + sync.RWMutex + mac1 struct { + key [blake2s.Size]byte + } + mac2 struct { + cookie [blake2s.Size128]byte + cookieSet time.Time + hasLastMAC1 bool + lastMAC1 [blake2s.Size128]byte + encryptionKey [chacha20poly1305.KeySize]byte + } +} + +func (st *CookieChecker) Init(pk NoisePublicKey) { + st.Lock() + defer st.Unlock() + + // mac1 state + + func() { + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelMAC1)) + hash.Write(pk[:]) + hash.Sum(st.mac1.key[:0]) + }() + + // mac2 state + + func() { + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelCookie)) + hash.Write(pk[:]) + hash.Sum(st.mac2.encryptionKey[:0]) + }() + + st.mac2.secretSet = time.Time{} +} + +func (st *CookieChecker) CheckMAC1(msg []byte) bool { + st.RLock() + defer st.RUnlock() + + size := len(msg) + smac2 := size - blake2s.Size128 + smac1 := smac2 - blake2s.Size128 + + var mac1 [blake2s.Size128]byte + + mac, _ := blake2s.New128(st.mac1.key[:]) + mac.Write(msg[:smac1]) + mac.Sum(mac1[:0]) + + return hmac.Equal(mac1[:], msg[smac1:smac2]) +} + +func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { + st.RLock() + defer st.RUnlock() + + if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { + return false + } + + // derive cookie key + + var cookie [blake2s.Size128]byte + func() { + mac, _ := blake2s.New128(st.mac2.secret[:]) + mac.Write(src) + mac.Sum(cookie[:0]) + }() + + // calculate mac of packet (including mac1) + + smac2 := len(msg) - blake2s.Size128 + + var mac2 [blake2s.Size128]byte + func() { + mac, _ := blake2s.New128(cookie[:]) + mac.Write(msg[:smac2]) + mac.Sum(mac2[:0]) + }() + + return hmac.Equal(mac2[:], msg[smac2:]) +} + +func (st *CookieChecker) CreateReply( + msg []byte, + recv uint32, + src []byte, +) (*MessageCookieReply, error) { + + st.RLock() + + // refresh cookie secret + + if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { + st.RUnlock() + st.Lock() + _, err := rand.Read(st.mac2.secret[:]) + if err != nil { + st.Unlock() + return nil, err + } + st.mac2.secretSet = time.Now() + st.Unlock() + st.RLock() + } + + // derive cookie + + var cookie [blake2s.Size128]byte + func() { + mac, _ := blake2s.New128(st.mac2.secret[:]) + mac.Write(src) + mac.Sum(cookie[:0]) + }() + + // encrypt cookie + + size := len(msg) + + smac2 := size - blake2s.Size128 + smac1 := smac2 - blake2s.Size128 + + reply := new(MessageCookieReply) + reply.Type = MessageCookieReplyType + reply.Receiver = recv + + _, err := rand.Read(reply.Nonce[:]) + if err != nil { + st.RUnlock() + return nil, err + } + + xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) + xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2]) + + st.RUnlock() + + return reply, nil +} + +func (st *CookieGenerator) Init(pk NoisePublicKey) { + st.Lock() + defer st.Unlock() + + func() { + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelMAC1)) + hash.Write(pk[:]) + hash.Sum(st.mac1.key[:0]) + }() + + func() { + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelCookie)) + hash.Write(pk[:]) + hash.Sum(st.mac2.encryptionKey[:0]) + }() + + st.mac2.cookieSet = time.Time{} +} + +func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { + st.Lock() + defer st.Unlock() + + if !st.mac2.hasLastMAC1 { + return false + } + + var cookie [blake2s.Size128]byte + + xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) + _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) + + if err != nil { + return false + } + + st.mac2.cookieSet = time.Now() + st.mac2.cookie = cookie + return true +} + +func (st *CookieGenerator) AddMacs(msg []byte) { + + size := len(msg) + + smac2 := size - blake2s.Size128 + smac1 := smac2 - blake2s.Size128 + + mac1 := msg[smac1:smac2] + mac2 := msg[smac2:] + + st.Lock() + defer st.Unlock() + + // set mac1 + + func() { + mac, _ := blake2s.New128(st.mac1.key[:]) + mac.Write(msg[:smac1]) + mac.Sum(mac1[:0]) + }() + copy(st.mac2.lastMAC1[:], mac1) + st.mac2.hasLastMAC1 = true + + // set mac2 + + if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime { + return + } + + func() { + mac, _ := blake2s.New128(st.mac2.cookie[:]) + mac.Write(msg[:smac2]) + mac.Sum(mac2[:0]) + }() +} diff --git a/device/cookie_test.go b/device/cookie_test.go new file mode 100644 index 0000000..79a6a86 --- /dev/null +++ b/device/cookie_test.go @@ -0,0 +1,191 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "testing" +) + +func TestCookieMAC1(t *testing.T) { + + // setup generator / checker + + var ( + generator CookieGenerator + checker CookieChecker + ) + + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + pk := sk.publicKey() + + generator.Init(pk) + checker.Init(pk) + + // check mac1 + + src := []byte{192, 168, 13, 37, 10, 10, 10} + + checkMAC1 := func(msg []byte) { + generator.AddMacs(msg) + if !checker.CheckMAC1(msg) { + t.Fatal("MAC1 generation/verification failed") + } + if checker.CheckMAC2(msg, src) { + t.Fatal("MAC2 generation/verification failed") + } + } + + checkMAC1([]byte{ + 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, + 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, + 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, + 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, + 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, + 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, + 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, + 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, + 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, + 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, + 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, + 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, + 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, + 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, + 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, + 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, + }) + + checkMAC1([]byte{ + 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, + 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, + 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, + 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, + 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, + 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, + 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, + 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, + }) + + checkMAC1([]byte{ + 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, + 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, + 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, + 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, + }) + + // exchange cookie reply + + func() { + msg := []byte{ + 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, + 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, + 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, + 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, + 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, + 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, + 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, + 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, + 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, + 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, + 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, + 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, + 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, + 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, + 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, + 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, + } + generator.AddMacs(msg) + reply, err := checker.CreateReply(msg, 1377, src) + if err != nil { + t.Fatal("Failed to create cookie reply:", err) + } + if !generator.ConsumeReply(reply) { + t.Fatal("Failed to consume cookie reply") + } + }() + + // check mac2 + + checkMAC2 := func(msg []byte) { + generator.AddMacs(msg) + + if !checker.CheckMAC1(msg) { + t.Fatal("MAC1 generation/verification failed") + } + if !checker.CheckMAC2(msg, src) { + t.Fatal("MAC2 generation/verification failed") + } + + msg[5] ^= 0x20 + + if checker.CheckMAC1(msg) { + t.Fatal("MAC1 generation/verification failed") + } + if checker.CheckMAC2(msg, src) { + t.Fatal("MAC2 generation/verification failed") + } + + msg[5] ^= 0x20 + + srcBad1 := []byte{192, 168, 13, 37, 40, 01} + if checker.CheckMAC2(msg, srcBad1) { + t.Fatal("MAC2 generation/verification failed") + } + + srcBad2 := []byte{192, 168, 13, 38, 40, 01} + if checker.CheckMAC2(msg, srcBad2) { + t.Fatal("MAC2 generation/verification failed") + } + } + + checkMAC2([]byte{ + 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, + 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, + 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, + 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, + 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, + 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, + 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, + 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, + }) + + checkMAC2([]byte{ + 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, + 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, + 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, + 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, + 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, + 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, + 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, + 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, + 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, + 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, + 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, + 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, + 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, + 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, + 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, + 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, + 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, + 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, + 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, + 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, + 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, + 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, + 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, + 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, + 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, + 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, + 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, + 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, + 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, + 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, + 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, + 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, + }) +} diff --git a/device/device.go b/device/device.go new file mode 100644 index 0000000..d6c96d6 --- /dev/null +++ b/device/device.go @@ -0,0 +1,396 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "golang.zx2c4.com/wireguard/ratelimiter" + "golang.zx2c4.com/wireguard/tun" + "runtime" + "sync" + "sync/atomic" + "time" +) + +const ( + DeviceRoutineNumberPerCPU = 3 + DeviceRoutineNumberAdditional = 2 +) + +type Device struct { + isUp AtomicBool // device is (going) up + isClosed AtomicBool // device is closed? (acting as guard) + log *Logger + + // synchronized resources (locks acquired in order) + + state struct { + starting sync.WaitGroup + stopping sync.WaitGroup + sync.Mutex + changing AtomicBool + current bool + } + + net struct { + starting sync.WaitGroup + stopping sync.WaitGroup + sync.RWMutex + bind Bind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + } + + staticIdentity struct { + sync.RWMutex + privateKey NoisePrivateKey + publicKey NoisePublicKey + } + + peers struct { + sync.RWMutex + keyMap map[NoisePublicKey]*Peer + } + + // unprotected / "self-synchronising resources" + + allowedips AllowedIPs + indexTable IndexTable + cookieChecker CookieChecker + + rate struct { + underLoadUntil atomic.Value + limiter ratelimiter.Ratelimiter + } + + pool struct { + messageBufferPool *sync.Pool + messageBufferReuseChan chan *[MaxMessageSize]byte + inboundElementPool *sync.Pool + inboundElementReuseChan chan *QueueInboundElement + outboundElementPool *sync.Pool + outboundElementReuseChan chan *QueueOutboundElement + } + + queue struct { + encryption chan *QueueOutboundElement + decryption chan *QueueInboundElement + handshake chan QueueHandshakeElement + } + + signals struct { + stop chan struct{} + } + + tun struct { + device tun.TUNDevice + mtu int32 + } +} + +/* Converts the peer into a "zombie", which remains in the peer map, + * but processes no packets and does not exists in the routing table. + * + * Must hold device.peers.Mutex + */ +func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { + + // stop routing and processing of packets + + device.allowedips.RemoveByPeer(peer) + peer.Stop() + + // remove from peer map + + delete(device.peers.keyMap, key) +} + +func deviceUpdateState(device *Device) { + + // check if state already being updated (guard) + + if device.state.changing.Swap(true) { + return + } + + // compare to current state of device + + device.state.Lock() + + newIsUp := device.isUp.Get() + + if newIsUp == device.state.current { + device.state.changing.Set(false) + device.state.Unlock() + return + } + + // change state of device + + switch newIsUp { + case true: + if err := device.BindUpdate(); err != nil { + device.isUp.Set(false) + break + } + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Start() + if peer.persistentKeepaliveInterval > 0 { + peer.SendKeepalive() + } + } + device.peers.RUnlock() + + case false: + device.BindClose() + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Stop() + } + device.peers.RUnlock() + } + + // update state variables + + device.state.current = newIsUp + device.state.changing.Set(false) + device.state.Unlock() + + // check for state change in the mean time + + deviceUpdateState(device) +} + +func (device *Device) Up() { + + // closed device cannot be brought up + + if device.isClosed.Get() { + return + } + + device.isUp.Set(true) + deviceUpdateState(device) +} + +func (device *Device) Down() { + device.isUp.Set(false) + deviceUpdateState(device) +} + +func (device *Device) IsUnderLoad() bool { + + // check if currently under load + + now := time.Now() + underLoad := len(device.queue.handshake) >= UnderLoadQueueSize + if underLoad { + device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime)) + return true + } + + // check if recently under load + + until := device.rate.underLoadUntil.Load().(time.Time) + return until.After(now) +} + +func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { + + // lock required resources + + device.staticIdentity.Lock() + defer device.staticIdentity.Unlock() + + device.peers.Lock() + defer device.peers.Unlock() + + for _, peer := range device.peers.keyMap { + peer.handshake.mutex.RLock() + defer peer.handshake.mutex.RUnlock() + } + + // remove peers with matching public keys + + publicKey := sk.publicKey() + for key, peer := range device.peers.keyMap { + if peer.handshake.remoteStatic.Equals(publicKey) { + unsafeRemovePeer(device, peer, key) + } + } + + // update key material + + device.staticIdentity.privateKey = sk + device.staticIdentity.publicKey = publicKey + device.cookieChecker.Init(publicKey) + + // do static-static DH pre-computations + + rmKey := device.staticIdentity.privateKey.IsZero() + + for key, peer := range device.peers.keyMap { + + handshake := &peer.handshake + + if rmKey { + handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{} + } else { + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) + } + + if isZero(handshake.precomputedStaticStatic[:]) { + unsafeRemovePeer(device, peer, key) + } + } + + return nil +} + +func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device { + device := new(Device) + + device.isUp.Set(false) + device.isClosed.Set(false) + + device.log = logger + + device.tun.device = tunDevice + mtu, err := device.tun.device.MTU() + if err != nil { + logger.Error.Println("Trouble determining MTU, assuming default:", err) + mtu = DefaultMTU + } + device.tun.mtu = int32(mtu) + + device.peers.keyMap = make(map[NoisePublicKey]*Peer) + + device.rate.limiter.Init() + device.rate.underLoadUntil.Store(time.Time{}) + + device.indexTable.Init() + device.allowedips.Reset() + + device.PopulatePools() + + // create queues + + device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) + device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) + device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) + + // prepare signals + + device.signals.stop = make(chan struct{}) + + // prepare net + + device.net.port = 0 + device.net.bind = nil + + // start workers + + cpus := runtime.NumCPU() + device.state.starting.Wait() + device.state.stopping.Wait() + device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) + device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) + for i := 0; i < cpus; i += 1 { + go device.RoutineEncryption() + go device.RoutineDecryption() + go device.RoutineHandshake() + } + + go device.RoutineReadFromTUN() + go device.RoutineTUNEventReader() + + device.state.starting.Wait() + + return device +} + +func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { + device.peers.RLock() + defer device.peers.RUnlock() + + return device.peers.keyMap[pk] +} + +func (device *Device) RemovePeer(key NoisePublicKey) { + device.peers.Lock() + defer device.peers.Unlock() + + // stop peer and remove from routing + + peer, ok := device.peers.keyMap[key] + if ok { + unsafeRemovePeer(device, peer, key) + } +} + +func (device *Device) RemoveAllPeers() { + device.peers.Lock() + defer device.peers.Unlock() + + for key, peer := range device.peers.keyMap { + unsafeRemovePeer(device, peer, key) + } + + device.peers.keyMap = make(map[NoisePublicKey]*Peer) +} + +func (device *Device) FlushPacketQueues() { + for { + select { + case elem, ok := <-device.queue.decryption: + if ok { + elem.Drop() + } + case elem, ok := <-device.queue.encryption: + if ok { + elem.Drop() + } + case <-device.queue.handshake: + default: + return + } + } + +} + +func (device *Device) Close() { + if device.isClosed.Swap(true) { + return + } + + device.state.starting.Wait() + + device.log.Info.Println("Device closing") + device.state.changing.Set(true) + device.state.Lock() + defer device.state.Unlock() + + device.tun.device.Close() + device.BindClose() + + device.isUp.Set(false) + + close(device.signals.stop) + + device.RemoveAllPeers() + + device.state.stopping.Wait() + device.FlushPacketQueues() + + device.rate.limiter.Close() + + device.state.changing.Set(false) + device.log.Info.Println("Interface closed") +} + +func (device *Device) Wait() chan struct{} { + return device.signals.stop +} diff --git a/device/device_test.go b/device/device_test.go new file mode 100644 index 0000000..db5a3c0 --- /dev/null +++ b/device/device_test.go @@ -0,0 +1,48 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +/* Create two device instances and simulate full WireGuard interaction + * without network dependencies + */ + +import "testing" + +func TestDevice(t *testing.T) { + + // prepare tun devices for generating traffic + + tun1, err := CreateDummyTUN("tun1") + if err != nil { + t.Error("failed to create tun:", err.Error()) + } + + tun2, err := CreateDummyTUN("tun2") + if err != nil { + t.Error("failed to create tun:", err.Error()) + } + + _ = tun1 + _ = tun2 + + // prepare endpoints + + end1, err := CreateDummyEndpoint() + if err != nil { + t.Error("failed to create endpoint:", err.Error()) + } + + end2, err := CreateDummyEndpoint() + if err != nil { + t.Error("failed to create endpoint:", err.Error()) + } + + _ = end1 + _ = end2 + + // create binds + +} diff --git a/device/endpoint_test.go b/device/endpoint_test.go new file mode 100644 index 0000000..1896790 --- /dev/null +++ b/device/endpoint_test.go @@ -0,0 +1,53 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "math/rand" + "net" +) + +type DummyEndpoint struct { + src [16]byte + dst [16]byte +} + +func CreateDummyEndpoint() (*DummyEndpoint, error) { + var end DummyEndpoint + if _, err := rand.Read(end.src[:]); err != nil { + return nil, err + } + _, err := rand.Read(end.dst[:]) + return &end, err +} + +func (e *DummyEndpoint) ClearSrc() {} + +func (e *DummyEndpoint) SrcToString() string { + var addr net.UDPAddr + addr.IP = e.SrcIP() + addr.Port = 1000 + return addr.String() +} + +func (e *DummyEndpoint) DstToString() string { + var addr net.UDPAddr + addr.IP = e.DstIP() + addr.Port = 1000 + return addr.String() +} + +func (e *DummyEndpoint) SrcToBytes() []byte { + return e.src[:] +} + +func (e *DummyEndpoint) DstIP() net.IP { + return e.dst[:] +} + +func (e *DummyEndpoint) SrcIP() net.IP { + return e.src[:] +} diff --git a/device/indextable.go b/device/indextable.go new file mode 100644 index 0000000..4cba970 --- /dev/null +++ b/device/indextable.go @@ -0,0 +1,97 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/rand" + "sync" + "unsafe" +) + +type IndexTableEntry struct { + peer *Peer + handshake *Handshake + keypair *Keypair +} + +type IndexTable struct { + sync.RWMutex + table map[uint32]IndexTableEntry +} + +func randUint32() (uint32, error) { + var integer [4]byte + _, err := rand.Read(integer[:]) + return *(*uint32)(unsafe.Pointer(&integer[0])), err +} + +func (table *IndexTable) Init() { + table.Lock() + defer table.Unlock() + table.table = make(map[uint32]IndexTableEntry) +} + +func (table *IndexTable) Delete(index uint32) { + table.Lock() + defer table.Unlock() + delete(table.table, index) +} + +func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { + table.Lock() + defer table.Unlock() + entry, ok := table.table[index] + if !ok { + return + } + table.table[index] = IndexTableEntry{ + peer: entry.peer, + keypair: keypair, + handshake: nil, + } +} + +func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) { + for { + // generate random index + + index, err := randUint32() + if err != nil { + return index, err + } + + // check if index used + + table.RLock() + _, ok := table.table[index] + table.RUnlock() + if ok { + continue + } + + // check again while locked + + table.Lock() + _, found := table.table[index] + if found { + table.Unlock() + continue + } + table.table[index] = IndexTableEntry{ + peer: peer, + handshake: handshake, + keypair: nil, + } + table.Unlock() + return index, nil + } +} + +func (table *IndexTable) Lookup(id uint32) IndexTableEntry { + table.RLock() + defer table.RUnlock() + return table.table[id] +} diff --git a/device/ip.go b/device/ip.go new file mode 100644 index 0000000..9d4fb74 --- /dev/null +++ b/device/ip.go @@ -0,0 +1,22 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "net" +) + +const ( + IPv4offsetTotalLength = 2 + IPv4offsetSrc = 12 + IPv4offsetDst = IPv4offsetSrc + net.IPv4len +) + +const ( + IPv6offsetPayloadLength = 4 + IPv6offsetSrc = 8 + IPv6offsetDst = IPv6offsetSrc + net.IPv6len +) diff --git a/device/kdf_test.go b/device/kdf_test.go new file mode 100644 index 0000000..11ea8d5 --- /dev/null +++ b/device/kdf_test.go @@ -0,0 +1,84 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "encoding/hex" + "golang.org/x/crypto/blake2s" + "testing" +) + +type KDFTest struct { + key string + input string + t0 string + t1 string + t2 string +} + +func assertEquals(t *testing.T, a string, b string) { + if a != b { + t.Fatal("expected", a, "=", b) + } +} + +func TestKDF(t *testing.T) { + tests := []KDFTest{ + { + key: "746573742d6b6579", + input: "746573742d696e707574", + t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", + t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", + t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", + }, + { + key: "776972656775617264", + input: "776972656775617264", + t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", + t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", + t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", + }, + { + key: "", + input: "", + t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", + t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", + t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", + }, + } + + var t0, t1, t2 [blake2s.Size]byte + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + KDF3(&t0, &t1, &t2, key, input) + t0s := hex.EncodeToString(t0[:]) + t1s := hex.EncodeToString(t1[:]) + t2s := hex.EncodeToString(t2[:]) + assertEquals(t, t0s, test.t0) + assertEquals(t, t1s, test.t1) + assertEquals(t, t2s, test.t2) + } + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + KDF2(&t0, &t1, key, input) + t0s := hex.EncodeToString(t0[:]) + t1s := hex.EncodeToString(t1[:]) + assertEquals(t, t0s, test.t0) + assertEquals(t, t1s, test.t1) + } + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + KDF1(&t0, key, input) + t0s := hex.EncodeToString(t0[:]) + assertEquals(t, t0s, test.t0) + } +} diff --git a/device/keypair.go b/device/keypair.go new file mode 100644 index 0000000..a9fbfce --- /dev/null +++ b/device/keypair.go @@ -0,0 +1,50 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/cipher" + "golang.zx2c4.com/wireguard/replay" + "sync" + "time" +) + +/* Due to limitations in Go and /x/crypto there is currently + * no way to ensure that key material is securely ereased in memory. + * + * Since this may harm the forward secrecy property, + * we plan to resolve this issue; whenever Go allows us to do so. + */ + +type Keypair struct { + sendNonce uint64 + send cipher.AEAD + receive cipher.AEAD + replayFilter replay.ReplayFilter + isInitiator bool + created time.Time + localIndex uint32 + remoteIndex uint32 +} + +type Keypairs struct { + sync.RWMutex + current *Keypair + previous *Keypair + next *Keypair +} + +func (kp *Keypairs) Current() *Keypair { + kp.RLock() + defer kp.RUnlock() + return kp.current +} + +func (device *Device) DeleteKeypair(key *Keypair) { + if key != nil { + device.indexTable.Delete(key.localIndex) + } +} diff --git a/device/logger.go b/device/logger.go new file mode 100644 index 0000000..7c8b704 --- /dev/null +++ b/device/logger.go @@ -0,0 +1,59 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "io" + "io/ioutil" + "log" + "os" +) + +const ( + LogLevelSilent = iota + LogLevelError + LogLevelInfo + LogLevelDebug +) + +type Logger struct { + Debug *log.Logger + Info *log.Logger + Error *log.Logger +} + +func NewLogger(level int, prepend string) *Logger { + output := os.Stdout + logger := new(Logger) + + logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) { + if level >= LogLevelDebug { + return output, output, output + } + if level >= LogLevelInfo { + return output, output, ioutil.Discard + } + if level >= LogLevelError { + return output, ioutil.Discard, ioutil.Discard + } + return ioutil.Discard, ioutil.Discard, ioutil.Discard + }() + + logger.Debug = log.New(logDebug, + "DEBUG: "+prepend, + log.Ldate|log.Ltime, + ) + + logger.Info = log.New(logInfo, + "INFO: "+prepend, + log.Ldate|log.Ltime, + ) + logger.Error = log.New(logErr, + "ERROR: "+prepend, + log.Ldate|log.Ltime, + ) + return logger +} diff --git a/device/mark_default.go b/device/mark_default.go new file mode 100644 index 0000000..76b1015 --- /dev/null +++ b/device/mark_default.go @@ -0,0 +1,12 @@ +// +build !linux,!openbsd,!freebsd + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +func (bind *NativeBind) SetMark(mark uint32) error { + return nil +} diff --git a/device/mark_unix.go b/device/mark_unix.go new file mode 100644 index 0000000..ee64cc9 --- /dev/null +++ b/device/mark_unix.go @@ -0,0 +1,64 @@ +// +build android openbsd freebsd + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "golang.org/x/sys/unix" + "runtime" +) + +var fwmarkIoctl int + +func init() { + switch runtime.GOOS { + case "linux", "android": + fwmarkIoctl = 36 /* unix.SO_MARK */ + case "freebsd": + fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ + case "openbsd": + fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ + } +} + +func (bind *NativeBind) SetMark(mark uint32) error { + var operr error + if fwmarkIoctl == 0 { + return nil + } + if bind.ipv4 != nil { + fd, err := bind.ipv4.SyscallConn() + if err != nil { + return err + } + err = fd.Control(func(fd uintptr) { + operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + if err == nil { + err = operr + } + if err != nil { + return err + } + } + if bind.ipv6 != nil { + fd, err := bind.ipv6.SyscallConn() + if err != nil { + return err + } + err = fd.Control(func(fd uintptr) { + operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + if err == nil { + err = operr + } + if err != nil { + return err + } + } + return nil +} diff --git a/device/misc.go b/device/misc.go new file mode 100644 index 0000000..a38d1c1 --- /dev/null +++ b/device/misc.go @@ -0,0 +1,48 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "sync/atomic" +) + +/* Atomic Boolean */ + +const ( + AtomicFalse = int32(iota) + AtomicTrue +) + +type AtomicBool struct { + int32 +} + +func (a *AtomicBool) Get() bool { + return atomic.LoadInt32(&a.int32) == AtomicTrue +} + +func (a *AtomicBool) Swap(val bool) bool { + flag := AtomicFalse + if val { + flag = AtomicTrue + } + return atomic.SwapInt32(&a.int32, flag) == AtomicTrue +} + +func (a *AtomicBool) Set(val bool) { + flag := AtomicFalse + if val { + flag = AtomicTrue + } + atomic.StoreInt32(&a.int32, flag) +} + +func min(a, b uint) uint { + if a > b { + return b + } + return a +} diff --git a/device/noise-helpers.go b/device/noise-helpers.go new file mode 100644 index 0000000..4b09bf3 --- /dev/null +++ b/device/noise-helpers.go @@ -0,0 +1,104 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/subtle" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/curve25519" + "hash" +) + +/* KDF related functions. + * HMAC-based Key Derivation Function (HKDF) + * https://tools.ietf.org/html/rfc5869 + */ + +func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { + mac := hmac.New(func() hash.Hash { + h, _ := blake2s.New256(nil) + return h + }, key) + mac.Write(in0) + mac.Sum(sum[:0]) +} + +func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { + mac := hmac.New(func() hash.Hash { + h, _ := blake2s.New256(nil) + return h + }, key) + mac.Write(in0) + mac.Write(in1) + mac.Sum(sum[:0]) +} + +func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { + HMAC1(t0, key, input) + HMAC1(t0, t0[:], []byte{0x1}) + return +} + +func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { + var prk [blake2s.Size]byte + HMAC1(&prk, key, input) + HMAC1(t0, prk[:], []byte{0x1}) + HMAC2(t1, prk[:], t0[:], []byte{0x2}) + setZero(prk[:]) + return +} + +func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { + var prk [blake2s.Size]byte + HMAC1(&prk, key, input) + HMAC1(t0, prk[:], []byte{0x1}) + HMAC2(t1, prk[:], t0[:], []byte{0x2}) + HMAC2(t2, prk[:], t1[:], []byte{0x3}) + setZero(prk[:]) + return +} + +func isZero(val []byte) bool { + acc := 1 + for _, b := range val { + acc &= subtle.ConstantTimeByteEq(b, 0) + } + return acc == 1 +} + +/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ +func setZero(arr []byte) { + for i := range arr { + arr[i] = 0 + } +} + +func (sk *NoisePrivateKey) clamp() { + sk[0] &= 248 + sk[31] = (sk[31] & 127) | 64 +} + +func newPrivateKey() (sk NoisePrivateKey, err error) { + _, err = rand.Read(sk[:]) + sk.clamp() + return +} + +func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarBaseMult(apk, ask) + return +} + +func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarMult(&ss, ask, apk) + return ss +} diff --git a/device/noise-protocol.go b/device/noise-protocol.go new file mode 100644 index 0000000..73826e1 --- /dev/null +++ b/device/noise-protocol.go @@ -0,0 +1,600 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/poly1305" + "golang.zx2c4.com/wireguard/tai64n" + "sync" + "time" +) + +const ( + HandshakeZeroed = iota + HandshakeInitiationCreated + HandshakeInitiationConsumed + HandshakeResponseCreated + HandshakeResponseConsumed +) + +const ( + NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" + WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" + WGLabelMAC1 = "mac1----" + WGLabelCookie = "cookie--" +) + +const ( + MessageInitiationType = 1 + MessageResponseType = 2 + MessageCookieReplyType = 3 + MessageTransportType = 4 +) + +const ( + MessageInitiationSize = 148 // size of handshake initation message + MessageResponseSize = 92 // size of response message + MessageCookieReplySize = 64 // size of cookie reply message + MessageTransportHeaderSize = 16 // size of data preceeding content in transport message + MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport + MessageKeepaliveSize = MessageTransportSize // size of keepalive + MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message +) + +const ( + MessageTransportOffsetReceiver = 4 + MessageTransportOffsetCounter = 8 + MessageTransportOffsetContent = 16 +) + +/* Type is an 8-bit field, followed by 3 nul bytes, + * by marshalling the messages in little-endian byteorder + * we can treat these as a 32-bit unsigned int (for now) + * + */ + +type MessageInitiation struct { + Type uint32 + Sender uint32 + Ephemeral NoisePublicKey + Static [NoisePublicKeySize + poly1305.TagSize]byte + Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte + MAC1 [blake2s.Size128]byte + MAC2 [blake2s.Size128]byte +} + +type MessageResponse struct { + Type uint32 + Sender uint32 + Receiver uint32 + Ephemeral NoisePublicKey + Empty [poly1305.TagSize]byte + MAC1 [blake2s.Size128]byte + MAC2 [blake2s.Size128]byte +} + +type MessageTransport struct { + Type uint32 + Receiver uint32 + Counter uint64 + Content []byte +} + +type MessageCookieReply struct { + Type uint32 + Receiver uint32 + Nonce [chacha20poly1305.NonceSizeX]byte + Cookie [blake2s.Size128 + poly1305.TagSize]byte +} + +type Handshake struct { + state int + mutex sync.RWMutex + hash [blake2s.Size]byte // hash value + chainKey [blake2s.Size]byte // chain key + presharedKey NoiseSymmetricKey // psk + localEphemeral NoisePrivateKey // ephemeral secret key + localIndex uint32 // used to clear hash-table + remoteIndex uint32 // index for sending + remoteStatic NoisePublicKey // long term key + remoteEphemeral NoisePublicKey // ephemeral public key + precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret + lastTimestamp tai64n.Timestamp + lastInitiationConsumption time.Time + lastSentHandshake time.Time +} + +var ( + InitialChainKey [blake2s.Size]byte + InitialHash [blake2s.Size]byte + ZeroNonce [chacha20poly1305.NonceSize]byte +) + +func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { + KDF1(dst, c[:], data) +} + +func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { + hash, _ := blake2s.New256(nil) + hash.Write(h[:]) + hash.Write(data) + hash.Sum(dst[:0]) + hash.Reset() +} + +func (h *Handshake) Clear() { + setZero(h.localEphemeral[:]) + setZero(h.remoteEphemeral[:]) + setZero(h.chainKey[:]) + setZero(h.hash[:]) + h.localIndex = 0 + h.state = HandshakeZeroed +} + +func (h *Handshake) mixHash(data []byte) { + mixHash(&h.hash, &h.hash, data) +} + +func (h *Handshake) mixKey(data []byte) { + mixKey(&h.chainKey, &h.chainKey, data) +} + +/* Do basic precomputations + */ +func init() { + InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) + mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) +} + +func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + if isZero(handshake.precomputedStaticStatic[:]) { + return nil, errors.New("static shared secret is zero") + } + + // create ephemeral key + + var err error + handshake.hash = InitialHash + handshake.chainKey = InitialChainKey + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + + // assign index + + device.indexTable.Delete(handshake.localIndex) + handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) + + if err != nil { + return nil, err + } + + handshake.mixHash(handshake.remoteStatic[:]) + + msg := MessageInitiation{ + Type: MessageInitiationType, + Ephemeral: handshake.localEphemeral.publicKey(), + Sender: handshake.localIndex, + } + + handshake.mixKey(msg.Ephemeral[:]) + handshake.mixHash(msg.Ephemeral[:]) + + // encrypt static key + + func() { + var key [chacha20poly1305.KeySize]byte + ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + ss[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) + }() + handshake.mixHash(msg.Static[:]) + + // encrypt timestamp + + timestamp := tai64n.Now() + func() { + var key [chacha20poly1305.KeySize]byte + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) + }() + + handshake.mixHash(msg.Timestamp[:]) + handshake.state = HandshakeInitiationCreated + return &msg, nil +} + +func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) + + if msg.Type != MessageInitiationType { + return nil + } + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) + mixHash(&hash, &hash, msg.Ephemeral[:]) + mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) + + // decrypt static key + + var err error + var peerPK NoisePublicKey + func() { + var key [chacha20poly1305.KeySize]byte + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + KDF2(&chainKey, &key, chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) + }() + if err != nil { + return nil + } + mixHash(&hash, &hash, msg.Static[:]) + + // lookup peer + + peer := device.LookupPeer(peerPK) + if peer == nil { + return nil + } + + handshake := &peer.handshake + if isZero(handshake.precomputedStaticStatic[:]) { + return nil + } + + // verify identity + + var timestamp tai64n.Timestamp + var key [chacha20poly1305.KeySize]byte + + handshake.mutex.RLock() + KDF2( + &chainKey, + &key, + chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) + if err != nil { + handshake.mutex.RUnlock() + return nil + } + mixHash(&hash, &hash, msg.Timestamp[:]) + + // protect against replay & flood + + var ok bool + ok = timestamp.After(handshake.lastTimestamp) + ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate + handshake.mutex.RUnlock() + if !ok { + return nil + } + + // update handshake state + + handshake.mutex.Lock() + + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.remoteEphemeral = msg.Ephemeral + handshake.lastTimestamp = timestamp + handshake.lastInitiationConsumption = time.Now() + handshake.state = HandshakeInitiationConsumed + + handshake.mutex.Unlock() + + setZero(hash[:]) + setZero(chainKey[:]) + + return peer +} + +func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + if handshake.state != HandshakeInitiationConsumed { + return nil, errors.New("handshake initiation must be consumed first") + } + + // assign index + + var err error + device.indexTable.Delete(handshake.localIndex) + handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) + if err != nil { + return nil, err + } + + var msg MessageResponse + msg.Type = MessageResponseType + msg.Sender = handshake.localIndex + msg.Receiver = handshake.remoteIndex + + // create ephemeral key + + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + msg.Ephemeral = handshake.localEphemeral.publicKey() + handshake.mixHash(msg.Ephemeral[:]) + handshake.mixKey(msg.Ephemeral[:]) + + func() { + ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + handshake.mixKey(ss[:]) + ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + handshake.mixKey(ss[:]) + }() + + // add preshared key + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + + KDF3( + &handshake.chainKey, + &tau, + &key, + handshake.chainKey[:], + handshake.presharedKey[:], + ) + + handshake.mixHash(tau[:]) + + func() { + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) + handshake.mixHash(msg.Empty[:]) + }() + + handshake.state = HandshakeResponseCreated + + return &msg, nil +} + +func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { + if msg.Type != MessageResponseType { + return nil + } + + // lookup handshake by receiver + + lookup := device.indexTable.Lookup(msg.Receiver) + handshake := lookup.handshake + if handshake == nil { + return nil + } + + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) + + ok := func() bool { + + // lock handshake state + + handshake.mutex.RLock() + defer handshake.mutex.RUnlock() + + if handshake.state != HandshakeInitiationCreated { + return false + } + + // lock private key for reading + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + // finish 3-way DH + + mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) + mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) + + func() { + ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) + }() + + func() { + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + KDF3( + &chainKey, + &tau, + &key, + chainKey[:], + handshake.presharedKey[:], + ) + mixHash(&hash, &hash, tau[:]) + + // authenticate transcript + + aead, _ := chacha20poly1305.New(key[:]) + _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + if err != nil { + return false + } + mixHash(&hash, &hash, msg.Empty[:]) + return true + }() + + if !ok { + return nil + } + + // update handshake state + + handshake.mutex.Lock() + + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.state = HandshakeResponseConsumed + + handshake.mutex.Unlock() + + setZero(hash[:]) + setZero(chainKey[:]) + + return lookup.peer +} + +/* Derives a new keypair from the current handshake state + * + */ +func (peer *Peer) BeginSymmetricSession() error { + device := peer.device + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + // derive keys + + var isInitiator bool + var sendKey [chacha20poly1305.KeySize]byte + var recvKey [chacha20poly1305.KeySize]byte + + if handshake.state == HandshakeResponseConsumed { + KDF2( + &sendKey, + &recvKey, + handshake.chainKey[:], + nil, + ) + isInitiator = true + } else if handshake.state == HandshakeResponseCreated { + KDF2( + &recvKey, + &sendKey, + handshake.chainKey[:], + nil, + ) + isInitiator = false + } else { + return errors.New("invalid state for keypair derivation") + } + + // zero handshake + + setZero(handshake.chainKey[:]) + setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. + setZero(handshake.localEphemeral[:]) + peer.handshake.state = HandshakeZeroed + + // create AEAD instances + + keypair := new(Keypair) + keypair.send, _ = chacha20poly1305.New(sendKey[:]) + keypair.receive, _ = chacha20poly1305.New(recvKey[:]) + + setZero(sendKey[:]) + setZero(recvKey[:]) + + keypair.created = time.Now() + keypair.sendNonce = 0 + keypair.replayFilter.Init() + keypair.isInitiator = isInitiator + keypair.localIndex = peer.handshake.localIndex + keypair.remoteIndex = peer.handshake.remoteIndex + + // remap index + + device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) + handshake.localIndex = 0 + + // rotate key pairs + + keypairs := &peer.keypairs + keypairs.Lock() + defer keypairs.Unlock() + + previous := keypairs.previous + next := keypairs.next + current := keypairs.current + + if isInitiator { + if next != nil { + keypairs.next = nil + keypairs.previous = next + device.DeleteKeypair(current) + } else { + keypairs.previous = current + } + device.DeleteKeypair(previous) + keypairs.current = keypair + } else { + keypairs.next = keypair + device.DeleteKeypair(next) + keypairs.previous = nil + device.DeleteKeypair(previous) + } + + return nil +} + +func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { + keypairs := &peer.keypairs + if keypairs.next != receivedKeypair { + return false + } + keypairs.Lock() + defer keypairs.Unlock() + if keypairs.next != receivedKeypair { + return false + } + old := keypairs.previous + keypairs.previous = keypairs.current + peer.device.DeleteKeypair(old) + keypairs.current = keypairs.next + keypairs.next = nil + return true +} diff --git a/device/noise-types.go b/device/noise-types.go new file mode 100644 index 0000000..82b12c1 --- /dev/null +++ b/device/noise-types.go @@ -0,0 +1,81 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/subtle" + "encoding/hex" + "errors" + "golang.org/x/crypto/chacha20poly1305" +) + +const ( + NoisePublicKeySize = 32 + NoisePrivateKeySize = 32 +) + +type ( + NoisePublicKey [NoisePublicKeySize]byte + NoisePrivateKey [NoisePrivateKeySize]byte + NoiseSymmetricKey [chacha20poly1305.KeySize]byte + NoiseNonce uint64 // padded to 12-bytes +) + +func loadExactHex(dst []byte, src string) error { + slice, err := hex.DecodeString(src) + if err != nil { + return err + } + if len(slice) != len(dst) { + return errors.New("hex string does not fit the slice") + } + copy(dst, slice) + return nil +} + +func (key NoisePrivateKey) IsZero() bool { + var zero NoisePrivateKey + return key.Equals(zero) +} + +func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + +func (key *NoisePrivateKey) FromHex(src string) (err error) { + err = loadExactHex(key[:], src) + key.clamp() + return +} + +func (key NoisePrivateKey) ToHex() string { + return hex.EncodeToString(key[:]) +} + +func (key *NoisePublicKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoisePublicKey) ToHex() string { + return hex.EncodeToString(key[:]) +} + +func (key NoisePublicKey) IsZero() bool { + var zero NoisePublicKey + return key.Equals(zero) +} + +func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + +func (key *NoiseSymmetricKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoiseSymmetricKey) ToHex() string { + return hex.EncodeToString(key[:]) +} diff --git a/device/noise_test.go b/device/noise_test.go new file mode 100644 index 0000000..6ba3f2e --- /dev/null +++ b/device/noise_test.go @@ -0,0 +1,144 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func TestCurveWrappers(t *testing.T) { + sk1, err := newPrivateKey() + assertNil(t, err) + + sk2, err := newPrivateKey() + assertNil(t, err) + + pk1 := sk1.publicKey() + pk2 := sk2.publicKey() + + ss1 := sk1.sharedSecret(pk2) + ss2 := sk2.sharedSecret(pk1) + + if ss1 != ss2 { + t.Fatal("Failed to compute shared secet") + } +} + +func TestNoiseHandshake(t *testing.T) { + dev1 := randDevice(t) + dev2 := randDevice(t) + + defer dev1.Close() + defer dev2.Close() + + peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) + peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) + + assertEqual( + t, + peer1.handshake.precomputedStaticStatic[:], + peer2.handshake.precomputedStaticStatic[:], + ) + + /* simulate handshake */ + + // initiation message + + t.Log("exchange initiation message") + + msg1, err := dev1.CreateMessageInitiation(peer2) + assertNil(t, err) + + packet := make([]byte, 0, 256) + writer := bytes.NewBuffer(packet) + err = binary.Write(writer, binary.LittleEndian, msg1) + assertNil(t, err) + peer := dev2.ConsumeMessageInitiation(msg1) + if peer == nil { + t.Fatal("handshake failed at initiation message") + } + + assertEqual( + t, + peer1.handshake.chainKey[:], + peer2.handshake.chainKey[:], + ) + + assertEqual( + t, + peer1.handshake.hash[:], + peer2.handshake.hash[:], + ) + + // response message + + t.Log("exchange response message") + + msg2, err := dev2.CreateMessageResponse(peer1) + assertNil(t, err) + + peer = dev1.ConsumeMessageResponse(msg2) + if peer == nil { + t.Fatal("handshake failed at response message") + } + + assertEqual( + t, + peer1.handshake.chainKey[:], + peer2.handshake.chainKey[:], + ) + + assertEqual( + t, + peer1.handshake.hash[:], + peer2.handshake.hash[:], + ) + + // key pairs + + t.Log("deriving keys") + + err = peer1.BeginSymmetricSession() + if err != nil { + t.Fatal("failed to derive keypair for peer 1", err) + } + + err = peer2.BeginSymmetricSession() + if err != nil { + t.Fatal("failed to derive keypair for peer 2", err) + } + + key1 := peer1.keypairs.next + key2 := peer2.keypairs.current + + // encrypting / decryption test + + t.Log("test key pairs") + + func() { + testMsg := []byte("wireguard test message 1") + var err error + var out []byte + var nonce [12]byte + out = key1.send.Seal(out, nonce[:], testMsg, nil) + out, err = key2.receive.Open(out[:0], nonce[:], out, nil) + assertNil(t, err) + assertEqual(t, out, testMsg) + }() + + func() { + testMsg := []byte("wireguard test message 2") + var err error + var out []byte + var nonce [12]byte + out = key2.send.Seal(out, nonce[:], testMsg, nil) + out, err = key1.receive.Open(out[:0], nonce[:], out, nil) + assertNil(t, err) + assertEqual(t, out, testMsg) + }() +} diff --git a/device/peer.go b/device/peer.go new file mode 100644 index 0000000..af3ef9d --- /dev/null +++ b/device/peer.go @@ -0,0 +1,270 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "encoding/base64" + "errors" + "fmt" + "sync" + "time" +) + +const ( + PeerRoutineNumber = 3 +) + +type Peer struct { + isRunning AtomicBool + sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer + keypairs Keypairs + handshake Handshake + device *Device + endpoint Endpoint + persistentKeepaliveInterval uint16 + + // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly + stats struct { + txBytes uint64 // bytes send to peer (endpoint) + rxBytes uint64 // bytes received from peer + lastHandshakeNano int64 // nano seconds since epoch + } + + timers struct { + retransmitHandshake *Timer + sendKeepalive *Timer + newHandshake *Timer + zeroKeyMaterial *Timer + persistentKeepalive *Timer + handshakeAttempts uint32 + needAnotherKeepalive AtomicBool + sentLastMinuteHandshake AtomicBool + } + + signals struct { + newKeypairArrived chan struct{} + flushNonceQueue chan struct{} + } + + queue struct { + nonce chan *QueueOutboundElement // nonce / pre-handshake queue + outbound chan *QueueOutboundElement // sequential ordering of work + inbound chan *QueueInboundElement // sequential ordering of work + packetInNonceQueueIsAwaitingKey AtomicBool + } + + routines struct { + sync.Mutex // held when stopping / starting routines + starting sync.WaitGroup // routines pending start + stopping sync.WaitGroup // routines pending stop + stop chan struct{} // size 0, stop all go routines in peer + } + + cookieGenerator CookieGenerator +} + +func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { + + if device.isClosed.Get() { + return nil, errors.New("device closed") + } + + // lock resources + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + device.peers.Lock() + defer device.peers.Unlock() + + // check if over limit + + if len(device.peers.keyMap) >= MaxPeers { + return nil, errors.New("too many peers") + } + + // create peer + + peer := new(Peer) + peer.Lock() + defer peer.Unlock() + + peer.cookieGenerator.Init(pk) + peer.device = device + peer.isRunning.Set(false) + + // map public key + + _, ok := device.peers.keyMap[pk] + if ok { + return nil, errors.New("adding existing peer") + } + device.peers.keyMap[pk] = peer + + // pre-compute DH + + handshake := &peer.handshake + handshake.mutex.Lock() + handshake.remoteStatic = pk + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) + handshake.mutex.Unlock() + + // reset endpoint + + peer.endpoint = nil + + // start peer + + if peer.device.isUp.Get() { + peer.Start() + } + + return peer, nil +} + +func (peer *Peer) SendBuffer(buffer []byte) error { + peer.device.net.RLock() + defer peer.device.net.RUnlock() + + if peer.device.net.bind == nil { + return errors.New("no bind") + } + + peer.RLock() + defer peer.RUnlock() + + if peer.endpoint == nil { + return errors.New("no known endpoint for peer") + } + + return peer.device.net.bind.Send(buffer, peer.endpoint) +} + +func (peer *Peer) String() string { + base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) + abbreviatedKey := "invalid" + if len(base64Key) == 44 { + abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] + } + return fmt.Sprintf("peer(%s)", abbreviatedKey) +} + +func (peer *Peer) Start() { + + // should never start a peer on a closed device + + if peer.device.isClosed.Get() { + return + } + + // prevent simultaneous start/stop operations + + peer.routines.Lock() + defer peer.routines.Unlock() + + if peer.isRunning.Get() { + return + } + + device := peer.device + device.log.Debug.Println(peer, "- Starting...") + + // reset routine state + + peer.routines.starting.Wait() + peer.routines.stopping.Wait() + peer.routines.stop = make(chan struct{}) + peer.routines.starting.Add(PeerRoutineNumber) + peer.routines.stopping.Add(PeerRoutineNumber) + + // prepare queues + + peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + + peer.timersInit() + peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + peer.signals.newKeypairArrived = make(chan struct{}, 1) + peer.signals.flushNonceQueue = make(chan struct{}, 1) + + // wait for routines to start + + go peer.RoutineNonce() + go peer.RoutineSequentialSender() + go peer.RoutineSequentialReceiver() + + peer.routines.starting.Wait() + peer.isRunning.Set(true) +} + +func (peer *Peer) ZeroAndFlushAll() { + device := peer.device + + // clear key pairs + + keypairs := &peer.keypairs + keypairs.Lock() + device.DeleteKeypair(keypairs.previous) + device.DeleteKeypair(keypairs.current) + device.DeleteKeypair(keypairs.next) + keypairs.previous = nil + keypairs.current = nil + keypairs.next = nil + keypairs.Unlock() + + // clear handshake state + + handshake := &peer.handshake + handshake.mutex.Lock() + device.indexTable.Delete(handshake.localIndex) + handshake.Clear() + handshake.mutex.Unlock() + + peer.FlushNonceQueue() +} + +func (peer *Peer) Stop() { + + // prevent simultaneous start/stop operations + + if !peer.isRunning.Swap(false) { + return + } + + peer.routines.starting.Wait() + + peer.routines.Lock() + defer peer.routines.Unlock() + + peer.device.log.Debug.Println(peer, "- Stopping...") + + peer.timersStop() + + // stop & wait for ongoing peer routines + + close(peer.routines.stop) + peer.routines.stopping.Wait() + + // close queues + + close(peer.queue.nonce) + close(peer.queue.outbound) + close(peer.queue.inbound) + + peer.ZeroAndFlushAll() +} + +var roamingDisabled bool + +func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { + if roamingDisabled { + return + } + peer.Lock() + peer.endpoint = endpoint + peer.Unlock() +} diff --git a/device/pools.go b/device/pools.go new file mode 100644 index 0000000..98f4ef1 --- /dev/null +++ b/device/pools.go @@ -0,0 +1,89 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import "sync" + +func (device *Device) PopulatePools() { + if PreallocatedBuffersPerPool == 0 { + device.pool.messageBufferPool = &sync.Pool{ + New: func() interface{} { + return new([MaxMessageSize]byte) + }, + } + device.pool.inboundElementPool = &sync.Pool{ + New: func() interface{} { + return new(QueueInboundElement) + }, + } + device.pool.outboundElementPool = &sync.Pool{ + New: func() interface{} { + return new(QueueOutboundElement) + }, + } + } else { + device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool) + for i := 0; i < PreallocatedBuffersPerPool; i += 1 { + device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte) + } + device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool) + for i := 0; i < PreallocatedBuffersPerPool; i += 1 { + device.pool.inboundElementReuseChan <- new(QueueInboundElement) + } + device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool) + for i := 0; i < PreallocatedBuffersPerPool; i += 1 { + device.pool.outboundElementReuseChan <- new(QueueOutboundElement) + } + } +} + +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + if PreallocatedBuffersPerPool == 0 { + return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte) + } else { + return <-device.pool.messageBufferReuseChan + } +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + if PreallocatedBuffersPerPool == 0 { + device.pool.messageBufferPool.Put(msg) + } else { + device.pool.messageBufferReuseChan <- msg + } +} + +func (device *Device) GetInboundElement() *QueueInboundElement { + if PreallocatedBuffersPerPool == 0 { + return device.pool.inboundElementPool.Get().(*QueueInboundElement) + } else { + return <-device.pool.inboundElementReuseChan + } +} + +func (device *Device) PutInboundElement(msg *QueueInboundElement) { + if PreallocatedBuffersPerPool == 0 { + device.pool.inboundElementPool.Put(msg) + } else { + device.pool.inboundElementReuseChan <- msg + } +} + +func (device *Device) GetOutboundElement() *QueueOutboundElement { + if PreallocatedBuffersPerPool == 0 { + return device.pool.outboundElementPool.Get().(*QueueOutboundElement) + } else { + return <-device.pool.outboundElementReuseChan + } +} + +func (device *Device) PutOutboundElement(msg *QueueOutboundElement) { + if PreallocatedBuffersPerPool == 0 { + device.pool.outboundElementPool.Put(msg) + } else { + device.pool.outboundElementReuseChan <- msg + } +} diff --git a/device/queueconstants.go b/device/queueconstants.go new file mode 100644 index 0000000..3e94b7f --- /dev/null +++ b/device/queueconstants.go @@ -0,0 +1,16 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +/* Implementation specific constants */ + +const ( + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram + PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth +) diff --git a/device/receive.go b/device/receive.go new file mode 100644 index 0000000..5c837c1 --- /dev/null +++ b/device/receive.go @@ -0,0 +1,641 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "bytes" + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" + "strconv" + "sync" + "sync/atomic" + "time" +) + +type QueueHandshakeElement struct { + msgType uint32 + packet []byte + endpoint Endpoint + buffer *[MaxMessageSize]byte +} + +type QueueInboundElement struct { + dropped int32 + sync.Mutex + buffer *[MaxMessageSize]byte + packet []byte + counter uint64 + keypair *Keypair + endpoint Endpoint +} + +func (elem *QueueInboundElement) Drop() { + atomic.StoreInt32(&elem.dropped, AtomicTrue) +} + +func (elem *QueueInboundElement) IsDropped() bool { + return atomic.LoadInt32(&elem.dropped) == AtomicTrue +} + +func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool { + select { + case inboundQueue <- element: + select { + case decryptionQueue <- element: + return true + default: + element.Drop() + element.Unlock() + return false + } + default: + device.PutInboundElement(element) + return false + } +} + +func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool { + select { + case queue <- element: + return true + default: + return false + } +} + +/* Called when a new authenticated message has been received + * + * NOTE: Not thread safe, but called by sequential receiver! + */ +func (peer *Peer) keepKeyFreshReceiving() { + if peer.timers.sentLastMinuteHandshake.Get() { + return + } + keypair := peer.keypairs.Current() + if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { + peer.timers.sentLastMinuteHandshake.Set(true) + peer.SendHandshakeInitiation(false) + } +} + +/* Receives incoming datagrams for the device + * + * Every time the bind is updated a new routine is started for + * IPv4 and IPv6 (separately) + */ +func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { + + logDebug := device.log.Debug + defer func() { + logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + device.net.stopping.Done() + }() + + logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started") + device.net.starting.Done() + + // receive datagrams until conn is closed + + buffer := device.GetMessageBuffer() + + var ( + err error + size int + endpoint Endpoint + ) + + for { + + // read next datagram + + switch IP { + case ipv4.Version: + size, endpoint, err = bind.ReceiveIPv4(buffer[:]) + case ipv6.Version: + size, endpoint, err = bind.ReceiveIPv6(buffer[:]) + default: + panic("invalid IP version") + } + + if err != nil { + device.PutMessageBuffer(buffer) + return + } + + if size < MinMessageSize { + continue + } + + // check size of packet + + packet := buffer[:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) + + var okay bool + + switch msgType { + + // check if transport + + case MessageTransportType: + + // check size + + if len(packet) < MessageTransportSize { + continue + } + + // lookup key pair + + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { + continue + } + + // check keypair expiry + + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } + + // create work element + peer := value.peer + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = buffer + elem.keypair = keypair + elem.dropped = AtomicFalse + elem.endpoint = endpoint + elem.counter = 0 + elem.Mutex = sync.Mutex{} + elem.Lock() + + // add to decryption queues + + if peer.isRunning.Get() { + if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { + buffer = device.GetMessageBuffer() + } + } + + continue + + // otherwise it is a fixed size & handshake related packet + + case MessageInitiationType: + okay = len(packet) == MessageInitiationSize + + case MessageResponseType: + okay = len(packet) == MessageResponseSize + + case MessageCookieReplyType: + okay = len(packet) == MessageCookieReplySize + + default: + logDebug.Println("Received message with unknown type") + } + + if okay { + if (device.addToHandshakeQueue( + device.queue.handshake, + QueueHandshakeElement{ + msgType: msgType, + buffer: buffer, + packet: packet, + endpoint: endpoint, + }, + )) { + buffer = device.GetMessageBuffer() + } + } + } +} + +func (device *Device) RoutineDecryption() { + + var nonce [chacha20poly1305.NonceSize]byte + + logDebug := device.log.Debug + defer func() { + logDebug.Println("Routine: decryption worker - stopped") + device.state.stopping.Done() + }() + logDebug.Println("Routine: decryption worker - started") + device.state.starting.Done() + + for { + select { + case <-device.signals.stop: + return + + case elem, ok := <-device.queue.decryption: + + if !ok { + return + } + + // check if dropped + + if elem.IsDropped() { + continue + } + + // split message into fields + + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] + + // expand nonce + + nonce[0x4] = counter[0x0] + nonce[0x5] = counter[0x1] + nonce[0x6] = counter[0x2] + nonce[0x7] = counter[0x3] + + nonce[0x8] = counter[0x4] + nonce[0x9] = counter[0x5] + nonce[0xa] = counter[0x6] + nonce[0xb] = counter[0x7] + + // decrypt and release to consumer + + var err error + elem.counter = binary.LittleEndian.Uint64(counter) + elem.packet, err = elem.keypair.receive.Open( + content[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.Drop() + device.PutMessageBuffer(elem.buffer) + } + elem.Unlock() + } + } +} + +/* Handles incoming packets related to handshake + */ +func (device *Device) RoutineHandshake() { + + logInfo := device.log.Info + logError := device.log.Error + logDebug := device.log.Debug + + var elem QueueHandshakeElement + var ok bool + + defer func() { + logDebug.Println("Routine: handshake worker - stopped") + device.state.stopping.Done() + if elem.buffer != nil { + device.PutMessageBuffer(elem.buffer) + } + }() + + logDebug.Println("Routine: handshake worker - started") + device.state.starting.Done() + + for { + if elem.buffer != nil { + device.PutMessageBuffer(elem.buffer) + elem.buffer = nil + } + + select { + case elem, ok = <-device.queue.handshake: + case <-device.signals.stop: + return + } + + if !ok { + return + } + + // handle cookie fields and ratelimiting + + switch elem.msgType { + + case MessageCookieReplyType: + + // unmarshal packet + + var reply MessageCookieReply + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &reply) + if err != nil { + logDebug.Println("Failed to decode cookie reply") + return + } + + // lookup peer from index + + entry := device.indexTable.Lookup(reply.Receiver) + + if entry.peer == nil { + continue + } + + // consume reply + + if peer := entry.peer; peer.isRunning.Get() { + logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString()) + if !peer.cookieGenerator.ConsumeReply(&reply) { + logDebug.Println("Could not decrypt invalid cookie response") + } + } + + continue + + case MessageInitiationType, MessageResponseType: + + // check mac fields and maybe ratelimit + + if !device.cookieChecker.CheckMAC1(elem.packet) { + logDebug.Println("Received packet with invalid mac1") + continue + } + + // endpoints destination address is the source of the datagram + + if device.IsUnderLoad() { + + // verify MAC2 field + + if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { + device.SendHandshakeCookie(&elem) + continue + } + + // check ratelimiter + + if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { + continue + } + } + + default: + logError.Println("Invalid packet ended up in the handshake queue") + continue + } + + // handle handshake initiation/response content + + switch elem.msgType { + case MessageInitiationType: + + // unmarshal + + var msg MessageInitiation + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode initiation message") + continue + } + + // consume initiation + + peer := device.ConsumeMessageInitiation(&msg) + if peer == nil { + logInfo.Println( + "Received invalid initiation message from", + elem.endpoint.DstToString(), + ) + continue + } + + // update timers + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) + + logDebug.Println(peer, "- Received handshake initiation") + + peer.SendHandshakeResponse() + + case MessageResponseType: + + // unmarshal + + var msg MessageResponse + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode response message") + continue + } + + // consume response + + peer := device.ConsumeMessageResponse(&msg) + if peer == nil { + logInfo.Println( + "Received invalid response message from", + elem.endpoint.DstToString(), + ) + continue + } + + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) + + logDebug.Println(peer, "- Received handshake response") + + // update timers + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + + // derive keypair + + err = peer.BeginSymmetricSession() + + if err != nil { + logError.Println(peer, "- Failed to derive keypair:", err) + continue + } + + peer.timersSessionDerived() + peer.timersHandshakeComplete() + peer.SendKeepalive() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } + } + } +} + +func (peer *Peer) RoutineSequentialReceiver() { + + device := peer.device + logInfo := device.log.Info + logError := device.log.Error + logDebug := device.log.Debug + + var elem *QueueInboundElement + var ok bool + + defer func() { + logDebug.Println(peer, "- Routine: sequential receiver - stopped") + peer.routines.stopping.Done() + if elem != nil { + if !elem.IsDropped() { + device.PutMessageBuffer(elem.buffer) + } + device.PutInboundElement(elem) + } + }() + + logDebug.Println(peer, "- Routine: sequential receiver - started") + + peer.routines.starting.Done() + + for { + if elem != nil { + if !elem.IsDropped() { + device.PutMessageBuffer(elem.buffer) + } + device.PutInboundElement(elem) + elem = nil + } + + select { + + case <-peer.routines.stop: + return + + case elem, ok = <-peer.queue.inbound: + + if !ok { + return + } + + // wait for decryption + + elem.Lock() + + if elem.IsDropped() { + continue + } + + // check for replay + + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { + continue + } + + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) + + // check if using new keypair + if peer.ReceivedWithKeypair(elem.keypair) { + peer.timersHandshakeComplete() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } + } + + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + + // check for keepalive + + if len(elem.packet) == 0 { + logDebug.Println(peer, "- Receiving keepalive packet") + continue + } + peer.timersDataReceived() + + // verify source and strip padding + + switch elem.packet[0] >> 4 { + case ipv4.Version: + + // strip padding + + if len(elem.packet) < ipv4.HeaderLen { + continue + } + + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv4 source + + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.LookupIPv4(src) != peer { + logInfo.Println( + "IPv4 packet with disallowed source address from", + peer, + ) + continue + } + + case ipv6.Version: + + // strip padding + + if len(elem.packet) < ipv6.HeaderLen { + continue + } + + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv6 source + + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.LookupIPv6(src) != peer { + logInfo.Println( + peer, + "sent packet with disallowed IPv6 source", + ) + continue + } + + default: + logInfo.Println("Packet with invalid IP version from", peer) + continue + } + + // write to tun device + + offset := MessageTransportOffsetContent + atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) + if err != nil { + logError.Println("Failed to write packet to TUN device:", err) + } + } + } +} diff --git a/device/send.go b/device/send.go new file mode 100644 index 0000000..b4e23c7 --- /dev/null +++ b/device/send.go @@ -0,0 +1,618 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "bytes" + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" + "sync" + "sync/atomic" + "time" +) + +/* Outbound flow + * + * 1. TUN queue + * 2. Routing (sequential) + * 3. Nonce assignment (sequential) + * 4. Encryption (parallel) + * 5. Transmission (sequential) + * + * The functions in this file occur (roughly) in the order in + * which the packets are processed. + * + * Locking, Producers and Consumers + * + * The order of packets (per peer) must be maintained, + * but encryption of packets happen out-of-order: + * + * The sequential consumers will attempt to take the lock, + * workers release lock when they have completed work (encryption) on the packet. + * + * If the element is inserted into the "encryption queue", + * the content is preceded by enough "junk" to contain the transport header + * (to allow the construction of transport messages in-place) + */ + +type QueueOutboundElement struct { + dropped int32 + sync.Mutex + buffer *[MaxMessageSize]byte // slice holding the packet data + packet []byte // slice of "buffer" (always!) + nonce uint64 // nonce for encryption + keypair *Keypair // keypair for encryption + peer *Peer // related peer +} + +func (device *Device) NewOutboundElement() *QueueOutboundElement { + elem := device.GetOutboundElement() + elem.dropped = AtomicFalse + elem.buffer = device.GetMessageBuffer() + elem.Mutex = sync.Mutex{} + elem.nonce = 0 + elem.keypair = nil + elem.peer = nil + return elem +} + +func (elem *QueueOutboundElement) Drop() { + atomic.StoreInt32(&elem.dropped, AtomicTrue) +} + +func (elem *QueueOutboundElement) IsDropped() bool { + return atomic.LoadInt32(&elem.dropped) == AtomicTrue +} + +func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) { + for { + select { + case queue <- element: + return + default: + select { + case old := <-queue: + device.PutMessageBuffer(old.buffer) + device.PutOutboundElement(old) + default: + } + } + } +} + +func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) { + select { + case outboundQueue <- element: + select { + case encryptionQueue <- element: + return + default: + element.Drop() + element.peer.device.PutMessageBuffer(element.buffer) + element.Unlock() + } + default: + element.peer.device.PutMessageBuffer(element.buffer) + element.peer.device.PutOutboundElement(element) + } +} + +/* Queues a keepalive if no packets are queued for peer + */ +func (peer *Peer) SendKeepalive() bool { + if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() { + return false + } + elem := peer.device.NewOutboundElement() + elem.packet = nil + select { + case peer.queue.nonce <- elem: + peer.device.log.Debug.Println(peer, "- Sending keepalive packet") + return true + default: + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + return false + } +} + +func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { + if !isRetry { + atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + } + + peer.handshake.mutex.RLock() + if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout { + peer.handshake.mutex.RUnlock() + return nil + } + peer.handshake.mutex.RUnlock() + + peer.handshake.mutex.Lock() + if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout { + peer.handshake.mutex.Unlock() + return nil + } + peer.handshake.lastSentHandshake = time.Now() + peer.handshake.mutex.Unlock() + + peer.device.log.Debug.Println(peer, "- Sending handshake initiation") + + msg, err := peer.device.CreateMessageInitiation(peer) + if err != nil { + peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err) + return err + } + + var buff [MessageInitiationSize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, msg) + packet := writer.Bytes() + peer.cookieGenerator.AddMacs(packet) + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() + + err = peer.SendBuffer(packet) + if err != nil { + peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err) + } + peer.timersHandshakeInitiated() + + return err +} + +func (peer *Peer) SendHandshakeResponse() error { + peer.handshake.mutex.Lock() + peer.handshake.lastSentHandshake = time.Now() + peer.handshake.mutex.Unlock() + + peer.device.log.Debug.Println(peer, "- Sending handshake response") + + response, err := peer.device.CreateMessageResponse(peer) + if err != nil { + peer.device.log.Error.Println(peer, "- Failed to create response message:", err) + return err + } + + var buff [MessageResponseSize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, response) + packet := writer.Bytes() + peer.cookieGenerator.AddMacs(packet) + + err = peer.BeginSymmetricSession() + if err != nil { + peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err) + return err + } + + peer.timersSessionDerived() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() + + err = peer.SendBuffer(packet) + if err != nil { + peer.device.log.Error.Println(peer, "- Failed to send handshake response", err) + } + return err +} + +func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { + + device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString()) + + sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) + reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) + if err != nil { + device.log.Error.Println("Failed to create cookie reply:", err) + return err + } + + var buff [MessageCookieReplySize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, reply) + device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + if err != nil { + device.log.Error.Println("Failed to send cookie reply:", err) + } + return err +} + +func (peer *Peer) keepKeyFreshSending() { + keypair := peer.keypairs.Current() + if keypair == nil { + return + } + nonce := atomic.LoadUint64(&keypair.sendNonce) + if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Now().Sub(keypair.created) > RekeyAfterTime) { + peer.SendHandshakeInitiation(false) + } +} + +/* Reads packets from the TUN and inserts + * into nonce queue for peer + * + * Obs. Single instance per TUN device + */ +func (device *Device) RoutineReadFromTUN() { + + logDebug := device.log.Debug + logError := device.log.Error + + defer func() { + logDebug.Println("Routine: TUN reader - stopped") + device.state.stopping.Done() + }() + + logDebug.Println("Routine: TUN reader - started") + device.state.starting.Done() + + var elem *QueueOutboundElement + + for { + if elem != nil { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + elem = device.NewOutboundElement() + + // read packet + + offset := MessageTransportHeaderSize + size, err := device.tun.device.Read(elem.buffer[:], offset) + + if err != nil { + if !device.isClosed.Get() { + logError.Println("Failed to read packet from TUN device:", err) + device.Close() + } + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + return + } + + if size == 0 || size > MaxContentSize { + continue + } + + elem.packet = elem.buffer[offset : offset+size] + + // lookup peer + + var peer *Peer + switch elem.packet[0] >> 4 { + case ipv4.Version: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.allowedips.LookupIPv4(dst) + + case ipv6.Version: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.allowedips.LookupIPv6(dst) + + default: + logDebug.Println("Received packet with unknown IP version") + } + + if peer == nil { + continue + } + + // insert into nonce/pre-handshake queue + + if peer.isRunning.Get() { + if peer.queue.packetInNonceQueueIsAwaitingKey.Get() { + peer.SendHandshakeInitiation(false) + } + addToNonceQueue(peer.queue.nonce, elem, device) + elem = nil + } + } +} + +func (peer *Peer) FlushNonceQueue() { + select { + case peer.signals.flushNonceQueue <- struct{}{}: + default: + } +} + +/* Queues packets when there is no handshake. + * Then assigns nonces to packets sequentially + * and creates "work" structs for workers + * + * Obs. A single instance per peer + */ +func (peer *Peer) RoutineNonce() { + var keypair *Keypair + + device := peer.device + logDebug := device.log.Debug + + flush := func() { + for { + select { + case elem := <-peer.queue.nonce: + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + default: + return + } + } + } + + defer func() { + flush() + logDebug.Println(peer, "- Routine: nonce worker - stopped") + peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) + peer.routines.stopping.Done() + }() + + peer.routines.starting.Done() + logDebug.Println(peer, "- Routine: nonce worker - started") + + for { + NextPacket: + peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) + + select { + case <-peer.routines.stop: + return + + case <-peer.signals.flushNonceQueue: + flush() + goto NextPacket + + case elem, ok := <-peer.queue.nonce: + + if !ok { + return + } + + // make sure to always pick the newest key + + for { + + // check validity of newest key pair + + keypair = peer.keypairs.Current() + if keypair != nil && keypair.sendNonce < RejectAfterMessages { + if time.Now().Sub(keypair.created) < RejectAfterTime { + break + } + } + peer.queue.packetInNonceQueueIsAwaitingKey.Set(true) + + // no suitable key pair, request for new handshake + + select { + case <-peer.signals.newKeypairArrived: + default: + } + + peer.SendHandshakeInitiation(false) + + // wait for key to be established + + logDebug.Println(peer, "- Awaiting keypair") + + select { + case <-peer.signals.newKeypairArrived: + logDebug.Println(peer, "- Obtained awaited keypair") + + case <-peer.signals.flushNonceQueue: + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + flush() + goto NextPacket + + case <-peer.routines.stop: + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + return + } + } + peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) + + // populate work element + + elem.peer = peer + elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 + + // double check in case of race condition added by future code + + if elem.nonce >= RejectAfterMessages { + atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + goto NextPacket + } + + elem.keypair = keypair + elem.dropped = AtomicFalse + elem.Lock() + + // add to parallel and sequential queue + addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem) + } + } +} + +/* Encrypts the elements in the queue + * and marks them for sequential consumption (by releasing the mutex) + * + * Obs. One instance per core + */ +func (device *Device) RoutineEncryption() { + + var nonce [chacha20poly1305.NonceSize]byte + + logDebug := device.log.Debug + + defer func() { + for { + select { + case elem, ok := <-device.queue.encryption: + if ok && !elem.IsDropped() { + elem.Drop() + device.PutMessageBuffer(elem.buffer) + elem.Unlock() + } + default: + goto out + } + } + out: + logDebug.Println("Routine: encryption worker - stopped") + device.state.stopping.Done() + }() + + logDebug.Println("Routine: encryption worker - started") + device.state.starting.Done() + + for { + + // fetch next element + + select { + case <-device.signals.stop: + return + + case elem, ok := <-device.queue.encryption: + + if !ok { + return + } + + // check if dropped + + if elem.IsDropped() { + continue + } + + // populate header fields + + header := elem.buffer[:MessageTransportHeaderSize] + + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] + + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + + // pad content to multiple of 16 + + mtu := int(atomic.LoadInt32(&device.tun.mtu)) + lastUnit := len(elem.packet) % mtu + paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1) + if paddedSize > mtu { + paddedSize = mtu + } + for i := len(elem.packet); i < paddedSize; i++ { + elem.packet = append(elem.packet, 0) + } + + // encrypt content and release to consumer + + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keypair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + elem.Unlock() + } + } +} + +/* Sequentially reads packets from queue and sends to endpoint + * + * Obs. Single instance per peer. + * The routine terminates then the outbound queue is closed. + */ +func (peer *Peer) RoutineSequentialSender() { + + device := peer.device + + logDebug := device.log.Debug + logError := device.log.Error + + defer func() { + for { + select { + case elem, ok := <-peer.queue.outbound: + if ok { + if !elem.IsDropped() { + device.PutMessageBuffer(elem.buffer) + elem.Drop() + } + device.PutOutboundElement(elem) + } + default: + goto out + } + } + out: + logDebug.Println(peer, "- Routine: sequential sender - stopped") + peer.routines.stopping.Done() + }() + + logDebug.Println(peer, "- Routine: sequential sender - started") + + peer.routines.starting.Done() + + for { + select { + + case <-peer.routines.stop: + return + + case elem, ok := <-peer.queue.outbound: + + if !ok { + return + } + + elem.Lock() + if elem.IsDropped() { + device.PutOutboundElement(elem) + continue + } + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() + + // send message and return buffer to pool + + length := uint64(len(elem.packet)) + err := peer.SendBuffer(elem.packet) + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + if err != nil { + logError.Println(peer, "- Failed to send data packet", err) + continue + } + atomic.AddUint64(&peer.stats.txBytes, length) + + if len(elem.packet) != MessageKeepaliveSize { + peer.timersDataSent() + } + peer.keepKeyFreshSending() + } + } +} diff --git a/device/timers.go b/device/timers.go new file mode 100644 index 0000000..5f28fcc --- /dev/null +++ b/device/timers.go @@ -0,0 +1,227 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * + * This is based heavily on timers.c from the kernel implementation. + */ + +package device + +import ( + "math/rand" + "sync" + "sync/atomic" + "time" +) + +/* This Timer structure and related functions should roughly copy the interface of + * the Linux kernel's struct timer_list. + */ + +type Timer struct { + *time.Timer + modifyingLock sync.RWMutex + runningLock sync.Mutex + isPending bool +} + +func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { + timer := &Timer{} + timer.Timer = time.AfterFunc(time.Hour, func() { + timer.runningLock.Lock() + + timer.modifyingLock.Lock() + if !timer.isPending { + timer.modifyingLock.Unlock() + timer.runningLock.Unlock() + return + } + timer.isPending = false + timer.modifyingLock.Unlock() + + expirationFunction(peer) + timer.runningLock.Unlock() + }) + timer.Stop() + return timer +} + +func (timer *Timer) Mod(d time.Duration) { + timer.modifyingLock.Lock() + timer.isPending = true + timer.Reset(d) + timer.modifyingLock.Unlock() +} + +func (timer *Timer) Del() { + timer.modifyingLock.Lock() + timer.isPending = false + timer.Stop() + timer.modifyingLock.Unlock() +} + +func (timer *Timer) DelSync() { + timer.Del() + timer.runningLock.Lock() + timer.Del() + timer.runningLock.Unlock() +} + +func (timer *Timer) IsPending() bool { + timer.modifyingLock.RLock() + defer timer.modifyingLock.RUnlock() + return timer.isPending +} + +func (peer *Peer) timersActive() bool { + return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0 +} + +func expiredRetransmitHandshake(peer *Peer) { + if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { + peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2) + + if peer.timersActive() { + peer.timers.sendKeepalive.Del() + } + + /* We drop all packets without a keypair and don't try again, + * if we try unsuccessfully for too long to make a handshake. + */ + peer.FlushNonceQueue() + + /* We set a timer for destroying any residue that might be left + * of a partial exchange. + */ + if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() { + peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) + } + } else { + atomic.AddUint32(&peer.timers.handshakeAttempts, 1) + peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) + + /* We clear the endpoint address src address, in case this is the cause of trouble. */ + peer.Lock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + peer.Unlock() + + peer.SendHandshakeInitiation(true) + } +} + +func expiredSendKeepalive(peer *Peer) { + peer.SendKeepalive() + if peer.timers.needAnotherKeepalive.Get() { + peer.timers.needAnotherKeepalive.Set(false) + if peer.timersActive() { + peer.timers.sendKeepalive.Mod(KeepaliveTimeout) + } + } +} + +func expiredNewHandshake(peer *Peer) { + peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) + /* We clear the endpoint address src address, in case this is the cause of trouble. */ + peer.Lock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + peer.Unlock() + peer.SendHandshakeInitiation(false) + +} + +func expiredZeroKeyMaterial(peer *Peer) { + peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) + peer.ZeroAndFlushAll() +} + +func expiredPersistentKeepalive(peer *Peer) { + if peer.persistentKeepaliveInterval > 0 { + peer.SendKeepalive() + } +} + +/* Should be called after an authenticated data packet is sent. */ +func (peer *Peer) timersDataSent() { + if peer.timersActive() && !peer.timers.newHandshake.IsPending() { + peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout) + } +} + +/* Should be called after an authenticated data packet is received. */ +func (peer *Peer) timersDataReceived() { + if peer.timersActive() { + if !peer.timers.sendKeepalive.IsPending() { + peer.timers.sendKeepalive.Mod(KeepaliveTimeout) + } else { + peer.timers.needAnotherKeepalive.Set(true) + } + } +} + +/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ +func (peer *Peer) timersAnyAuthenticatedPacketSent() { + if peer.timersActive() { + peer.timers.sendKeepalive.Del() + } +} + +/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ +func (peer *Peer) timersAnyAuthenticatedPacketReceived() { + if peer.timersActive() { + peer.timers.newHandshake.Del() + } +} + +/* Should be called after a handshake initiation message is sent. */ +func (peer *Peer) timersHandshakeInitiated() { + if peer.timersActive() { + peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) + } +} + +/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ +func (peer *Peer) timersHandshakeComplete() { + if peer.timersActive() { + peer.timers.retransmitHandshake.Del() + } + atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.sentLastMinuteHandshake.Set(false) + atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) +} + +/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ +func (peer *Peer) timersSessionDerived() { + if peer.timersActive() { + peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) + } +} + +/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ +func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { + if peer.persistentKeepaliveInterval > 0 && peer.timersActive() { + peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second) + } +} + +func (peer *Peer) timersInit() { + peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake) + peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive) + peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) + peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) + peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) + atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.sentLastMinuteHandshake.Set(false) + peer.timers.needAnotherKeepalive.Set(false) +} + +func (peer *Peer) timersStop() { + peer.timers.retransmitHandshake.DelSync() + peer.timers.sendKeepalive.DelSync() + peer.timers.newHandshake.DelSync() + peer.timers.zeroKeyMaterial.DelSync() + peer.timers.persistentKeepalive.DelSync() +} diff --git a/device/tun.go b/device/tun.go new file mode 100644 index 0000000..bc5f1f1 --- /dev/null +++ b/device/tun.go @@ -0,0 +1,55 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "golang.zx2c4.com/wireguard/tun" + "sync/atomic" +) + +const DefaultMTU = 1420 + +func (device *Device) RoutineTUNEventReader() { + setUp := false + logDebug := device.log.Debug + logInfo := device.log.Info + logError := device.log.Error + + logDebug.Println("Routine: event worker - started") + device.state.starting.Done() + + for event := range device.tun.device.Events() { + if event&tun.TUNEventMTUUpdate != 0 { + mtu, err := device.tun.device.MTU() + old := atomic.LoadInt32(&device.tun.mtu) + if err != nil { + logError.Println("Failed to load updated MTU of device:", err) + } else if int(old) != mtu { + if mtu+MessageTransportSize > MaxMessageSize { + logInfo.Println("MTU updated:", mtu, "(too large)") + } else { + logInfo.Println("MTU updated:", mtu) + } + atomic.StoreInt32(&device.tun.mtu, int32(mtu)) + } + } + + if event&tun.TUNEventUp != 0 && !setUp { + logInfo.Println("Interface set up") + setUp = true + device.Up() + } + + if event&tun.TUNEventDown != 0 && setUp { + logInfo.Println("Interface set down") + setUp = false + device.Down() + } + } + + logDebug.Println("Routine: event worker - stopped") + device.state.stopping.Done() +} diff --git a/device/uapi.go b/device/uapi.go new file mode 100644 index 0000000..5c65917 --- /dev/null +++ b/device/uapi.go @@ -0,0 +1,426 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "bufio" + "fmt" + "golang.zx2c4.com/wireguard/ipc" + "io" + "net" + "strconv" + "strings" + "sync/atomic" + "time" +) + +type IPCError struct { + int64 +} + +func (s *IPCError) Error() string { + return fmt.Sprintf("IPC error: %d", s.int64) +} + +func (s *IPCError) ErrorCode() int64 { + return s.int64 +} + +func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError { + + device.log.Debug.Println("UAPI: Processing get operation") + + // create lines + + lines := make([]string, 0, 100) + send := func(line string) { + lines = append(lines, line) + } + + func() { + + // lock required resources + + device.net.RLock() + defer device.net.RUnlock() + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + device.peers.RLock() + defer device.peers.RUnlock() + + // serialize device related values + + if !device.staticIdentity.privateKey.IsZero() { + send("private_key=" + device.staticIdentity.privateKey.ToHex()) + } + + if device.net.port != 0 { + send(fmt.Sprintf("listen_port=%d", device.net.port)) + } + + if device.net.fwmark != 0 { + send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) + } + + // serialize each peer state + + for _, peer := range device.peers.keyMap { + peer.RLock() + defer peer.RUnlock() + + send("public_key=" + peer.handshake.remoteStatic.ToHex()) + send("preshared_key=" + peer.handshake.presharedKey.ToHex()) + send("protocol_version=1") + if peer.endpoint != nil { + send("endpoint=" + peer.endpoint.DstToString()) + } + + nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) + secs := nano / time.Second.Nanoseconds() + nano %= time.Second.Nanoseconds() + + send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) + send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) + send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))) + send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))) + send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) + + for _, ip := range device.allowedips.EntriesForPeer(peer) { + send("allowed_ip=" + ip.String()) + } + + } + }() + + // send lines (does not require resource locks) + + for _, line := range lines { + _, err := socket.WriteString(line + "\n") + if err != nil { + return &IPCError{ipc.IpcErrorIO} + } + } + + return nil +} + +func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError { + scanner := bufio.NewScanner(socket) + logError := device.log.Error + logDebug := device.log.Debug + + var peer *Peer + + dummy := false + deviceConfig := true + + for scanner.Scan() { + + // parse line + + line := scanner.Text() + if line == "" { + return nil + } + parts := strings.Split(line, "=") + if len(parts) != 2 { + return &IPCError{ipc.IpcErrorProtocol} + } + key := parts[0] + value := parts[1] + + /* device configuration */ + + if deviceConfig { + + switch key { + case "private_key": + var sk NoisePrivateKey + err := sk.FromHex(value) + if err != nil { + logError.Println("Failed to set private_key:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + logDebug.Println("UAPI: Updating private key") + device.SetPrivateKey(sk) + + case "listen_port": + + // parse port number + + port, err := strconv.ParseUint(value, 10, 16) + if err != nil { + logError.Println("Failed to parse listen_port:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + // update port and rebind + + logDebug.Println("UAPI: Updating listen port") + + device.net.Lock() + device.net.port = uint16(port) + device.net.Unlock() + + if err := device.BindUpdate(); err != nil { + logError.Println("Failed to set listen_port:", err) + return &IPCError{ipc.IpcErrorPortInUse} + } + + case "fwmark": + + // parse fwmark field + + fwmark, err := func() (uint32, error) { + if value == "" { + return 0, nil + } + mark, err := strconv.ParseUint(value, 10, 32) + return uint32(mark), err + }() + + if err != nil { + logError.Println("Invalid fwmark", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + logDebug.Println("UAPI: Updating fwmark") + + if err := device.BindSetMark(uint32(fwmark)); err != nil { + logError.Println("Failed to update fwmark:", err) + return &IPCError{ipc.IpcErrorPortInUse} + } + + case "public_key": + // switch to peer configuration + logDebug.Println("UAPI: Transition to peer configuration") + deviceConfig = false + + case "replace_peers": + if value != "true" { + logError.Println("Failed to set replace_peers, invalid value:", value) + return &IPCError{ipc.IpcErrorInvalid} + } + logDebug.Println("UAPI: Removing all peers") + device.RemoveAllPeers() + + default: + logError.Println("Invalid UAPI device key:", key) + return &IPCError{ipc.IpcErrorInvalid} + } + } + + /* peer configuration */ + + if !deviceConfig { + + switch key { + + case "public_key": + var publicKey NoisePublicKey + err := publicKey.FromHex(value) + if err != nil { + logError.Println("Failed to get peer by public key:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + // ignore peer with public key of device + + device.staticIdentity.RLock() + dummy = device.staticIdentity.publicKey.Equals(publicKey) + device.staticIdentity.RUnlock() + + if dummy { + peer = &Peer{} + } else { + peer = device.LookupPeer(publicKey) + } + + if peer == nil { + peer, err = device.NewPeer(publicKey) + if err != nil { + logError.Println("Failed to create new peer:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + logDebug.Println(peer, "- UAPI: Created") + } + + case "remove": + + // remove currently selected peer from device + + if value != "true" { + logError.Println("Failed to set remove, invalid value:", value) + return &IPCError{ipc.IpcErrorInvalid} + } + if !dummy { + logDebug.Println(peer, "- UAPI: Removing") + device.RemovePeer(peer.handshake.remoteStatic) + } + peer = &Peer{} + dummy = true + + case "preshared_key": + + // update PSK + + logDebug.Println(peer, "- UAPI: Updating preshared key") + + peer.handshake.mutex.Lock() + err := peer.handshake.presharedKey.FromHex(value) + peer.handshake.mutex.Unlock() + + if err != nil { + logError.Println("Failed to set preshared key:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + case "endpoint": + + // set endpoint destination + + logDebug.Println(peer, "- UAPI: Updating endpoint") + + err := func() error { + peer.Lock() + defer peer.Unlock() + endpoint, err := CreateEndpoint(value) + if err != nil { + return err + } + peer.endpoint = endpoint + return nil + }() + + if err != nil { + logError.Println("Failed to set endpoint:", value) + return &IPCError{ipc.IpcErrorInvalid} + } + + case "persistent_keepalive_interval": + + // update persistent keepalive interval + + logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval") + + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { + logError.Println("Failed to set persistent keepalive interval:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + old := peer.persistentKeepaliveInterval + peer.persistentKeepaliveInterval = uint16(secs) + + // send immediate keepalive if we're turning it on and before it wasn't on + + if old == 0 && secs != 0 { + if err != nil { + logError.Println("Failed to get tun device status:", err) + return &IPCError{ipc.IpcErrorIO} + } + if device.isUp.Get() && !dummy { + peer.SendKeepalive() + } + } + + case "replace_allowed_ips": + + logDebug.Println(peer, "- UAPI: Removing all allowedips") + + if value != "true" { + logError.Println("Failed to replace allowedips, invalid value:", value) + return &IPCError{ipc.IpcErrorInvalid} + } + + if dummy { + continue + } + + device.allowedips.RemoveByPeer(peer) + + case "allowed_ip": + + logDebug.Println(peer, "- UAPI: Adding allowedip") + + _, network, err := net.ParseCIDR(value) + if err != nil { + logError.Println("Failed to set allowed ip:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + if dummy { + continue + } + + ones, _ := network.Mask.Size() + device.allowedips.Insert(network.IP, uint(ones), peer) + + case "protocol_version": + + if value != "1" { + logError.Println("Invalid protocol version:", value) + return &IPCError{ipc.IpcErrorInvalid} + } + + default: + logError.Println("Invalid UAPI peer key:", key) + return &IPCError{ipc.IpcErrorInvalid} + } + } + } + + return nil +} + +func (device *Device) IpcHandle(socket net.Conn) { + + // create buffered read/writer + + defer socket.Close() + + buffered := func(s io.ReadWriter) *bufio.ReadWriter { + reader := bufio.NewReader(s) + writer := bufio.NewWriter(s) + return bufio.NewReadWriter(reader, writer) + }(socket) + + defer buffered.Flush() + + op, err := buffered.ReadString('\n') + if err != nil { + return + } + + // handle operation + + var status *IPCError + + switch op { + case "set=1\n": + device.log.Debug.Println("UAPI: Set operation") + status = device.IpcSetOperation(buffered.Reader) + + case "get=1\n": + device.log.Debug.Println("UAPI: Get operation") + status = device.IpcGetOperation(buffered.Writer) + + default: + device.log.Error.Println("Invalid UAPI operation:", op) + return + } + + // write status + + if status != nil { + device.log.Error.Println(status) + fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") + } +} diff --git a/device/version.go b/device/version.go new file mode 100644 index 0000000..9077cdc --- /dev/null +++ b/device/version.go @@ -0,0 +1,3 @@ +package device + +const WireGuardGoVersion = "0.0.20181222" diff --git a/device_test.go b/device_test.go deleted file mode 100644 index df0ba69..0000000 --- a/device_test.go +++ /dev/null @@ -1,48 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -/* Create two device instances and simulate full WireGuard interaction - * without network dependencies - */ - -import "testing" - -func TestDevice(t *testing.T) { - - // prepare tun devices for generating traffic - - tun1, err := CreateDummyTUN("tun1") - if err != nil { - t.Error("failed to create tun:", err.Error()) - } - - tun2, err := CreateDummyTUN("tun2") - if err != nil { - t.Error("failed to create tun:", err.Error()) - } - - _ = tun1 - _ = tun2 - - // prepare endpoints - - end1, err := CreateDummyEndpoint() - if err != nil { - t.Error("failed to create endpoint:", err.Error()) - } - - end2, err := CreateDummyEndpoint() - if err != nil { - t.Error("failed to create endpoint:", err.Error()) - } - - _ = end1 - _ = end2 - - // create binds - -} diff --git a/endpoint_test.go b/endpoint_test.go deleted file mode 100644 index fe6677c..0000000 --- a/endpoint_test.go +++ /dev/null @@ -1,53 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "math/rand" - "net" -) - -type DummyEndpoint struct { - src [16]byte - dst [16]byte -} - -func CreateDummyEndpoint() (*DummyEndpoint, error) { - var end DummyEndpoint - if _, err := rand.Read(end.src[:]); err != nil { - return nil, err - } - _, err := rand.Read(end.dst[:]) - return &end, err -} - -func (e *DummyEndpoint) ClearSrc() {} - -func (e *DummyEndpoint) SrcToString() string { - var addr net.UDPAddr - addr.IP = e.SrcIP() - addr.Port = 1000 - return addr.String() -} - -func (e *DummyEndpoint) DstToString() string { - var addr net.UDPAddr - addr.IP = e.DstIP() - addr.Port = 1000 - return addr.String() -} - -func (e *DummyEndpoint) SrcToBytes() []byte { - return e.src[:] -} - -func (e *DummyEndpoint) DstIP() net.IP { - return e.dst[:] -} - -func (e *DummyEndpoint) SrcIP() net.IP { - return e.src[:] -} diff --git a/go.mod b/go.mod index cfff5b6..49076a6 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module golang.zx2c4.com/wireguard require ( github.com/Microsoft/go-winio v0.4.11 - golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 - golang.org/x/net v0.0.0-20190213061140-3a22650c66bd - golang.org/x/sys v0.0.0-20190213121743-983097b1a8a3 + golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25 + golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95 + golang.org/x/sys v0.0.0-20190302025703-b6889370fb10 ) diff --git a/go.sum b/go.sum index c1adf80..76c5f08 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,9 @@ github.com/Microsoft/go-winio v0.4.11 h1:zoIOcVf0xPN1tnMVbTtEdI+P8OofVk3NObnwOQ6nK2Q= github.com/Microsoft/go-winio v0.4.11/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= -golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 h1:ng3VDlRp5/DHpSWl02R4rM9I+8M2rhmsuLwAMmkLQWE= -golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd h1:HuTn7WObtcDo9uEEU7rEqL0jYthdXAmZ6PP+meazmaU= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/sys v0.0.0-20190213121743-983097b1a8a3 h1:+KlxhGbYkFs8lMfwKn+2ojry1ID5eBSMXprS2u/wqCE= -golang.org/x/sys v0.0.0-20190213121743-983097b1a8a3/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25 h1:jsG6UpNLt9iAsb0S2AGW28DveNzzgmbXR+ENoPjUeIU= +golang.org/x/crypto v0.0.0-20190228161510-8dd112bcdc25/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95 h1:fY7Dsw114eJN4boqzVSbpVHO6rTdhq6/GnXeu+PKnzU= +golang.org/x/net v0.0.0-20190301231341-16b79f2e4e95/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190302025703-b6889370fb10 h1:xQJI9OEiErEQ++DoXOHqEpzsGMrAv2Q2jyCpi7DmfpQ= +golang.org/x/sys v0.0.0-20190302025703-b6889370fb10/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/helper_test.go b/helper_test.go deleted file mode 100644 index 3705c97..0000000 --- a/helper_test.go +++ /dev/null @@ -1,92 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "bytes" - "errors" - "golang.zx2c4.com/wireguard/tun" - "os" - "testing" -) - -/* Helpers for writing unit tests - */ - -type DummyTUN struct { - name string - mtu int - packets chan []byte - events chan tun.TUNEvent -} - -func (tun *DummyTUN) File() *os.File { - return nil -} - -func (tun *DummyTUN) Name() (string, error) { - return tun.name, nil -} - -func (tun *DummyTUN) MTU() (int, error) { - return tun.mtu, nil -} - -func (tun *DummyTUN) Write(d []byte, offset int) (int, error) { - tun.packets <- d[offset:] - return len(d), nil -} - -func (tun *DummyTUN) Close() error { - close(tun.events) - close(tun.packets) - return nil -} - -func (tun *DummyTUN) Events() chan tun.TUNEvent { - return tun.events -} - -func (tun *DummyTUN) Read(d []byte, offset int) (int, error) { - t, ok := <-tun.packets - if !ok { - return 0, errors.New("device closed") - } - copy(d[offset:], t) - return len(t), nil -} - -func CreateDummyTUN(name string) (tun.TUNDevice, error) { - var dummy DummyTUN - dummy.mtu = 0 - dummy.packets = make(chan []byte, 100) - dummy.events = make(chan tun.TUNEvent, 10) - return &dummy, nil -} - -func assertNil(t *testing.T, err error) { - if err != nil { - t.Fatal(err) - } -} - -func assertEqual(t *testing.T, a []byte, b []byte) { - if bytes.Compare(a, b) != 0 { - t.Fatal(a, "!=", b) - } -} - -func randDevice(t *testing.T) *Device { - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) - } - tun, _ := CreateDummyTUN("dummy") - logger := NewLogger(LogLevelError, "") - device := NewDevice(tun, logger) - device.SetPrivateKey(sk) - return device -} diff --git a/indextable.go b/indextable.go deleted file mode 100644 index 046113c..0000000 --- a/indextable.go +++ /dev/null @@ -1,97 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "crypto/rand" - "sync" - "unsafe" -) - -type IndexTableEntry struct { - peer *Peer - handshake *Handshake - keypair *Keypair -} - -type IndexTable struct { - sync.RWMutex - table map[uint32]IndexTableEntry -} - -func randUint32() (uint32, error) { - var integer [4]byte - _, err := rand.Read(integer[:]) - return *(*uint32)(unsafe.Pointer(&integer[0])), err -} - -func (table *IndexTable) Init() { - table.Lock() - defer table.Unlock() - table.table = make(map[uint32]IndexTableEntry) -} - -func (table *IndexTable) Delete(index uint32) { - table.Lock() - defer table.Unlock() - delete(table.table, index) -} - -func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { - table.Lock() - defer table.Unlock() - entry, ok := table.table[index] - if !ok { - return - } - table.table[index] = IndexTableEntry{ - peer: entry.peer, - keypair: keypair, - handshake: nil, - } -} - -func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) { - for { - // generate random index - - index, err := randUint32() - if err != nil { - return index, err - } - - // check if index used - - table.RLock() - _, ok := table.table[index] - table.RUnlock() - if ok { - continue - } - - // check again while locked - - table.Lock() - _, found := table.table[index] - if found { - table.Unlock() - continue - } - table.table[index] = IndexTableEntry{ - peer: peer, - handshake: handshake, - keypair: nil, - } - table.Unlock() - return index, nil - } -} - -func (table *IndexTable) Lookup(id uint32) IndexTableEntry { - table.RLock() - defer table.RUnlock() - return table.table[id] -} diff --git a/ip.go b/ip.go deleted file mode 100644 index e2e0ff3..0000000 --- a/ip.go +++ /dev/null @@ -1,22 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "net" -) - -const ( - IPv4offsetTotalLength = 2 - IPv4offsetSrc = 12 - IPv4offsetDst = IPv4offsetSrc + net.IPv4len -) - -const ( - IPv6offsetPayloadLength = 4 - IPv6offsetSrc = 8 - IPv6offsetDst = IPv6offsetSrc + net.IPv6len -) diff --git a/ipc/uapi_bsd.go b/ipc/uapi_bsd.go new file mode 100644 index 0000000..f66c386 --- /dev/null +++ b/ipc/uapi_bsd.go @@ -0,0 +1,202 @@ +// +build darwin freebsd openbsd + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package ipc + +import ( + "errors" + "fmt" + "golang.org/x/sys/unix" + "net" + "os" + "path" + "unsafe" +) + +var socketDirectory = "/var/run/wireguard" + +const ( + IpcErrorIO = -int64(unix.EIO) + IpcErrorProtocol = -int64(unix.EPROTO) + IpcErrorInvalid = -int64(unix.EINVAL) + IpcErrorPortInUse = -int64(unix.EADDRINUSE) + socketName = "%s.sock" +) + +type UAPIListener struct { + listener net.Listener // unix socket listener + connNew chan net.Conn + connErr chan error + kqueueFd int + keventFd int +} + +func (l *UAPIListener) Accept() (net.Conn, error) { + for { + select { + case conn := <-l.connNew: + return conn, nil + + case err := <-l.connErr: + return nil, err + } + } +} + +func (l *UAPIListener) Close() error { + err1 := unix.Close(l.kqueueFd) + err2 := unix.Close(l.keventFd) + err3 := l.listener.Close() + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + return err3 +} + +func (l *UAPIListener) Addr() net.Addr { + return l.listener.Addr() +} + +func UAPIListen(name string, file *os.File) (net.Listener, error) { + + // wrap file in listener + + listener, err := net.FileListener(file) + if err != nil { + return nil, err + } + + uapi := &UAPIListener{ + listener: listener, + connNew: make(chan net.Conn, 1), + connErr: make(chan error, 1), + } + + if unixListener, ok := listener.(*net.UnixListener); ok { + unixListener.SetUnlinkOnClose(true) + } + + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + + // watch for deletion of socket + + uapi.kqueueFd, err = unix.Kqueue() + if err != nil { + return nil, err + } + uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0) + if err != nil { + unix.Close(uapi.kqueueFd) + return nil, err + } + + go func(l *UAPIListener) { + event := unix.Kevent_t{ + Filter: unix.EVFILT_VNODE, + Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT, + Fflags: unix.NOTE_WRITE, + } + // Allow this assignment to work with both the 32-bit and 64-bit version + // of the above struct. If you know another way, please submit a patch. + *(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd) + events := make([]unix.Kevent_t, 1) + n := 1 + var kerr error + for { + // start with lstat to avoid race condition + if _, err := os.Lstat(socketPath); os.IsNotExist(err) { + l.connErr <- err + return + } + if kerr != nil || n != 1 { + if kerr != nil { + l.connErr <- kerr + } else { + l.connErr <- errors.New("kqueue returned empty") + } + return + } + n, kerr = unix.Kevent(uapi.kqueueFd, []unix.Kevent_t{event}, events, nil) + } + }(uapi) + + // watch for new connections + + go func(l *UAPIListener) { + for { + conn, err := l.listener.Accept() + if err != nil { + l.connErr <- err + break + } + l.connNew <- conn + } + }(uapi) + + return uapi, nil +} + +func UAPIOpen(name string) (*os.File, error) { + + // check if path exist + + err := os.MkdirAll(socketDirectory, 0755) + if err != nil && !os.IsExist(err) { + return nil, err + } + + // open UNIX socket + + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + + addr, err := net.ResolveUnixAddr("unix", socketPath) + if err != nil { + return nil, err + } + + oldUmask := unix.Umask(0077) + listener, err := func() (*net.UnixListener, error) { + + // initial connection attempt + + listener, err := net.ListenUnix("unix", addr) + if err == nil { + return listener, nil + } + + // check if socket already active + + _, err = net.Dial("unix", socketPath) + if err == nil { + return nil, errors.New("unix socket in use") + } + + // cleanup & attempt again + + err = os.Remove(socketPath) + if err != nil { + return nil, err + } + return net.ListenUnix("unix", addr) + }() + unix.Umask(oldUmask) + + if err != nil { + return nil, err + } + + return listener.File() +} diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go new file mode 100644 index 0000000..8af3d8c --- /dev/null +++ b/ipc/uapi_linux.go @@ -0,0 +1,199 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package ipc + +import ( + "errors" + "fmt" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/rwcancel" + "net" + "os" + "path" +) + +var socketDirectory = "/var/run/wireguard" + +const ( + IpcErrorIO = -int64(unix.EIO) + IpcErrorProtocol = -int64(unix.EPROTO) + IpcErrorInvalid = -int64(unix.EINVAL) + IpcErrorPortInUse = -int64(unix.EADDRINUSE) + socketName = "%s.sock" +) + +type UAPIListener struct { + listener net.Listener // unix socket listener + connNew chan net.Conn + connErr chan error + inotifyFd int + inotifyRWCancel *rwcancel.RWCancel +} + +func (l *UAPIListener) Accept() (net.Conn, error) { + for { + select { + case conn := <-l.connNew: + return conn, nil + + case err := <-l.connErr: + return nil, err + } + } +} + +func (l *UAPIListener) Close() error { + err1 := unix.Close(l.inotifyFd) + err2 := l.inotifyRWCancel.Cancel() + err3 := l.listener.Close() + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + return err3 +} + +func (l *UAPIListener) Addr() net.Addr { + return l.listener.Addr() +} + +func UAPIListen(name string, file *os.File) (net.Listener, error) { + + // wrap file in listener + + listener, err := net.FileListener(file) + if err != nil { + return nil, err + } + + if unixListener, ok := listener.(*net.UnixListener); ok { + unixListener.SetUnlinkOnClose(true) + } + + uapi := &UAPIListener{ + listener: listener, + connNew: make(chan net.Conn, 1), + connErr: make(chan error, 1), + } + + // watch for deletion of socket + + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + + uapi.inotifyFd, err = unix.InotifyInit() + if err != nil { + return nil, err + } + + _, err = unix.InotifyAddWatch( + uapi.inotifyFd, + socketPath, + unix.IN_ATTRIB| + unix.IN_DELETE| + unix.IN_DELETE_SELF, + ) + + if err != nil { + return nil, err + } + + uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd) + if err != nil { + unix.Close(uapi.inotifyFd) + return nil, err + } + + go func(l *UAPIListener) { + var buff [0]byte + for { + // start with lstat to avoid race condition + if _, err := os.Lstat(socketPath); os.IsNotExist(err) { + l.connErr <- err + return + } + _, err := uapi.inotifyRWCancel.Read(buff[:]) + if err != nil { + l.connErr <- err + return + } + } + }(uapi) + + // watch for new connections + + go func(l *UAPIListener) { + for { + conn, err := l.listener.Accept() + if err != nil { + l.connErr <- err + break + } + l.connNew <- conn + } + }(uapi) + + return uapi, nil +} + +func UAPIOpen(name string) (*os.File, error) { + + // check if path exist + + err := os.MkdirAll(socketDirectory, 0755) + if err != nil && !os.IsExist(err) { + return nil, err + } + + // open UNIX socket + + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + + addr, err := net.ResolveUnixAddr("unix", socketPath) + if err != nil { + return nil, err + } + + oldUmask := unix.Umask(0077) + listener, err := func() (*net.UnixListener, error) { + + // initial connection attempt + + listener, err := net.ListenUnix("unix", addr) + if err == nil { + return listener, nil + } + + // check if socket already active + + _, err = net.Dial("unix", socketPath) + if err == nil { + return nil, errors.New("unix socket in use") + } + + // cleanup & attempt again + + err = os.Remove(socketPath) + if err != nil { + return nil, err + } + return net.ListenUnix("unix", addr) + }() + unix.Umask(oldUmask) + + if err != nil { + return nil, err + } + + return listener.File() +} diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go new file mode 100644 index 0000000..209d0d2 --- /dev/null +++ b/ipc/uapi_windows.go @@ -0,0 +1,76 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package ipc + +import ( + "github.com/Microsoft/go-winio" + "net" +) + +//TODO: replace these with actual standard windows error numbers from the win package +const ( + IpcErrorIO = -int64(5) + IpcErrorProtocol = -int64(71) + IpcErrorInvalid = -int64(22) + IpcErrorPortInUse = -int64(98) +) + +type UAPIListener struct { + listener net.Listener // unix socket listener + connNew chan net.Conn + connErr chan error + kqueueFd int + keventFd int +} + +func (l *UAPIListener) Accept() (net.Conn, error) { + for { + select { + case conn := <-l.connNew: + return conn, nil + + case err := <-l.connErr: + return nil, err + } + } +} + +func (l *UAPIListener) Close() error { + return l.listener.Close() +} + +func (l *UAPIListener) Addr() net.Addr { + return l.listener.Addr() +} + +func UAPIListen(name string) (net.Listener, error) { + config := winio.PipeConfig{ + SecurityDescriptor: "", //TODO: we want this to be a very locked down pipe. + } + listener, err := winio.ListenPipe("\\\\.\\pipe\\wireguard\\"+name, &config) //TODO: choose sane name. + if err != nil { + return nil, err + } + + uapi := &UAPIListener{ + listener: listener, + connNew: make(chan net.Conn, 1), + connErr: make(chan error, 1), + } + + go func(l *UAPIListener) { + for { + conn, err := l.listener.Accept() + if err != nil { + l.connErr <- err + break + } + l.connNew <- conn + } + }(uapi) + + return uapi, nil +} diff --git a/kdf_test.go b/kdf_test.go deleted file mode 100644 index 3b9a8be..0000000 --- a/kdf_test.go +++ /dev/null @@ -1,84 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "encoding/hex" - "golang.org/x/crypto/blake2s" - "testing" -) - -type KDFTest struct { - key string - input string - t0 string - t1 string - t2 string -} - -func assertEquals(t *testing.T, a string, b string) { - if a != b { - t.Fatal("expected", a, "=", b) - } -} - -func TestKDF(t *testing.T) { - tests := []KDFTest{ - { - key: "746573742d6b6579", - input: "746573742d696e707574", - t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", - t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", - t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", - }, - { - key: "776972656775617264", - input: "776972656775617264", - t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", - t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", - t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", - }, - { - key: "", - input: "", - t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", - t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", - t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", - }, - } - - var t0, t1, t2 [blake2s.Size]byte - - for _, test := range tests { - key, _ := hex.DecodeString(test.key) - input, _ := hex.DecodeString(test.input) - KDF3(&t0, &t1, &t2, key, input) - t0s := hex.EncodeToString(t0[:]) - t1s := hex.EncodeToString(t1[:]) - t2s := hex.EncodeToString(t2[:]) - assertEquals(t, t0s, test.t0) - assertEquals(t, t1s, test.t1) - assertEquals(t, t2s, test.t2) - } - - for _, test := range tests { - key, _ := hex.DecodeString(test.key) - input, _ := hex.DecodeString(test.input) - KDF2(&t0, &t1, key, input) - t0s := hex.EncodeToString(t0[:]) - t1s := hex.EncodeToString(t1[:]) - assertEquals(t, t0s, test.t0) - assertEquals(t, t1s, test.t1) - } - - for _, test := range tests { - key, _ := hex.DecodeString(test.key) - input, _ := hex.DecodeString(test.input) - KDF1(&t0, key, input) - t0s := hex.EncodeToString(t0[:]) - assertEquals(t, t0s, test.t0) - } -} diff --git a/keypair.go b/keypair.go deleted file mode 100644 index af10a58..0000000 --- a/keypair.go +++ /dev/null @@ -1,50 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "crypto/cipher" - "golang.zx2c4.com/wireguard/replay" - "sync" - "time" -) - -/* Due to limitations in Go and /x/crypto there is currently - * no way to ensure that key material is securely ereased in memory. - * - * Since this may harm the forward secrecy property, - * we plan to resolve this issue; whenever Go allows us to do so. - */ - -type Keypair struct { - sendNonce uint64 - send cipher.AEAD - receive cipher.AEAD - replayFilter replay.ReplayFilter - isInitiator bool - created time.Time - localIndex uint32 - remoteIndex uint32 -} - -type Keypairs struct { - sync.RWMutex - current *Keypair - previous *Keypair - next *Keypair -} - -func (kp *Keypairs) Current() *Keypair { - kp.RLock() - defer kp.RUnlock() - return kp.current -} - -func (device *Device) DeleteKeypair(key *Keypair) { - if key != nil { - device.indexTable.Delete(key.localIndex) - } -} diff --git a/logger.go b/logger.go deleted file mode 100644 index 00b1c7d..0000000 --- a/logger.go +++ /dev/null @@ -1,59 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "io" - "io/ioutil" - "log" - "os" -) - -const ( - LogLevelSilent = iota - LogLevelError - LogLevelInfo - LogLevelDebug -) - -type Logger struct { - Debug *log.Logger - Info *log.Logger - Error *log.Logger -} - -func NewLogger(level int, prepend string) *Logger { - output := os.Stdout - logger := new(Logger) - - logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) { - if level >= LogLevelDebug { - return output, output, output - } - if level >= LogLevelInfo { - return output, output, ioutil.Discard - } - if level >= LogLevelError { - return output, ioutil.Discard, ioutil.Discard - } - return ioutil.Discard, ioutil.Discard, ioutil.Discard - }() - - logger.Debug = log.New(logDebug, - "DEBUG: "+prepend, - log.Ldate|log.Ltime, - ) - - logger.Info = log.New(logInfo, - "INFO: "+prepend, - log.Ldate|log.Ltime, - ) - logger.Error = log.New(logErr, - "ERROR: "+prepend, - log.Ldate|log.Ltime, - ) - return logger -} diff --git a/main.go b/main.go index 08f8cc6..a3a04b8 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,8 @@ package main import ( "fmt" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" "os" "os/signal" @@ -76,7 +78,7 @@ func warning() { func main() { if len(os.Args) == 2 && os.Args[1] == "--version" { - fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld .\n", WireGuardGoVersion, runtime.GOOS, runtime.GOARCH) + fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld .\n", device.WireGuardGoVersion, runtime.GOOS, runtime.GOARCH) return } @@ -119,15 +121,15 @@ func main() { logLevel := func() int { switch os.Getenv("LOG_LEVEL") { case "debug": - return LogLevelDebug + return device.LogLevelDebug case "info": - return LogLevelInfo + return device.LogLevelInfo case "error": - return LogLevelError + return device.LogLevelError case "silent": - return LogLevelSilent + return device.LogLevelSilent } - return LogLevelInfo + return device.LogLevelInfo }() // open TUN device (or use supplied fd) @@ -135,7 +137,7 @@ func main() { tun, err := func() (tun.TUNDevice, error) { tunFdStr := os.Getenv(ENV_WG_TUN_FD) if tunFdStr == "" { - return tun.CreateTUN(interfaceName, DefaultMTU) + return tun.CreateTUN(interfaceName, device.DefaultMTU) } // construct tun device from supplied fd @@ -151,7 +153,7 @@ func main() { } file := os.NewFile(uintptr(fd), "") - return tun.CreateTUNFromFile(file, DefaultMTU) + return tun.CreateTUNFromFile(file, device.DefaultMTU) }() if err == nil { @@ -161,12 +163,12 @@ func main() { } } - logger := NewLogger( + logger := device.NewLogger( logLevel, fmt.Sprintf("(%s) ", interfaceName), ) - logger.Info.Println("Starting wireguard-go version", WireGuardGoVersion) + logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion) logger.Debug.Println("Debug log enabled") @@ -180,7 +182,7 @@ func main() { fileUAPI, err := func() (*os.File, error) { uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) if uapiFdStr == "" { - return UAPIOpen(interfaceName) + return ipc.UAPIOpen(interfaceName) } // use supplied fd @@ -206,7 +208,7 @@ func main() { env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND)) files := [3]*os.File{} - if os.Getenv("LOG_LEVEL") != "" && logLevel != LogLevelSilent { + if os.Getenv("LOG_LEVEL") != "" && logLevel != device.LogLevelSilent { files[0], _ = os.Open(os.DevNull) files[1] = os.Stdout files[2] = os.Stderr @@ -246,14 +248,14 @@ func main() { return } - device := NewDevice(tun, logger) + device := device.NewDevice(tun, logger) logger.Info.Println("Device started") errs := make(chan error) term := make(chan os.Signal, 1) - uapi, err := UAPIListen(interfaceName, fileUAPI) + uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) if err != nil { logger.Error.Println("Failed to listen on uapi socket:", err) os.Exit(ExitSetupFailed) @@ -266,7 +268,7 @@ func main() { errs <- err return } - go ipcHandle(device, conn) + go device.IpcHandle(conn) } }() diff --git a/main_windows.go b/main_windows.go index 7104a20..39cdead 100644 --- a/main_windows.go +++ b/main_windows.go @@ -7,6 +7,8 @@ package main import ( "fmt" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/ipc" "os" "os/signal" "syscall" @@ -25,8 +27,8 @@ func main() { } interfaceName := os.Args[1] - logger := NewLogger( - LogLevelDebug, + logger := device.NewLogger( + device.LogLevelDebug, fmt.Sprintf("(%s) ", interfaceName), ) logger.Info.Println("Starting wireguard-go version", WireGuardGoVersion) @@ -43,11 +45,11 @@ func main() { os.Exit(ExitSetupFailed) } - device := NewDevice(tun, logger) + device := device.NewDevice(tun, logger) device.Up() logger.Info.Println("Device started") - uapi, err := UAPIListen(interfaceName) + uapi, err := ipc.UAPIListen(interfaceName) if err != nil { logger.Error.Println("Failed to listen on uapi socket:", err) os.Exit(ExitSetupFailed) @@ -63,7 +65,7 @@ func main() { errs <- err return } - go ipcHandle(device, conn) + go device.IpcHandle(conn) } }() logger.Info.Println("UAPI listener started") diff --git a/mark_default.go b/mark_default.go deleted file mode 100644 index 7149d69..0000000 --- a/mark_default.go +++ /dev/null @@ -1,12 +0,0 @@ -// +build !linux,!openbsd,!freebsd - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -func (bind *NativeBind) SetMark(mark uint32) error { - return nil -} diff --git a/mark_unix.go b/mark_unix.go deleted file mode 100644 index 0ae62b7..0000000 --- a/mark_unix.go +++ /dev/null @@ -1,64 +0,0 @@ -// +build android openbsd freebsd - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "golang.org/x/sys/unix" - "runtime" -) - -var fwmarkIoctl int - -func init() { - switch runtime.GOOS { - case "linux", "android": - fwmarkIoctl = 36 /* unix.SO_MARK */ - case "freebsd": - fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ - case "openbsd": - fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ - } -} - -func (bind *NativeBind) SetMark(mark uint32) error { - var operr error - if fwmarkIoctl == 0 { - return nil - } - if bind.ipv4 != nil { - fd, err := bind.ipv4.SyscallConn() - if err != nil { - return err - } - err = fd.Control(func(fd uintptr) { - operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) - }) - if err == nil { - err = operr - } - if err != nil { - return err - } - } - if bind.ipv6 != nil { - fd, err := bind.ipv6.SyscallConn() - if err != nil { - return err - } - err = fd.Control(func(fd uintptr) { - operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) - }) - if err == nil { - err = operr - } - if err != nil { - return err - } - } - return nil -} diff --git a/misc.go b/misc.go deleted file mode 100644 index 6786cb5..0000000 --- a/misc.go +++ /dev/null @@ -1,48 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "sync/atomic" -) - -/* Atomic Boolean */ - -const ( - AtomicFalse = int32(iota) - AtomicTrue -) - -type AtomicBool struct { - int32 -} - -func (a *AtomicBool) Get() bool { - return atomic.LoadInt32(&a.int32) == AtomicTrue -} - -func (a *AtomicBool) Swap(val bool) bool { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - return atomic.SwapInt32(&a.int32, flag) == AtomicTrue -} - -func (a *AtomicBool) Set(val bool) { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - atomic.StoreInt32(&a.int32, flag) -} - -func min(a, b uint) uint { - if a > b { - return b - } - return a -} diff --git a/noise-helpers.go b/noise-helpers.go deleted file mode 100644 index af11f09..0000000 --- a/noise-helpers.go +++ /dev/null @@ -1,104 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "crypto/hmac" - "crypto/rand" - "crypto/subtle" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/curve25519" - "hash" -) - -/* KDF related functions. - * HMAC-based Key Derivation Function (HKDF) - * https://tools.ietf.org/html/rfc5869 - */ - -func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { - mac := hmac.New(func() hash.Hash { - h, _ := blake2s.New256(nil) - return h - }, key) - mac.Write(in0) - mac.Sum(sum[:0]) -} - -func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { - mac := hmac.New(func() hash.Hash { - h, _ := blake2s.New256(nil) - return h - }, key) - mac.Write(in0) - mac.Write(in1) - mac.Sum(sum[:0]) -} - -func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { - HMAC1(t0, key, input) - HMAC1(t0, t0[:], []byte{0x1}) - return -} - -func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { - var prk [blake2s.Size]byte - HMAC1(&prk, key, input) - HMAC1(t0, prk[:], []byte{0x1}) - HMAC2(t1, prk[:], t0[:], []byte{0x2}) - setZero(prk[:]) - return -} - -func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { - var prk [blake2s.Size]byte - HMAC1(&prk, key, input) - HMAC1(t0, prk[:], []byte{0x1}) - HMAC2(t1, prk[:], t0[:], []byte{0x2}) - HMAC2(t2, prk[:], t1[:], []byte{0x3}) - setZero(prk[:]) - return -} - -func isZero(val []byte) bool { - acc := 1 - for _, b := range val { - acc &= subtle.ConstantTimeByteEq(b, 0) - } - return acc == 1 -} - -/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ -func setZero(arr []byte) { - for i := range arr { - arr[i] = 0 - } -} - -func (sk *NoisePrivateKey) clamp() { - sk[0] &= 248 - sk[31] = (sk[31] & 127) | 64 -} - -func newPrivateKey() (sk NoisePrivateKey, err error) { - _, err = rand.Read(sk[:]) - sk.clamp() - return -} - -func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { - apk := (*[NoisePublicKeySize]byte)(&pk) - ask := (*[NoisePrivateKeySize]byte)(sk) - curve25519.ScalarBaseMult(apk, ask) - return -} - -func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { - apk := (*[NoisePublicKeySize]byte)(&pk) - ask := (*[NoisePrivateKeySize]byte)(sk) - curve25519.ScalarMult(&ss, ask, apk) - return ss -} diff --git a/noise-protocol.go b/noise-protocol.go deleted file mode 100644 index fb43413..0000000 --- a/noise-protocol.go +++ /dev/null @@ -1,600 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "errors" - "golang.org/x/crypto/blake2s" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/poly1305" - "golang.zx2c4.com/wireguard/tai64n" - "sync" - "time" -) - -const ( - HandshakeZeroed = iota - HandshakeInitiationCreated - HandshakeInitiationConsumed - HandshakeResponseCreated - HandshakeResponseConsumed -) - -const ( - NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" - WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" - WGLabelMAC1 = "mac1----" - WGLabelCookie = "cookie--" -) - -const ( - MessageInitiationType = 1 - MessageResponseType = 2 - MessageCookieReplyType = 3 - MessageTransportType = 4 -) - -const ( - MessageInitiationSize = 148 // size of handshake initation message - MessageResponseSize = 92 // size of response message - MessageCookieReplySize = 64 // size of cookie reply message - MessageTransportHeaderSize = 16 // size of data preceeding content in transport message - MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport - MessageKeepaliveSize = MessageTransportSize // size of keepalive - MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message -) - -const ( - MessageTransportOffsetReceiver = 4 - MessageTransportOffsetCounter = 8 - MessageTransportOffsetContent = 16 -) - -/* Type is an 8-bit field, followed by 3 nul bytes, - * by marshalling the messages in little-endian byteorder - * we can treat these as a 32-bit unsigned int (for now) - * - */ - -type MessageInitiation struct { - Type uint32 - Sender uint32 - Ephemeral NoisePublicKey - Static [NoisePublicKeySize + poly1305.TagSize]byte - Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte - MAC1 [blake2s.Size128]byte - MAC2 [blake2s.Size128]byte -} - -type MessageResponse struct { - Type uint32 - Sender uint32 - Receiver uint32 - Ephemeral NoisePublicKey - Empty [poly1305.TagSize]byte - MAC1 [blake2s.Size128]byte - MAC2 [blake2s.Size128]byte -} - -type MessageTransport struct { - Type uint32 - Receiver uint32 - Counter uint64 - Content []byte -} - -type MessageCookieReply struct { - Type uint32 - Receiver uint32 - Nonce [chacha20poly1305.NonceSizeX]byte - Cookie [blake2s.Size128 + poly1305.TagSize]byte -} - -type Handshake struct { - state int - mutex sync.RWMutex - hash [blake2s.Size]byte // hash value - chainKey [blake2s.Size]byte // chain key - presharedKey NoiseSymmetricKey // psk - localEphemeral NoisePrivateKey // ephemeral secret key - localIndex uint32 // used to clear hash-table - remoteIndex uint32 // index for sending - remoteStatic NoisePublicKey // long term key - remoteEphemeral NoisePublicKey // ephemeral public key - precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret - lastTimestamp tai64n.Timestamp - lastInitiationConsumption time.Time - lastSentHandshake time.Time -} - -var ( - InitialChainKey [blake2s.Size]byte - InitialHash [blake2s.Size]byte - ZeroNonce [chacha20poly1305.NonceSize]byte -) - -func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { - KDF1(dst, c[:], data) -} - -func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { - hash, _ := blake2s.New256(nil) - hash.Write(h[:]) - hash.Write(data) - hash.Sum(dst[:0]) - hash.Reset() -} - -func (h *Handshake) Clear() { - setZero(h.localEphemeral[:]) - setZero(h.remoteEphemeral[:]) - setZero(h.chainKey[:]) - setZero(h.hash[:]) - h.localIndex = 0 - h.state = HandshakeZeroed -} - -func (h *Handshake) mixHash(data []byte) { - mixHash(&h.hash, &h.hash, data) -} - -func (h *Handshake) mixKey(data []byte) { - mixKey(&h.chainKey, &h.chainKey, data) -} - -/* Do basic precomputations - */ -func init() { - InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) - mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) -} - -func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - - device.staticIdentity.RLock() - defer device.staticIdentity.RUnlock() - - handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - - if isZero(handshake.precomputedStaticStatic[:]) { - return nil, errors.New("static shared secret is zero") - } - - // create ephemeral key - - var err error - handshake.hash = InitialHash - handshake.chainKey = InitialChainKey - handshake.localEphemeral, err = newPrivateKey() - if err != nil { - return nil, err - } - - // assign index - - device.indexTable.Delete(handshake.localIndex) - handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) - - if err != nil { - return nil, err - } - - handshake.mixHash(handshake.remoteStatic[:]) - - msg := MessageInitiation{ - Type: MessageInitiationType, - Ephemeral: handshake.localEphemeral.publicKey(), - Sender: handshake.localIndex, - } - - handshake.mixKey(msg.Ephemeral[:]) - handshake.mixHash(msg.Ephemeral[:]) - - // encrypt static key - - func() { - var key [chacha20poly1305.KeySize]byte - ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - ss[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) - }() - handshake.mixHash(msg.Static[:]) - - // encrypt timestamp - - timestamp := tai64n.Now() - func() { - var key [chacha20poly1305.KeySize]byte - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) - }() - - handshake.mixHash(msg.Timestamp[:]) - handshake.state = HandshakeInitiationCreated - return &msg, nil -} - -func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { - var ( - hash [blake2s.Size]byte - chainKey [blake2s.Size]byte - ) - - if msg.Type != MessageInitiationType { - return nil - } - - device.staticIdentity.RLock() - defer device.staticIdentity.RUnlock() - - mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) - mixHash(&hash, &hash, msg.Ephemeral[:]) - mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) - - // decrypt static key - - var err error - var peerPK NoisePublicKey - func() { - var key [chacha20poly1305.KeySize]byte - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - KDF2(&chainKey, &key, chainKey[:], ss[:]) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) - }() - if err != nil { - return nil - } - mixHash(&hash, &hash, msg.Static[:]) - - // lookup peer - - peer := device.LookupPeer(peerPK) - if peer == nil { - return nil - } - - handshake := &peer.handshake - if isZero(handshake.precomputedStaticStatic[:]) { - return nil - } - - // verify identity - - var timestamp tai64n.Timestamp - var key [chacha20poly1305.KeySize]byte - - handshake.mutex.RLock() - KDF2( - &chainKey, - &key, - chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) - if err != nil { - handshake.mutex.RUnlock() - return nil - } - mixHash(&hash, &hash, msg.Timestamp[:]) - - // protect against replay & flood - - var ok bool - ok = timestamp.After(handshake.lastTimestamp) - ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate - handshake.mutex.RUnlock() - if !ok { - return nil - } - - // update handshake state - - handshake.mutex.Lock() - - handshake.hash = hash - handshake.chainKey = chainKey - handshake.remoteIndex = msg.Sender - handshake.remoteEphemeral = msg.Ephemeral - handshake.lastTimestamp = timestamp - handshake.lastInitiationConsumption = time.Now() - handshake.state = HandshakeInitiationConsumed - - handshake.mutex.Unlock() - - setZero(hash[:]) - setZero(chainKey[:]) - - return peer -} - -func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { - handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - - if handshake.state != HandshakeInitiationConsumed { - return nil, errors.New("handshake initiation must be consumed first") - } - - // assign index - - var err error - device.indexTable.Delete(handshake.localIndex) - handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) - if err != nil { - return nil, err - } - - var msg MessageResponse - msg.Type = MessageResponseType - msg.Sender = handshake.localIndex - msg.Receiver = handshake.remoteIndex - - // create ephemeral key - - handshake.localEphemeral, err = newPrivateKey() - if err != nil { - return nil, err - } - msg.Ephemeral = handshake.localEphemeral.publicKey() - handshake.mixHash(msg.Ephemeral[:]) - handshake.mixKey(msg.Ephemeral[:]) - - func() { - ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) - handshake.mixKey(ss[:]) - ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.mixKey(ss[:]) - }() - - // add preshared key - - var tau [blake2s.Size]byte - var key [chacha20poly1305.KeySize]byte - - KDF3( - &handshake.chainKey, - &tau, - &key, - handshake.chainKey[:], - handshake.presharedKey[:], - ) - - handshake.mixHash(tau[:]) - - func() { - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) - handshake.mixHash(msg.Empty[:]) - }() - - handshake.state = HandshakeResponseCreated - - return &msg, nil -} - -func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { - if msg.Type != MessageResponseType { - return nil - } - - // lookup handshake by receiver - - lookup := device.indexTable.Lookup(msg.Receiver) - handshake := lookup.handshake - if handshake == nil { - return nil - } - - var ( - hash [blake2s.Size]byte - chainKey [blake2s.Size]byte - ) - - ok := func() bool { - - // lock handshake state - - handshake.mutex.RLock() - defer handshake.mutex.RUnlock() - - if handshake.state != HandshakeInitiationCreated { - return false - } - - // lock private key for reading - - device.staticIdentity.RLock() - defer device.staticIdentity.RUnlock() - - // finish 3-way DH - - mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) - mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) - - func() { - ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() - - func() { - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() - - // add preshared key (psk) - - var tau [blake2s.Size]byte - var key [chacha20poly1305.KeySize]byte - KDF3( - &chainKey, - &tau, - &key, - chainKey[:], - handshake.presharedKey[:], - ) - mixHash(&hash, &hash, tau[:]) - - // authenticate transcript - - aead, _ := chacha20poly1305.New(key[:]) - _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) - if err != nil { - return false - } - mixHash(&hash, &hash, msg.Empty[:]) - return true - }() - - if !ok { - return nil - } - - // update handshake state - - handshake.mutex.Lock() - - handshake.hash = hash - handshake.chainKey = chainKey - handshake.remoteIndex = msg.Sender - handshake.state = HandshakeResponseConsumed - - handshake.mutex.Unlock() - - setZero(hash[:]) - setZero(chainKey[:]) - - return lookup.peer -} - -/* Derives a new keypair from the current handshake state - * - */ -func (peer *Peer) BeginSymmetricSession() error { - device := peer.device - handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - - // derive keys - - var isInitiator bool - var sendKey [chacha20poly1305.KeySize]byte - var recvKey [chacha20poly1305.KeySize]byte - - if handshake.state == HandshakeResponseConsumed { - KDF2( - &sendKey, - &recvKey, - handshake.chainKey[:], - nil, - ) - isInitiator = true - } else if handshake.state == HandshakeResponseCreated { - KDF2( - &recvKey, - &sendKey, - handshake.chainKey[:], - nil, - ) - isInitiator = false - } else { - return errors.New("invalid state for keypair derivation") - } - - // zero handshake - - setZero(handshake.chainKey[:]) - setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. - setZero(handshake.localEphemeral[:]) - peer.handshake.state = HandshakeZeroed - - // create AEAD instances - - keypair := new(Keypair) - keypair.send, _ = chacha20poly1305.New(sendKey[:]) - keypair.receive, _ = chacha20poly1305.New(recvKey[:]) - - setZero(sendKey[:]) - setZero(recvKey[:]) - - keypair.created = time.Now() - keypair.sendNonce = 0 - keypair.replayFilter.Init() - keypair.isInitiator = isInitiator - keypair.localIndex = peer.handshake.localIndex - keypair.remoteIndex = peer.handshake.remoteIndex - - // remap index - - device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) - handshake.localIndex = 0 - - // rotate key pairs - - keypairs := &peer.keypairs - keypairs.Lock() - defer keypairs.Unlock() - - previous := keypairs.previous - next := keypairs.next - current := keypairs.current - - if isInitiator { - if next != nil { - keypairs.next = nil - keypairs.previous = next - device.DeleteKeypair(current) - } else { - keypairs.previous = current - } - device.DeleteKeypair(previous) - keypairs.current = keypair - } else { - keypairs.next = keypair - device.DeleteKeypair(next) - keypairs.previous = nil - device.DeleteKeypair(previous) - } - - return nil -} - -func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { - keypairs := &peer.keypairs - if keypairs.next != receivedKeypair { - return false - } - keypairs.Lock() - defer keypairs.Unlock() - if keypairs.next != receivedKeypair { - return false - } - old := keypairs.previous - keypairs.previous = keypairs.current - peer.device.DeleteKeypair(old) - keypairs.current = keypairs.next - keypairs.next = nil - return true -} diff --git a/noise-types.go b/noise-types.go deleted file mode 100644 index 902905e..0000000 --- a/noise-types.go +++ /dev/null @@ -1,81 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "crypto/subtle" - "encoding/hex" - "errors" - "golang.org/x/crypto/chacha20poly1305" -) - -const ( - NoisePublicKeySize = 32 - NoisePrivateKeySize = 32 -) - -type ( - NoisePublicKey [NoisePublicKeySize]byte - NoisePrivateKey [NoisePrivateKeySize]byte - NoiseSymmetricKey [chacha20poly1305.KeySize]byte - NoiseNonce uint64 // padded to 12-bytes -) - -func loadExactHex(dst []byte, src string) error { - slice, err := hex.DecodeString(src) - if err != nil { - return err - } - if len(slice) != len(dst) { - return errors.New("hex string does not fit the slice") - } - copy(dst, slice) - return nil -} - -func (key NoisePrivateKey) IsZero() bool { - var zero NoisePrivateKey - return key.Equals(zero) -} - -func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { - return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 -} - -func (key *NoisePrivateKey) FromHex(src string) (err error) { - err = loadExactHex(key[:], src) - key.clamp() - return -} - -func (key NoisePrivateKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - -func (key *NoisePublicKey) FromHex(src string) error { - return loadExactHex(key[:], src) -} - -func (key NoisePublicKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - -func (key NoisePublicKey) IsZero() bool { - var zero NoisePublicKey - return key.Equals(zero) -} - -func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { - return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 -} - -func (key *NoiseSymmetricKey) FromHex(src string) error { - return loadExactHex(key[:], src) -} - -func (key NoiseSymmetricKey) ToHex() string { - return hex.EncodeToString(key[:]) -} diff --git a/noise_test.go b/noise_test.go deleted file mode 100644 index 116057a..0000000 --- a/noise_test.go +++ /dev/null @@ -1,144 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "bytes" - "encoding/binary" - "testing" -) - -func TestCurveWrappers(t *testing.T) { - sk1, err := newPrivateKey() - assertNil(t, err) - - sk2, err := newPrivateKey() - assertNil(t, err) - - pk1 := sk1.publicKey() - pk2 := sk2.publicKey() - - ss1 := sk1.sharedSecret(pk2) - ss2 := sk2.sharedSecret(pk1) - - if ss1 != ss2 { - t.Fatal("Failed to compute shared secet") - } -} - -func TestNoiseHandshake(t *testing.T) { - dev1 := randDevice(t) - dev2 := randDevice(t) - - defer dev1.Close() - defer dev2.Close() - - peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) - peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) - - assertEqual( - t, - peer1.handshake.precomputedStaticStatic[:], - peer2.handshake.precomputedStaticStatic[:], - ) - - /* simulate handshake */ - - // initiation message - - t.Log("exchange initiation message") - - msg1, err := dev1.CreateMessageInitiation(peer2) - assertNil(t, err) - - packet := make([]byte, 0, 256) - writer := bytes.NewBuffer(packet) - err = binary.Write(writer, binary.LittleEndian, msg1) - assertNil(t, err) - peer := dev2.ConsumeMessageInitiation(msg1) - if peer == nil { - t.Fatal("handshake failed at initiation message") - } - - assertEqual( - t, - peer1.handshake.chainKey[:], - peer2.handshake.chainKey[:], - ) - - assertEqual( - t, - peer1.handshake.hash[:], - peer2.handshake.hash[:], - ) - - // response message - - t.Log("exchange response message") - - msg2, err := dev2.CreateMessageResponse(peer1) - assertNil(t, err) - - peer = dev1.ConsumeMessageResponse(msg2) - if peer == nil { - t.Fatal("handshake failed at response message") - } - - assertEqual( - t, - peer1.handshake.chainKey[:], - peer2.handshake.chainKey[:], - ) - - assertEqual( - t, - peer1.handshake.hash[:], - peer2.handshake.hash[:], - ) - - // key pairs - - t.Log("deriving keys") - - err = peer1.BeginSymmetricSession() - if err != nil { - t.Fatal("failed to derive keypair for peer 1", err) - } - - err = peer2.BeginSymmetricSession() - if err != nil { - t.Fatal("failed to derive keypair for peer 2", err) - } - - key1 := peer1.keypairs.next - key2 := peer2.keypairs.current - - // encrypting / decryption test - - t.Log("test key pairs") - - func() { - testMsg := []byte("wireguard test message 1") - var err error - var out []byte - var nonce [12]byte - out = key1.send.Seal(out, nonce[:], testMsg, nil) - out, err = key2.receive.Open(out[:0], nonce[:], out, nil) - assertNil(t, err) - assertEqual(t, out, testMsg) - }() - - func() { - testMsg := []byte("wireguard test message 2") - var err error - var out []byte - var nonce [12]byte - out = key2.send.Seal(out, nonce[:], testMsg, nil) - out, err = key1.receive.Open(out[:0], nonce[:], out, nil) - assertNil(t, err) - assertEqual(t, out, testMsg) - }() -} diff --git a/peer.go b/peer.go deleted file mode 100644 index f021565..0000000 --- a/peer.go +++ /dev/null @@ -1,270 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "encoding/base64" - "errors" - "fmt" - "sync" - "time" -) - -const ( - PeerRoutineNumber = 3 -) - -type Peer struct { - isRunning AtomicBool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer - keypairs Keypairs - handshake Handshake - device *Device - endpoint Endpoint - persistentKeepaliveInterval uint16 - - // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly - stats struct { - txBytes uint64 // bytes send to peer (endpoint) - rxBytes uint64 // bytes received from peer - lastHandshakeNano int64 // nano seconds since epoch - } - - timers struct { - retransmitHandshake *Timer - sendKeepalive *Timer - newHandshake *Timer - zeroKeyMaterial *Timer - persistentKeepalive *Timer - handshakeAttempts uint32 - needAnotherKeepalive AtomicBool - sentLastMinuteHandshake AtomicBool - } - - signals struct { - newKeypairArrived chan struct{} - flushNonceQueue chan struct{} - } - - queue struct { - nonce chan *QueueOutboundElement // nonce / pre-handshake queue - outbound chan *QueueOutboundElement // sequential ordering of work - inbound chan *QueueInboundElement // sequential ordering of work - packetInNonceQueueIsAwaitingKey AtomicBool - } - - routines struct { - sync.Mutex // held when stopping / starting routines - starting sync.WaitGroup // routines pending start - stopping sync.WaitGroup // routines pending stop - stop chan struct{} // size 0, stop all go routines in peer - } - - cookieGenerator CookieGenerator -} - -func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { - - if device.isClosed.Get() { - return nil, errors.New("device closed") - } - - // lock resources - - device.staticIdentity.RLock() - defer device.staticIdentity.RUnlock() - - device.peers.Lock() - defer device.peers.Unlock() - - // check if over limit - - if len(device.peers.keyMap) >= MaxPeers { - return nil, errors.New("too many peers") - } - - // create peer - - peer := new(Peer) - peer.Lock() - defer peer.Unlock() - - peer.cookieGenerator.Init(pk) - peer.device = device - peer.isRunning.Set(false) - - // map public key - - _, ok := device.peers.keyMap[pk] - if ok { - return nil, errors.New("adding existing peer") - } - device.peers.keyMap[pk] = peer - - // pre-compute DH - - handshake := &peer.handshake - handshake.mutex.Lock() - handshake.remoteStatic = pk - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) - handshake.mutex.Unlock() - - // reset endpoint - - peer.endpoint = nil - - // start peer - - if peer.device.isUp.Get() { - peer.Start() - } - - return peer, nil -} - -func (peer *Peer) SendBuffer(buffer []byte) error { - peer.device.net.RLock() - defer peer.device.net.RUnlock() - - if peer.device.net.bind == nil { - return errors.New("no bind") - } - - peer.RLock() - defer peer.RUnlock() - - if peer.endpoint == nil { - return errors.New("no known endpoint for peer") - } - - return peer.device.net.bind.Send(buffer, peer.endpoint) -} - -func (peer *Peer) String() string { - base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) - abbreviatedKey := "invalid" - if len(base64Key) == 44 { - abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] - } - return fmt.Sprintf("peer(%s)", abbreviatedKey) -} - -func (peer *Peer) Start() { - - // should never start a peer on a closed device - - if peer.device.isClosed.Get() { - return - } - - // prevent simultaneous start/stop operations - - peer.routines.Lock() - defer peer.routines.Unlock() - - if peer.isRunning.Get() { - return - } - - device := peer.device - device.log.Debug.Println(peer, "- Starting...") - - // reset routine state - - peer.routines.starting.Wait() - peer.routines.stopping.Wait() - peer.routines.stop = make(chan struct{}) - peer.routines.starting.Add(PeerRoutineNumber) - peer.routines.stopping.Add(PeerRoutineNumber) - - // prepare queues - - peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) - - peer.timersInit() - peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) - peer.signals.newKeypairArrived = make(chan struct{}, 1) - peer.signals.flushNonceQueue = make(chan struct{}, 1) - - // wait for routines to start - - go peer.RoutineNonce() - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() - - peer.routines.starting.Wait() - peer.isRunning.Set(true) -} - -func (peer *Peer) ZeroAndFlushAll() { - device := peer.device - - // clear key pairs - - keypairs := &peer.keypairs - keypairs.Lock() - device.DeleteKeypair(keypairs.previous) - device.DeleteKeypair(keypairs.current) - device.DeleteKeypair(keypairs.next) - keypairs.previous = nil - keypairs.current = nil - keypairs.next = nil - keypairs.Unlock() - - // clear handshake state - - handshake := &peer.handshake - handshake.mutex.Lock() - device.indexTable.Delete(handshake.localIndex) - handshake.Clear() - handshake.mutex.Unlock() - - peer.FlushNonceQueue() -} - -func (peer *Peer) Stop() { - - // prevent simultaneous start/stop operations - - if !peer.isRunning.Swap(false) { - return - } - - peer.routines.starting.Wait() - - peer.routines.Lock() - defer peer.routines.Unlock() - - peer.device.log.Debug.Println(peer, "- Stopping...") - - peer.timersStop() - - // stop & wait for ongoing peer routines - - close(peer.routines.stop) - peer.routines.stopping.Wait() - - // close queues - - close(peer.queue.nonce) - close(peer.queue.outbound) - close(peer.queue.inbound) - - peer.ZeroAndFlushAll() -} - -var roamingDisabled bool - -func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { - if roamingDisabled { - return - } - peer.Lock() - peer.endpoint = endpoint - peer.Unlock() -} diff --git a/pools.go b/pools.go deleted file mode 100644 index 8a9af0b..0000000 --- a/pools.go +++ /dev/null @@ -1,89 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import "sync" - -func (device *Device) PopulatePools() { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool = &sync.Pool{ - New: func() interface{} { - return new([MaxMessageSize]byte) - }, - } - device.pool.inboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueInboundElement) - }, - } - device.pool.outboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueOutboundElement) - }, - } - } else { - device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte) - } - device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.inboundElementReuseChan <- new(QueueInboundElement) - } - device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.outboundElementReuseChan <- new(QueueOutboundElement) - } - } -} - -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - if PreallocatedBuffersPerPool == 0 { - return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte) - } else { - return <-device.pool.messageBufferReuseChan - } -} - -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool.Put(msg) - } else { - device.pool.messageBufferReuseChan <- msg - } -} - -func (device *Device) GetInboundElement() *QueueInboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.inboundElementPool.Get().(*QueueInboundElement) - } else { - return <-device.pool.inboundElementReuseChan - } -} - -func (device *Device) PutInboundElement(msg *QueueInboundElement) { - if PreallocatedBuffersPerPool == 0 { - device.pool.inboundElementPool.Put(msg) - } else { - device.pool.inboundElementReuseChan <- msg - } -} - -func (device *Device) GetOutboundElement() *QueueOutboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.outboundElementPool.Get().(*QueueOutboundElement) - } else { - return <-device.pool.outboundElementReuseChan - } -} - -func (device *Device) PutOutboundElement(msg *QueueOutboundElement) { - if PreallocatedBuffersPerPool == 0 { - device.pool.outboundElementPool.Put(msg) - } else { - device.pool.outboundElementReuseChan <- msg - } -} diff --git a/queueconstants.go b/queueconstants.go deleted file mode 100644 index 0dcdd33..0000000 --- a/queueconstants.go +++ /dev/null @@ -1,16 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -/* Implementation specific constants */ - -const ( - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram - PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth -) diff --git a/receive.go b/receive.go deleted file mode 100644 index fb848eb..0000000 --- a/receive.go +++ /dev/null @@ -1,641 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "bytes" - "encoding/binary" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "net" - "strconv" - "sync" - "sync/atomic" - "time" -) - -type QueueHandshakeElement struct { - msgType uint32 - packet []byte - endpoint Endpoint - buffer *[MaxMessageSize]byte -} - -type QueueInboundElement struct { - dropped int32 - sync.Mutex - buffer *[MaxMessageSize]byte - packet []byte - counter uint64 - keypair *Keypair - endpoint Endpoint -} - -func (elem *QueueInboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueInboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool { - select { - case inboundQueue <- element: - select { - case decryptionQueue <- element: - return true - default: - element.Drop() - element.Unlock() - return false - } - default: - device.PutInboundElement(element) - return false - } -} - -func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool { - select { - case queue <- element: - return true - default: - return false - } -} - -/* Called when a new authenticated message has been received - * - * NOTE: Not thread safe, but called by sequential receiver! - */ -func (peer *Peer) keepKeyFreshReceiving() { - if peer.timers.sentLastMinuteHandshake.Get() { - return - } - keypair := peer.keypairs.Current() - if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { - peer.timers.sentLastMinuteHandshake.Set(true) - peer.SendHandshakeInitiation(false) - } -} - -/* Receives incoming datagrams for the device - * - * Every time the bind is updated a new routine is started for - * IPv4 and IPv6 (separately) - */ -func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { - - logDebug := device.log.Debug - defer func() { - logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") - device.net.stopping.Done() - }() - - logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started") - device.net.starting.Done() - - // receive datagrams until conn is closed - - buffer := device.GetMessageBuffer() - - var ( - err error - size int - endpoint Endpoint - ) - - for { - - // read next datagram - - switch IP { - case ipv4.Version: - size, endpoint, err = bind.ReceiveIPv4(buffer[:]) - case ipv6.Version: - size, endpoint, err = bind.ReceiveIPv6(buffer[:]) - default: - panic("invalid IP version") - } - - if err != nil { - device.PutMessageBuffer(buffer) - return - } - - if size < MinMessageSize { - continue - } - - // check size of packet - - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) - - var okay bool - - switch msgType { - - // check if transport - - case MessageTransportType: - - // check size - - if len(packet) < MessageTransportSize { - continue - } - - // lookup key pair - - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indexTable.Lookup(receiver) - keypair := value.keypair - if keypair == nil { - continue - } - - // check keypair expiry - - if keypair.created.Add(RejectAfterTime).Before(time.Now()) { - continue - } - - // create work element - peer := value.peer - elem := device.GetInboundElement() - elem.packet = packet - elem.buffer = buffer - elem.keypair = keypair - elem.dropped = AtomicFalse - elem.endpoint = endpoint - elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() - - // add to decryption queues - - if peer.isRunning.Get() { - if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { - buffer = device.GetMessageBuffer() - } - } - - continue - - // otherwise it is a fixed size & handshake related packet - - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize - - case MessageResponseType: - okay = len(packet) == MessageResponseSize - - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize - - default: - logDebug.Println("Received message with unknown type") - } - - if okay { - if (device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - endpoint: endpoint, - }, - )) { - buffer = device.GetMessageBuffer() - } - } - } -} - -func (device *Device) RoutineDecryption() { - - var nonce [chacha20poly1305.NonceSize]byte - - logDebug := device.log.Debug - defer func() { - logDebug.Println("Routine: decryption worker - stopped") - device.state.stopping.Done() - }() - logDebug.Println("Routine: decryption worker - started") - device.state.starting.Done() - - for { - select { - case <-device.signals.stop: - return - - case elem, ok := <-device.queue.decryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } - - // split message into fields - - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] - content := elem.packet[MessageTransportOffsetContent:] - - // expand nonce - - nonce[0x4] = counter[0x0] - nonce[0x5] = counter[0x1] - nonce[0x6] = counter[0x2] - nonce[0x7] = counter[0x3] - - nonce[0x8] = counter[0x4] - nonce[0x9] = counter[0x5] - nonce[0xa] = counter[0x6] - nonce[0xb] = counter[0x7] - - // decrypt and release to consumer - - var err error - elem.counter = binary.LittleEndian.Uint64(counter) - elem.packet, err = elem.keypair.receive.Open( - content[:0], - nonce[:], - content, - nil, - ) - if err != nil { - elem.Drop() - device.PutMessageBuffer(elem.buffer) - } - elem.Unlock() - } - } -} - -/* Handles incoming packets related to handshake - */ -func (device *Device) RoutineHandshake() { - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - var elem QueueHandshakeElement - var ok bool - - defer func() { - logDebug.Println("Routine: handshake worker - stopped") - device.state.stopping.Done() - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - } - }() - - logDebug.Println("Routine: handshake worker - started") - device.state.starting.Done() - - for { - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - elem.buffer = nil - } - - select { - case elem, ok = <-device.queue.handshake: - case <-device.signals.stop: - return - } - - if !ok { - return - } - - // handle cookie fields and ratelimiting - - switch elem.msgType { - - case MessageCookieReplyType: - - // unmarshal packet - - var reply MessageCookieReply - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &reply) - if err != nil { - logDebug.Println("Failed to decode cookie reply") - return - } - - // lookup peer from index - - entry := device.indexTable.Lookup(reply.Receiver) - - if entry.peer == nil { - continue - } - - // consume reply - - if peer := entry.peer; peer.isRunning.Get() { - logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString()) - if !peer.cookieGenerator.ConsumeReply(&reply) { - logDebug.Println("Could not decrypt invalid cookie response") - } - } - - continue - - case MessageInitiationType, MessageResponseType: - - // check mac fields and maybe ratelimit - - if !device.cookieChecker.CheckMAC1(elem.packet) { - logDebug.Println("Received packet with invalid mac1") - continue - } - - // endpoints destination address is the source of the datagram - - if device.IsUnderLoad() { - - // verify MAC2 field - - if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { - device.SendHandshakeCookie(&elem) - continue - } - - // check ratelimiter - - if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { - continue - } - } - - default: - logError.Println("Invalid packet ended up in the handshake queue") - continue - } - - // handle handshake initiation/response content - - switch elem.msgType { - case MessageInitiationType: - - // unmarshal - - var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) - if err != nil { - logError.Println("Failed to decode initiation message") - continue - } - - // consume initiation - - peer := device.ConsumeMessageInitiation(&msg) - if peer == nil { - logInfo.Println( - "Received invalid initiation message from", - elem.endpoint.DstToString(), - ) - continue - } - - // update timers - - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - - // update endpoint - peer.SetEndpointFromPacket(elem.endpoint) - - logDebug.Println(peer, "- Received handshake initiation") - - peer.SendHandshakeResponse() - - case MessageResponseType: - - // unmarshal - - var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) - if err != nil { - logError.Println("Failed to decode response message") - continue - } - - // consume response - - peer := device.ConsumeMessageResponse(&msg) - if peer == nil { - logInfo.Println( - "Received invalid response message from", - elem.endpoint.DstToString(), - ) - continue - } - - // update endpoint - peer.SetEndpointFromPacket(elem.endpoint) - - logDebug.Println(peer, "- Received handshake response") - - // update timers - - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - - // derive keypair - - err = peer.BeginSymmetricSession() - - if err != nil { - logError.Println(peer, "- Failed to derive keypair:", err) - continue - } - - peer.timersSessionDerived() - peer.timersHandshakeComplete() - peer.SendKeepalive() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } - } - } -} - -func (peer *Peer) RoutineSequentialReceiver() { - - device := peer.device - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - var elem *QueueInboundElement - var ok bool - - defer func() { - logDebug.Println(peer, "- Routine: sequential receiver - stopped") - peer.routines.stopping.Done() - if elem != nil { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - } - device.PutInboundElement(elem) - } - }() - - logDebug.Println(peer, "- Routine: sequential receiver - started") - - peer.routines.starting.Done() - - for { - if elem != nil { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - } - device.PutInboundElement(elem) - elem = nil - } - - select { - - case <-peer.routines.stop: - return - - case elem, ok = <-peer.queue.inbound: - - if !ok { - return - } - - // wait for decryption - - elem.Lock() - - if elem.IsDropped() { - continue - } - - // check for replay - - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - continue - } - - // update endpoint - peer.SetEndpointFromPacket(elem.endpoint) - - // check if using new keypair - if peer.ReceivedWithKeypair(elem.keypair) { - peer.timersHandshakeComplete() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } - } - - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - - // check for keepalive - - if len(elem.packet) == 0 { - logDebug.Println(peer, "- Receiving keepalive packet") - continue - } - peer.timersDataReceived() - - // verify source and strip padding - - switch elem.packet[0] >> 4 { - case ipv4.Version: - - // strip padding - - if len(elem.packet) < ipv4.HeaderLen { - continue - } - - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - continue - } - - elem.packet = elem.packet[:length] - - // verify IPv4 source - - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.LookupIPv4(src) != peer { - logInfo.Println( - "IPv4 packet with disallowed source address from", - peer, - ) - continue - } - - case ipv6.Version: - - // strip padding - - if len(elem.packet) < ipv6.HeaderLen { - continue - } - - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } - - elem.packet = elem.packet[:length] - - // verify IPv6 source - - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.LookupIPv6(src) != peer { - logInfo.Println( - peer, - "sent packet with disallowed IPv6 source", - ) - continue - } - - default: - logInfo.Println("Packet with invalid IP version from", peer) - continue - } - - // write to tun device - - offset := MessageTransportOffsetContent - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) - _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) - if err != nil { - logError.Println("Failed to write packet to TUN device:", err) - } - } - } -} diff --git a/send.go b/send.go deleted file mode 100644 index b7cac04..0000000 --- a/send.go +++ /dev/null @@ -1,618 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "bytes" - "encoding/binary" - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "net" - "sync" - "sync/atomic" - "time" -) - -/* Outbound flow - * - * 1. TUN queue - * 2. Routing (sequential) - * 3. Nonce assignment (sequential) - * 4. Encryption (parallel) - * 5. Transmission (sequential) - * - * The functions in this file occur (roughly) in the order in - * which the packets are processed. - * - * Locking, Producers and Consumers - * - * The order of packets (per peer) must be maintained, - * but encryption of packets happen out-of-order: - * - * The sequential consumers will attempt to take the lock, - * workers release lock when they have completed work (encryption) on the packet. - * - * If the element is inserted into the "encryption queue", - * the content is preceded by enough "junk" to contain the transport header - * (to allow the construction of transport messages in-place) - */ - -type QueueOutboundElement struct { - dropped int32 - sync.Mutex - buffer *[MaxMessageSize]byte // slice holding the packet data - packet []byte // slice of "buffer" (always!) - nonce uint64 // nonce for encryption - keypair *Keypair // keypair for encryption - peer *Peer // related peer -} - -func (device *Device) NewOutboundElement() *QueueOutboundElement { - elem := device.GetOutboundElement() - elem.dropped = AtomicFalse - elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} - elem.nonce = 0 - elem.keypair = nil - elem.peer = nil - return elem -} - -func (elem *QueueOutboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueOutboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - device.PutMessageBuffer(old.buffer) - device.PutOutboundElement(old) - default: - } - } - } -} - -func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) { - select { - case outboundQueue <- element: - select { - case encryptionQueue <- element: - return - default: - element.Drop() - element.peer.device.PutMessageBuffer(element.buffer) - element.Unlock() - } - default: - element.peer.device.PutMessageBuffer(element.buffer) - element.peer.device.PutOutboundElement(element) - } -} - -/* Queues a keepalive if no packets are queued for peer - */ -func (peer *Peer) SendKeepalive() bool { - if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() { - return false - } - elem := peer.device.NewOutboundElement() - elem.packet = nil - select { - case peer.queue.nonce <- elem: - peer.device.log.Debug.Println(peer, "- Sending keepalive packet") - return true - default: - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) - return false - } -} - -func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { - if !isRetry { - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - } - - peer.handshake.mutex.RLock() - if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout { - peer.handshake.mutex.RUnlock() - return nil - } - peer.handshake.mutex.RUnlock() - - peer.handshake.mutex.Lock() - if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout { - peer.handshake.mutex.Unlock() - return nil - } - peer.handshake.lastSentHandshake = time.Now() - peer.handshake.mutex.Unlock() - - peer.device.log.Debug.Println(peer, "- Sending handshake initiation") - - msg, err := peer.device.CreateMessageInitiation(peer) - if err != nil { - peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err) - return err - } - - var buff [MessageInitiationSize]byte - writer := bytes.NewBuffer(buff[:0]) - binary.Write(writer, binary.LittleEndian, msg) - packet := writer.Bytes() - peer.cookieGenerator.AddMacs(packet) - - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketSent() - - err = peer.SendBuffer(packet) - if err != nil { - peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err) - } - peer.timersHandshakeInitiated() - - return err -} - -func (peer *Peer) SendHandshakeResponse() error { - peer.handshake.mutex.Lock() - peer.handshake.lastSentHandshake = time.Now() - peer.handshake.mutex.Unlock() - - peer.device.log.Debug.Println(peer, "- Sending handshake response") - - response, err := peer.device.CreateMessageResponse(peer) - if err != nil { - peer.device.log.Error.Println(peer, "- Failed to create response message:", err) - return err - } - - var buff [MessageResponseSize]byte - writer := bytes.NewBuffer(buff[:0]) - binary.Write(writer, binary.LittleEndian, response) - packet := writer.Bytes() - peer.cookieGenerator.AddMacs(packet) - - err = peer.BeginSymmetricSession() - if err != nil { - peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err) - return err - } - - peer.timersSessionDerived() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketSent() - - err = peer.SendBuffer(packet) - if err != nil { - peer.device.log.Error.Println(peer, "- Failed to send handshake response", err) - } - return err -} - -func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { - - device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString()) - - sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) - reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) - if err != nil { - device.log.Error.Println("Failed to create cookie reply:", err) - return err - } - - var buff [MessageCookieReplySize]byte - writer := bytes.NewBuffer(buff[:0]) - binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) - if err != nil { - device.log.Error.Println("Failed to send cookie reply:", err) - } - return err -} - -func (peer *Peer) keepKeyFreshSending() { - keypair := peer.keypairs.Current() - if keypair == nil { - return - } - nonce := atomic.LoadUint64(&keypair.sendNonce) - if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Now().Sub(keypair.created) > RekeyAfterTime) { - peer.SendHandshakeInitiation(false) - } -} - -/* Reads packets from the TUN and inserts - * into nonce queue for peer - * - * Obs. Single instance per TUN device - */ -func (device *Device) RoutineReadFromTUN() { - - logDebug := device.log.Debug - logError := device.log.Error - - defer func() { - logDebug.Println("Routine: TUN reader - stopped") - device.state.stopping.Done() - }() - - logDebug.Println("Routine: TUN reader - started") - device.state.starting.Done() - - var elem *QueueOutboundElement - - for { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - } - elem = device.NewOutboundElement() - - // read packet - - offset := MessageTransportHeaderSize - size, err := device.tun.device.Read(elem.buffer[:], offset) - - if err != nil { - if !device.isClosed.Get() { - logError.Println("Failed to read packet from TUN device:", err) - device.Close() - } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return - } - - if size == 0 || size > MaxContentSize { - continue - } - - elem.packet = elem.buffer[offset : offset+size] - - // lookup peer - - var peer *Peer - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue - } - dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.allowedips.LookupIPv4(dst) - - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { - continue - } - dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.allowedips.LookupIPv6(dst) - - default: - logDebug.Println("Received packet with unknown IP version") - } - - if peer == nil { - continue - } - - // insert into nonce/pre-handshake queue - - if peer.isRunning.Get() { - if peer.queue.packetInNonceQueueIsAwaitingKey.Get() { - peer.SendHandshakeInitiation(false) - } - addToNonceQueue(peer.queue.nonce, elem, device) - elem = nil - } - } -} - -func (peer *Peer) FlushNonceQueue() { - select { - case peer.signals.flushNonceQueue <- struct{}{}: - default: - } -} - -/* Queues packets when there is no handshake. - * Then assigns nonces to packets sequentially - * and creates "work" structs for workers - * - * Obs. A single instance per peer - */ -func (peer *Peer) RoutineNonce() { - var keypair *Keypair - - device := peer.device - logDebug := device.log.Debug - - flush := func() { - for { - select { - case elem := <-peer.queue.nonce: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - default: - return - } - } - } - - defer func() { - flush() - logDebug.Println(peer, "- Routine: nonce worker - stopped") - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - peer.routines.stopping.Done() - }() - - peer.routines.starting.Done() - logDebug.Println(peer, "- Routine: nonce worker - started") - - for { - NextPacket: - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - - select { - case <-peer.routines.stop: - return - - case <-peer.signals.flushNonceQueue: - flush() - goto NextPacket - - case elem, ok := <-peer.queue.nonce: - - if !ok { - return - } - - // make sure to always pick the newest key - - for { - - // check validity of newest key pair - - keypair = peer.keypairs.Current() - if keypair != nil && keypair.sendNonce < RejectAfterMessages { - if time.Now().Sub(keypair.created) < RejectAfterTime { - break - } - } - peer.queue.packetInNonceQueueIsAwaitingKey.Set(true) - - // no suitable key pair, request for new handshake - - select { - case <-peer.signals.newKeypairArrived: - default: - } - - peer.SendHandshakeInitiation(false) - - // wait for key to be established - - logDebug.Println(peer, "- Awaiting keypair") - - select { - case <-peer.signals.newKeypairArrived: - logDebug.Println(peer, "- Obtained awaited keypair") - - case <-peer.signals.flushNonceQueue: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - flush() - goto NextPacket - - case <-peer.routines.stop: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return - } - } - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - - // populate work element - - elem.peer = peer - elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 - - // double check in case of race condition added by future code - - if elem.nonce >= RejectAfterMessages { - atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - goto NextPacket - } - - elem.keypair = keypair - elem.dropped = AtomicFalse - elem.Lock() - - // add to parallel and sequential queue - addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem) - } - } -} - -/* Encrypts the elements in the queue - * and marks them for sequential consumption (by releasing the mutex) - * - * Obs. One instance per core - */ -func (device *Device) RoutineEncryption() { - - var nonce [chacha20poly1305.NonceSize]byte - - logDebug := device.log.Debug - - defer func() { - for { - select { - case elem, ok := <-device.queue.encryption: - if ok && !elem.IsDropped() { - elem.Drop() - device.PutMessageBuffer(elem.buffer) - elem.Unlock() - } - default: - goto out - } - } - out: - logDebug.Println("Routine: encryption worker - stopped") - device.state.stopping.Done() - }() - - logDebug.Println("Routine: encryption worker - started") - device.state.starting.Done() - - for { - - // fetch next element - - select { - case <-device.signals.stop: - return - - case elem, ok := <-device.queue.encryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } - - // populate header fields - - header := elem.buffer[:MessageTransportHeaderSize] - - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] - - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - - // pad content to multiple of 16 - - mtu := int(atomic.LoadInt32(&device.tun.mtu)) - lastUnit := len(elem.packet) % mtu - paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1) - if paddedSize > mtu { - paddedSize = mtu - } - for i := len(elem.packet); i < paddedSize; i++ { - elem.packet = append(elem.packet, 0) - } - - // encrypt content and release to consumer - - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keypair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.Unlock() - } - } -} - -/* Sequentially reads packets from queue and sends to endpoint - * - * Obs. Single instance per peer. - * The routine terminates then the outbound queue is closed. - */ -func (peer *Peer) RoutineSequentialSender() { - - device := peer.device - - logDebug := device.log.Debug - logError := device.log.Error - - defer func() { - for { - select { - case elem, ok := <-peer.queue.outbound: - if ok { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - elem.Drop() - } - device.PutOutboundElement(elem) - } - default: - goto out - } - } - out: - logDebug.Println(peer, "- Routine: sequential sender - stopped") - peer.routines.stopping.Done() - }() - - logDebug.Println(peer, "- Routine: sequential sender - started") - - peer.routines.starting.Done() - - for { - select { - - case <-peer.routines.stop: - return - - case elem, ok := <-peer.queue.outbound: - - if !ok { - return - } - - elem.Lock() - if elem.IsDropped() { - device.PutOutboundElement(elem) - continue - } - - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketSent() - - // send message and return buffer to pool - - length := uint64(len(elem.packet)) - err := peer.SendBuffer(elem.packet) - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - if err != nil { - logError.Println(peer, "- Failed to send data packet", err) - continue - } - atomic.AddUint64(&peer.stats.txBytes, length) - - if len(elem.packet) != MessageKeepaliveSize { - peer.timersDataSent() - } - peer.keepKeyFreshSending() - } - } -} diff --git a/timers.go b/timers.go deleted file mode 100644 index 9c16f13..0000000 --- a/timers.go +++ /dev/null @@ -1,227 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - * - * This is based heavily on timers.c from the kernel implementation. - */ - -package main - -import ( - "math/rand" - "sync" - "sync/atomic" - "time" -) - -/* This Timer structure and related functions should roughly copy the interface of - * the Linux kernel's struct timer_list. - */ - -type Timer struct { - *time.Timer - modifyingLock sync.RWMutex - runningLock sync.Mutex - isPending bool -} - -func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { - timer := &Timer{} - timer.Timer = time.AfterFunc(time.Hour, func() { - timer.runningLock.Lock() - - timer.modifyingLock.Lock() - if !timer.isPending { - timer.modifyingLock.Unlock() - timer.runningLock.Unlock() - return - } - timer.isPending = false - timer.modifyingLock.Unlock() - - expirationFunction(peer) - timer.runningLock.Unlock() - }) - timer.Stop() - return timer -} - -func (timer *Timer) Mod(d time.Duration) { - timer.modifyingLock.Lock() - timer.isPending = true - timer.Reset(d) - timer.modifyingLock.Unlock() -} - -func (timer *Timer) Del() { - timer.modifyingLock.Lock() - timer.isPending = false - timer.Stop() - timer.modifyingLock.Unlock() -} - -func (timer *Timer) DelSync() { - timer.Del() - timer.runningLock.Lock() - timer.Del() - timer.runningLock.Unlock() -} - -func (timer *Timer) IsPending() bool { - timer.modifyingLock.RLock() - defer timer.modifyingLock.RUnlock() - return timer.isPending -} - -func (peer *Peer) timersActive() bool { - return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0 -} - -func expiredRetransmitHandshake(peer *Peer) { - if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { - peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2) - - if peer.timersActive() { - peer.timers.sendKeepalive.Del() - } - - /* We drop all packets without a keypair and don't try again, - * if we try unsuccessfully for too long to make a handshake. - */ - peer.FlushNonceQueue() - - /* We set a timer for destroying any residue that might be left - * of a partial exchange. - */ - if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() { - peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) - } - } else { - atomic.AddUint32(&peer.timers.handshakeAttempts, 1) - peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) - - /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() - - peer.SendHandshakeInitiation(true) - } -} - -func expiredSendKeepalive(peer *Peer) { - peer.SendKeepalive() - if peer.timers.needAnotherKeepalive.Get() { - peer.timers.needAnotherKeepalive.Set(false) - if peer.timersActive() { - peer.timers.sendKeepalive.Mod(KeepaliveTimeout) - } - } -} - -func expiredNewHandshake(peer *Peer) { - peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) - /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() - peer.SendHandshakeInitiation(false) - -} - -func expiredZeroKeyMaterial(peer *Peer) { - peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) - peer.ZeroAndFlushAll() -} - -func expiredPersistentKeepalive(peer *Peer) { - if peer.persistentKeepaliveInterval > 0 { - peer.SendKeepalive() - } -} - -/* Should be called after an authenticated data packet is sent. */ -func (peer *Peer) timersDataSent() { - if peer.timersActive() && !peer.timers.newHandshake.IsPending() { - peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout) - } -} - -/* Should be called after an authenticated data packet is received. */ -func (peer *Peer) timersDataReceived() { - if peer.timersActive() { - if !peer.timers.sendKeepalive.IsPending() { - peer.timers.sendKeepalive.Mod(KeepaliveTimeout) - } else { - peer.timers.needAnotherKeepalive.Set(true) - } - } -} - -/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ -func (peer *Peer) timersAnyAuthenticatedPacketSent() { - if peer.timersActive() { - peer.timers.sendKeepalive.Del() - } -} - -/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ -func (peer *Peer) timersAnyAuthenticatedPacketReceived() { - if peer.timersActive() { - peer.timers.newHandshake.Del() - } -} - -/* Should be called after a handshake initiation message is sent. */ -func (peer *Peer) timersHandshakeInitiated() { - if peer.timersActive() { - peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) - } -} - -/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ -func (peer *Peer) timersHandshakeComplete() { - if peer.timersActive() { - peer.timers.retransmitHandshake.Del() - } - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - peer.timers.sentLastMinuteHandshake.Set(false) - atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) -} - -/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ -func (peer *Peer) timersSessionDerived() { - if peer.timersActive() { - peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) - } -} - -/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ -func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { - if peer.persistentKeepaliveInterval > 0 && peer.timersActive() { - peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second) - } -} - -func (peer *Peer) timersInit() { - peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake) - peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive) - peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) - peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) - peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - peer.timers.sentLastMinuteHandshake.Set(false) - peer.timers.needAnotherKeepalive.Set(false) -} - -func (peer *Peer) timersStop() { - peer.timers.retransmitHandshake.DelSync() - peer.timers.sendKeepalive.DelSync() - peer.timers.newHandshake.DelSync() - peer.timers.zeroKeyMaterial.DelSync() - peer.timers.persistentKeepalive.DelSync() -} diff --git a/tun.go b/tun.go deleted file mode 100644 index 52bfb68..0000000 --- a/tun.go +++ /dev/null @@ -1,55 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "golang.zx2c4.com/wireguard/tun" - "sync/atomic" -) - -const DefaultMTU = 1420 - -func (device *Device) RoutineTUNEventReader() { - setUp := false - logDebug := device.log.Debug - logInfo := device.log.Info - logError := device.log.Error - - logDebug.Println("Routine: event worker - started") - device.state.starting.Done() - - for event := range device.tun.device.Events() { - if event&tun.TUNEventMTUUpdate != 0 { - mtu, err := device.tun.device.MTU() - old := atomic.LoadInt32(&device.tun.mtu) - if err != nil { - logError.Println("Failed to load updated MTU of device:", err) - } else if int(old) != mtu { - if mtu+MessageTransportSize > MaxMessageSize { - logInfo.Println("MTU updated:", mtu, "(too large)") - } else { - logInfo.Println("MTU updated:", mtu) - } - atomic.StoreInt32(&device.tun.mtu, int32(mtu)) - } - } - - if event&tun.TUNEventUp != 0 && !setUp { - logInfo.Println("Interface set up") - setUp = true - device.Up() - } - - if event&tun.TUNEventDown != 0 && setUp { - logInfo.Println("Interface set down") - setUp = false - device.Down() - } - } - - logDebug.Println("Routine: event worker - stopped") - device.state.stopping.Done() -} diff --git a/tun/helper_test.go b/tun/helper_test.go new file mode 100644 index 0000000..3e86fc8 --- /dev/null +++ b/tun/helper_test.go @@ -0,0 +1,92 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "errors" + "golang.zx2c4.com/wireguard/tun" + "os" + "testing" +) + +/* Helpers for writing unit tests + */ + +type DummyTUN struct { + name string + mtu int + packets chan []byte + events chan tun.TUNEvent +} + +func (tun *DummyTUN) File() *os.File { + return nil +} + +func (tun *DummyTUN) Name() (string, error) { + return tun.name, nil +} + +func (tun *DummyTUN) MTU() (int, error) { + return tun.mtu, nil +} + +func (tun *DummyTUN) Write(d []byte, offset int) (int, error) { + tun.packets <- d[offset:] + return len(d), nil +} + +func (tun *DummyTUN) Close() error { + close(tun.events) + close(tun.packets) + return nil +} + +func (tun *DummyTUN) Events() chan tun.TUNEvent { + return tun.events +} + +func (tun *DummyTUN) Read(d []byte, offset int) (int, error) { + t, ok := <-tun.packets + if !ok { + return 0, errors.New("device closed") + } + copy(d[offset:], t) + return len(t), nil +} + +func CreateDummyTUN(name string) (tun.TUNDevice, error) { + var dummy DummyTUN + dummy.mtu = 0 + dummy.packets = make(chan []byte, 100) + dummy.events = make(chan tun.TUNEvent, 10) + return &dummy, nil +} + +func assertNil(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } +} + +func assertEqual(t *testing.T, a []byte, b []byte) { + if bytes.Compare(a, b) != 0 { + t.Fatal(a, "!=", b) + } +} + +func randDevice(t *testing.T) *Device { + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + tun, _ := CreateDummyTUN("dummy") + logger := NewLogger(LogLevelError, "") + device := NewDevice(tun, logger) + device.SetPrivateKey(sk) + return device +} diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 94efe48..d767d79 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -353,4 +353,4 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) GUID() windows.GUID { return *(*windows.GUID)(tun.wt) -} \ No newline at end of file +} diff --git a/uapi.go b/uapi.go deleted file mode 100644 index 4a370b3..0000000 --- a/uapi.go +++ /dev/null @@ -1,425 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "bufio" - "fmt" - "io" - "net" - "strconv" - "strings" - "sync/atomic" - "time" -) - -type IPCError struct { - int64 -} - -func (s *IPCError) Error() string { - return fmt.Sprintf("IPC error: %d", s.int64) -} - -func (s *IPCError) ErrorCode() int64 { - return s.int64 -} - -func ipcGetOperation(device *Device, socket *bufio.Writer) *IPCError { - - device.log.Debug.Println("UAPI: Processing get operation") - - // create lines - - lines := make([]string, 0, 100) - send := func(line string) { - lines = append(lines, line) - } - - func() { - - // lock required resources - - device.net.RLock() - defer device.net.RUnlock() - - device.staticIdentity.RLock() - defer device.staticIdentity.RUnlock() - - device.peers.RLock() - defer device.peers.RUnlock() - - // serialize device related values - - if !device.staticIdentity.privateKey.IsZero() { - send("private_key=" + device.staticIdentity.privateKey.ToHex()) - } - - if device.net.port != 0 { - send(fmt.Sprintf("listen_port=%d", device.net.port)) - } - - if device.net.fwmark != 0 { - send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) - } - - // serialize each peer state - - for _, peer := range device.peers.keyMap { - peer.RLock() - defer peer.RUnlock() - - send("public_key=" + peer.handshake.remoteStatic.ToHex()) - send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - send("protocol_version=1") - if peer.endpoint != nil { - send("endpoint=" + peer.endpoint.DstToString()) - } - - nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) - secs := nano / time.Second.Nanoseconds() - nano %= time.Second.Nanoseconds() - - send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) - send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) - send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))) - send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))) - send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) - - for _, ip := range device.allowedips.EntriesForPeer(peer) { - send("allowed_ip=" + ip.String()) - } - - } - }() - - // send lines (does not require resource locks) - - for _, line := range lines { - _, err := socket.WriteString(line + "\n") - if err != nil { - return &IPCError{ipcErrorIO} - } - } - - return nil -} - -func ipcSetOperation(device *Device, socket *bufio.Reader) *IPCError { - scanner := bufio.NewScanner(socket) - logError := device.log.Error - logDebug := device.log.Debug - - var peer *Peer - - dummy := false - deviceConfig := true - - for scanner.Scan() { - - // parse line - - line := scanner.Text() - if line == "" { - return nil - } - parts := strings.Split(line, "=") - if len(parts) != 2 { - return &IPCError{ipcErrorProtocol} - } - key := parts[0] - value := parts[1] - - /* device configuration */ - - if deviceConfig { - - switch key { - case "private_key": - var sk NoisePrivateKey - err := sk.FromHex(value) - if err != nil { - logError.Println("Failed to set private_key:", err) - return &IPCError{ipcErrorInvalid} - } - logDebug.Println("UAPI: Updating private key") - device.SetPrivateKey(sk) - - case "listen_port": - - // parse port number - - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to parse listen_port:", err) - return &IPCError{ipcErrorInvalid} - } - - // update port and rebind - - logDebug.Println("UAPI: Updating listen port") - - device.net.Lock() - device.net.port = uint16(port) - device.net.Unlock() - - if err := device.BindUpdate(); err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{ipcErrorPortInUse} - } - - case "fwmark": - - // parse fwmark field - - fwmark, err := func() (uint32, error) { - if value == "" { - return 0, nil - } - mark, err := strconv.ParseUint(value, 10, 32) - return uint32(mark), err - }() - - if err != nil { - logError.Println("Invalid fwmark", err) - return &IPCError{ipcErrorInvalid} - } - - logDebug.Println("UAPI: Updating fwmark") - - if err := device.BindSetMark(uint32(fwmark)); err != nil { - logError.Println("Failed to update fwmark:", err) - return &IPCError{ipcErrorPortInUse} - } - - case "public_key": - // switch to peer configuration - logDebug.Println("UAPI: Transition to peer configuration") - deviceConfig = false - - case "replace_peers": - if value != "true" { - logError.Println("Failed to set replace_peers, invalid value:", value) - return &IPCError{ipcErrorInvalid} - } - logDebug.Println("UAPI: Removing all peers") - device.RemoveAllPeers() - - default: - logError.Println("Invalid UAPI device key:", key) - return &IPCError{ipcErrorInvalid} - } - } - - /* peer configuration */ - - if !deviceConfig { - - switch key { - - case "public_key": - var publicKey NoisePublicKey - err := publicKey.FromHex(value) - if err != nil { - logError.Println("Failed to get peer by public key:", err) - return &IPCError{ipcErrorInvalid} - } - - // ignore peer with public key of device - - device.staticIdentity.RLock() - dummy = device.staticIdentity.publicKey.Equals(publicKey) - device.staticIdentity.RUnlock() - - if dummy { - peer = &Peer{} - } else { - peer = device.LookupPeer(publicKey) - } - - if peer == nil { - peer, err = device.NewPeer(publicKey) - if err != nil { - logError.Println("Failed to create new peer:", err) - return &IPCError{ipcErrorInvalid} - } - logDebug.Println(peer, "- UAPI: Created") - } - - case "remove": - - // remove currently selected peer from device - - if value != "true" { - logError.Println("Failed to set remove, invalid value:", value) - return &IPCError{ipcErrorInvalid} - } - if !dummy { - logDebug.Println(peer, "- UAPI: Removing") - device.RemovePeer(peer.handshake.remoteStatic) - } - peer = &Peer{} - dummy = true - - case "preshared_key": - - // update PSK - - logDebug.Println(peer, "- UAPI: Updating preshared key") - - peer.handshake.mutex.Lock() - err := peer.handshake.presharedKey.FromHex(value) - peer.handshake.mutex.Unlock() - - if err != nil { - logError.Println("Failed to set preshared key:", err) - return &IPCError{ipcErrorInvalid} - } - - case "endpoint": - - // set endpoint destination - - logDebug.Println(peer, "- UAPI: Updating endpoint") - - err := func() error { - peer.Lock() - defer peer.Unlock() - endpoint, err := CreateEndpoint(value) - if err != nil { - return err - } - peer.endpoint = endpoint - return nil - }() - - if err != nil { - logError.Println("Failed to set endpoint:", value) - return &IPCError{ipcErrorInvalid} - } - - case "persistent_keepalive_interval": - - // update persistent keepalive interval - - logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval") - - secs, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to set persistent keepalive interval:", err) - return &IPCError{ipcErrorInvalid} - } - - old := peer.persistentKeepaliveInterval - peer.persistentKeepaliveInterval = uint16(secs) - - // send immediate keepalive if we're turning it on and before it wasn't on - - if old == 0 && secs != 0 { - if err != nil { - logError.Println("Failed to get tun device status:", err) - return &IPCError{ipcErrorIO} - } - if device.isUp.Get() && !dummy { - peer.SendKeepalive() - } - } - - case "replace_allowed_ips": - - logDebug.Println(peer, "- UAPI: Removing all allowedips") - - if value != "true" { - logError.Println("Failed to replace allowedips, invalid value:", value) - return &IPCError{ipcErrorInvalid} - } - - if dummy { - continue - } - - device.allowedips.RemoveByPeer(peer) - - case "allowed_ip": - - logDebug.Println(peer, "- UAPI: Adding allowedip") - - _, network, err := net.ParseCIDR(value) - if err != nil { - logError.Println("Failed to set allowed ip:", err) - return &IPCError{ipcErrorInvalid} - } - - if dummy { - continue - } - - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint(ones), peer) - - case "protocol_version": - - if value != "1" { - logError.Println("Invalid protocol version:", value) - return &IPCError{ipcErrorInvalid} - } - - default: - logError.Println("Invalid UAPI peer key:", key) - return &IPCError{ipcErrorInvalid} - } - } - } - - return nil -} - -func ipcHandle(device *Device, socket net.Conn) { - - // create buffered read/writer - - defer socket.Close() - - buffered := func(s io.ReadWriter) *bufio.ReadWriter { - reader := bufio.NewReader(s) - writer := bufio.NewWriter(s) - return bufio.NewReadWriter(reader, writer) - }(socket) - - defer buffered.Flush() - - op, err := buffered.ReadString('\n') - if err != nil { - return - } - - // handle operation - - var status *IPCError - - switch op { - case "set=1\n": - device.log.Debug.Println("UAPI: Set operation") - status = ipcSetOperation(device, buffered.Reader) - - case "get=1\n": - device.log.Debug.Println("UAPI: Get operation") - status = ipcGetOperation(device, buffered.Writer) - - default: - device.log.Error.Println("Invalid UAPI operation:", op) - return - } - - // write status - - if status != nil { - device.log.Error.Println(status) - fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } -} diff --git a/uapi_bsd.go b/uapi_bsd.go deleted file mode 100644 index d75f4f2..0000000 --- a/uapi_bsd.go +++ /dev/null @@ -1,202 +0,0 @@ -// +build darwin freebsd openbsd - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "errors" - "fmt" - "golang.org/x/sys/unix" - "net" - "os" - "path" - "unsafe" -) - -var socketDirectory = "/var/run/wireguard" - -const ( - ipcErrorIO = -int64(unix.EIO) - ipcErrorProtocol = -int64(unix.EPROTO) - ipcErrorInvalid = -int64(unix.EINVAL) - ipcErrorPortInUse = -int64(unix.EADDRINUSE) - socketName = "%s.sock" -) - -type UAPIListener struct { - listener net.Listener // unix socket listener - connNew chan net.Conn - connErr chan error - kqueueFd int - keventFd int -} - -func (l *UAPIListener) Accept() (net.Conn, error) { - for { - select { - case conn := <-l.connNew: - return conn, nil - - case err := <-l.connErr: - return nil, err - } - } -} - -func (l *UAPIListener) Close() error { - err1 := unix.Close(l.kqueueFd) - err2 := unix.Close(l.keventFd) - err3 := l.listener.Close() - if err1 != nil { - return err1 - } - if err2 != nil { - return err2 - } - return err3 -} - -func (l *UAPIListener) Addr() net.Addr { - return l.listener.Addr() -} - -func UAPIListen(name string, file *os.File) (net.Listener, error) { - - // wrap file in listener - - listener, err := net.FileListener(file) - if err != nil { - return nil, err - } - - uapi := &UAPIListener{ - listener: listener, - connNew: make(chan net.Conn, 1), - connErr: make(chan error, 1), - } - - if unixListener, ok := listener.(*net.UnixListener); ok { - unixListener.SetUnlinkOnClose(true) - } - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - // watch for deletion of socket - - uapi.kqueueFd, err = unix.Kqueue() - if err != nil { - return nil, err - } - uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0) - if err != nil { - unix.Close(uapi.kqueueFd) - return nil, err - } - - go func(l *UAPIListener) { - event := unix.Kevent_t{ - Filter: unix.EVFILT_VNODE, - Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT, - Fflags: unix.NOTE_WRITE, - } - // Allow this assignment to work with both the 32-bit and 64-bit version - // of the above struct. If you know another way, please submit a patch. - *(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd) - events := make([]unix.Kevent_t, 1) - n := 1 - var kerr error - for { - // start with lstat to avoid race condition - if _, err := os.Lstat(socketPath); os.IsNotExist(err) { - l.connErr <- err - return - } - if kerr != nil || n != 1 { - if kerr != nil { - l.connErr <- kerr - } else { - l.connErr <- errors.New("kqueue returned empty") - } - return - } - n, kerr = unix.Kevent(uapi.kqueueFd, []unix.Kevent_t{event}, events, nil) - } - }(uapi) - - // watch for new connections - - go func(l *UAPIListener) { - for { - conn, err := l.listener.Accept() - if err != nil { - l.connErr <- err - break - } - l.connNew <- conn - } - }(uapi) - - return uapi, nil -} - -func UAPIOpen(name string) (*os.File, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 0755) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - addr, err := net.ResolveUnixAddr("unix", socketPath) - if err != nil { - return nil, err - } - - oldUmask := unix.Umask(0077) - listener, err := func() (*net.UnixListener, error) { - - // initial connection attempt - - listener, err := net.ListenUnix("unix", addr) - if err == nil { - return listener, nil - } - - // check if socket already active - - _, err = net.Dial("unix", socketPath) - if err == nil { - return nil, errors.New("unix socket in use") - } - - // cleanup & attempt again - - err = os.Remove(socketPath) - if err != nil { - return nil, err - } - return net.ListenUnix("unix", addr) - }() - unix.Umask(oldUmask) - - if err != nil { - return nil, err - } - - return listener.File() -} diff --git a/uapi_linux.go b/uapi_linux.go deleted file mode 100644 index d4b89fc..0000000 --- a/uapi_linux.go +++ /dev/null @@ -1,199 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "errors" - "fmt" - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/rwcancel" - "net" - "os" - "path" -) - -var socketDirectory = "/var/run/wireguard" - -const ( - ipcErrorIO = -int64(unix.EIO) - ipcErrorProtocol = -int64(unix.EPROTO) - ipcErrorInvalid = -int64(unix.EINVAL) - ipcErrorPortInUse = -int64(unix.EADDRINUSE) - socketName = "%s.sock" -) - -type UAPIListener struct { - listener net.Listener // unix socket listener - connNew chan net.Conn - connErr chan error - inotifyFd int - inotifyRWCancel *rwcancel.RWCancel -} - -func (l *UAPIListener) Accept() (net.Conn, error) { - for { - select { - case conn := <-l.connNew: - return conn, nil - - case err := <-l.connErr: - return nil, err - } - } -} - -func (l *UAPIListener) Close() error { - err1 := unix.Close(l.inotifyFd) - err2 := l.inotifyRWCancel.Cancel() - err3 := l.listener.Close() - if err1 != nil { - return err1 - } - if err2 != nil { - return err2 - } - return err3 -} - -func (l *UAPIListener) Addr() net.Addr { - return l.listener.Addr() -} - -func UAPIListen(name string, file *os.File) (net.Listener, error) { - - // wrap file in listener - - listener, err := net.FileListener(file) - if err != nil { - return nil, err - } - - if unixListener, ok := listener.(*net.UnixListener); ok { - unixListener.SetUnlinkOnClose(true) - } - - uapi := &UAPIListener{ - listener: listener, - connNew: make(chan net.Conn, 1), - connErr: make(chan error, 1), - } - - // watch for deletion of socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - uapi.inotifyFd, err = unix.InotifyInit() - if err != nil { - return nil, err - } - - _, err = unix.InotifyAddWatch( - uapi.inotifyFd, - socketPath, - unix.IN_ATTRIB| - unix.IN_DELETE| - unix.IN_DELETE_SELF, - ) - - if err != nil { - return nil, err - } - - uapi.inotifyRWCancel, err = rwcancel.NewRWCancel(uapi.inotifyFd) - if err != nil { - unix.Close(uapi.inotifyFd) - return nil, err - } - - go func(l *UAPIListener) { - var buff [0]byte - for { - // start with lstat to avoid race condition - if _, err := os.Lstat(socketPath); os.IsNotExist(err) { - l.connErr <- err - return - } - _, err := uapi.inotifyRWCancel.Read(buff[:]) - if err != nil { - l.connErr <- err - return - } - } - }(uapi) - - // watch for new connections - - go func(l *UAPIListener) { - for { - conn, err := l.listener.Accept() - if err != nil { - l.connErr <- err - break - } - l.connNew <- conn - } - }(uapi) - - return uapi, nil -} - -func UAPIOpen(name string) (*os.File, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 0755) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - addr, err := net.ResolveUnixAddr("unix", socketPath) - if err != nil { - return nil, err - } - - oldUmask := unix.Umask(0077) - listener, err := func() (*net.UnixListener, error) { - - // initial connection attempt - - listener, err := net.ListenUnix("unix", addr) - if err == nil { - return listener, nil - } - - // check if socket already active - - _, err = net.Dial("unix", socketPath) - if err == nil { - return nil, errors.New("unix socket in use") - } - - // cleanup & attempt again - - err = os.Remove(socketPath) - if err != nil { - return nil, err - } - return net.ListenUnix("unix", addr) - }() - unix.Umask(oldUmask) - - if err != nil { - return nil, err - } - - return listener.File() -} diff --git a/uapi_windows.go b/uapi_windows.go deleted file mode 100644 index 64917f5..0000000 --- a/uapi_windows.go +++ /dev/null @@ -1,76 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package main - -import ( - "github.com/Microsoft/go-winio" - "net" -) - -//TODO: replace these with actual standard windows error numbers from the win package -const ( - ipcErrorIO = -int64(5) - ipcErrorProtocol = -int64(71) - ipcErrorInvalid = -int64(22) - ipcErrorPortInUse = -int64(98) -) - -type UAPIListener struct { - listener net.Listener // unix socket listener - connNew chan net.Conn - connErr chan error - kqueueFd int - keventFd int -} - -func (l *UAPIListener) Accept() (net.Conn, error) { - for { - select { - case conn := <-l.connNew: - return conn, nil - - case err := <-l.connErr: - return nil, err - } - } -} - -func (l *UAPIListener) Close() error { - return l.listener.Close() -} - -func (l *UAPIListener) Addr() net.Addr { - return l.listener.Addr() -} - -func UAPIListen(name string) (net.Listener, error) { - config := winio.PipeConfig{ - SecurityDescriptor: "", //TODO: we want this to be a very locked down pipe. - } - listener, err := winio.ListenPipe("\\\\.\\pipe\\wireguard\\"+name, &config) //TODO: choose sane name. - if err != nil { - return nil, err - } - - uapi := &UAPIListener{ - listener: listener, - connNew: make(chan net.Conn, 1), - connErr: make(chan error, 1), - } - - go func(l *UAPIListener) { - for { - conn, err := l.listener.Accept() - if err != nil { - l.connErr <- err - break - } - l.connNew <- conn - } - }(uapi) - - return uapi, nil -} -- cgit v1.2.3-59-g8ed1b