diff options
Diffstat (limited to 'device')
48 files changed, 3070 insertions, 3609 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index efc27c0..fa46f97 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -1,173 +1,201 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( + "container/list" + "encoding/binary" "errors" "math/bits" "net" + "net/netip" "sync" "unsafe" ) -type trieEntry struct { - cidr uint - child [2]*trieEntry - bits net.IP - peer *Peer - - // index of "branching" bit - - bit_at_byte uint - bit_at_shift uint -} - -func isLittleEndian() bool { - one := uint32(1) - return *(*byte)(unsafe.Pointer(&one)) != 0 +type parentIndirection struct { + parentBit **trieEntry + parentBitType uint8 } -func swapU32(i uint32) uint32 { - if !isLittleEndian() { - return i - } - - return bits.ReverseBytes32(i) -} - -func swapU64(i uint64) uint64 { - if !isLittleEndian() { - return i - } - - return bits.ReverseBytes64(i) +type trieEntry struct { + peer *Peer + child [2]*trieEntry + parent parentIndirection + cidr uint8 + bitAtByte uint8 + bitAtShift uint8 + bits []byte + perPeerElem *list.Element } -func commonBits(ip1 net.IP, ip2 net.IP) uint { +func commonBits(ip1, ip2 []byte) uint8 { size := len(ip1) if size == net.IPv4len { - a := (*uint32)(unsafe.Pointer(&ip1[0])) - b := (*uint32)(unsafe.Pointer(&ip2[0])) - x := *a ^ *b - return uint(bits.LeadingZeros32(swapU32(x))) + a := binary.BigEndian.Uint32(ip1) + b := binary.BigEndian.Uint32(ip2) + x := a ^ b + return uint8(bits.LeadingZeros32(x)) } else if size == net.IPv6len { - a := (*uint64)(unsafe.Pointer(&ip1[0])) - b := (*uint64)(unsafe.Pointer(&ip2[0])) - x := *a ^ *b + a := binary.BigEndian.Uint64(ip1) + b := binary.BigEndian.Uint64(ip2) + x := a ^ b if x != 0 { - return uint(bits.LeadingZeros64(swapU64(x))) + return uint8(bits.LeadingZeros64(x)) } - a = (*uint64)(unsafe.Pointer(&ip1[8])) - b = (*uint64)(unsafe.Pointer(&ip2[8])) - x = *a ^ *b - return 64 + uint(bits.LeadingZeros64(swapU64(x))) + a = binary.BigEndian.Uint64(ip1[8:]) + b = binary.BigEndian.Uint64(ip2[8:]) + x = a ^ b + return 64 + uint8(bits.LeadingZeros64(x)) } else { panic("Wrong size bit string") } } -func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].removeByPeer(p) - node.child[1] = node.child[1].removeByPeer(p) +func (node *trieEntry) addToPeerEntries() { + node.perPeerElem = node.peer.trieEntries.PushBack(node) +} - if node.peer != p { - return node +func (node *trieEntry) removeFromPeerEntries() { + if node.perPeerElem != nil { + node.peer.trieEntries.Remove(node.perPeerElem) + node.perPeerElem = nil } +} - // remove peer & merge +func (node *trieEntry) choose(ip []byte) byte { + return (ip[node.bitAtByte] >> node.bitAtShift) & 1 +} - node.peer = nil - if node.child[0] == nil { - return node.child[1] +func (node *trieEntry) maskSelf() { + mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) + for i := 0; i < len(mask); i++ { + node.bits[i] &= mask[i] } - return node.child[0] } -func (node *trieEntry) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 +func (node *trieEntry) zeroizePointers() { + // Make the garbage collector's life slightly easier + node.peer = nil + node.child[0] = nil + node.child[1] = nil + node.parent.parentBit = nil } -func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { - - // at leaf - - if node == nil { - return &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), +func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { + for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { + parent = node + if parent.cidr == cidr { + exact = true + return } + bit := node.choose(ip) + node = node.child[bit] } + return +} - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.peer = peer - return node +func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { + if *trie.parentBit == nil { + node := &trieEntry{ + peer: peer, + parent: trie, + bits: ip, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } - bit := node.choose(ip) - node.child[bit] = node.child[bit].insert(ip, cidr, peer) - return node + node.maskSelf() + node.addToPeerEntries() + *trie.parentBit = node + return + } + node, exact := (*trie.parentBit).nodePlacement(ip, cidr) + if exact { + node.removeFromPeerEntries() + node.peer = peer + node.addToPeerEntries() + return } - - // split node newNode := &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + peer: peer, + bits: ip, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } + newNode.maskSelf() + newNode.addToPeerEntries() - cidr = min(cidr, common) - - // check for shorter prefix + var down *trieEntry + if node == nil { + down = *trie.parentBit + } else { + bit := node.choose(ip) + down = node.child[bit] + if down == nil { + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + return + } + } + common := commonBits(down.bits, ip) + if common < cidr { + cidr = common + } + parent := node if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode + bit := newNode.choose(down.bits) + down.parent = parentIndirection{&newNode.child[bit], bit} + newNode.child[bit] = down + if parent == nil { + newNode.parent = trie + *trie.parentBit = newNode + } else { + bit := parent.choose(newNode.bits) + newNode.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = newNode + } + return } - // create new parent for node & newNode - - parent := &trieEntry{ - bits: ip, - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + node = &trieEntry{ + bits: append([]byte{}, newNode.bits...), + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), + } + node.maskSelf() + + bit := node.choose(down.bits) + down.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = down + bit = node.choose(newNode.bits) + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + if parent == nil { + node.parent = trie + *trie.parentBit = node + } else { + bit := parent.choose(node.bits) + node.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = node } - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent } -func (node *trieEntry) lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip []byte) *Peer { var found *Peer - size := uint(len(ip)) + size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { if node.peer != nil { found = node.peer } - if node.bit_at_byte == size { + if node.bitAtByte == size { break } bit := node.choose(ip) @@ -176,76 +204,91 @@ func (node *trieEntry) lookup(ip net.IP) *Peer { return found } -func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { - if node == nil { - return results - } - if node.peer == p { - mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) - results = append(results, net.IPNet{ - Mask: mask, - IP: node.bits.Mask(mask), - }) - } - results = node.child[0].entriesForPeer(p, results) - results = node.child[1].entriesForPeer(p, results) - return results -} - type AllowedIPs struct { IPv4 *trieEntry IPv6 *trieEntry mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { table.mutex.RLock() defer table.mutex.RUnlock() - allowed := make([]net.IPNet, 0, 10) - allowed = table.IPv4.entriesForPeer(peer, allowed) - allowed = table.IPv6.entriesForPeer(peer, allowed) - return allowed -} - -func (table *AllowedIPs) Reset() { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = nil - table.IPv6 = nil + for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { + node := elem.Value.(*trieEntry) + a, _ := netip.AddrFromSlice(node.bits) + if !cb(netip.PrefixFrom(a, int(node.cidr))) { + return + } + } } func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - table.IPv4 = table.IPv4.removeByPeer(peer) - table.IPv6 = table.IPv6.removeByPeer(peer) + var next *list.Element + for elem := peer.trieEntries.Front(); elem != nil; elem = next { + next = elem.Next() + node := elem.Value.(*trieEntry) + + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + continue + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + node.zeroizePointers() + continue + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + node.zeroizePointers() + continue + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + node.zeroizePointers() + parent.zeroizePointers() + } } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { +func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - switch len(ip) { - case net.IPv6len: - table.IPv6 = table.IPv6.insert(ip, cidr, peer) - case net.IPv4len: - table.IPv4 = table.IPv4.insert(ip, cidr, peer) - default: + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else { panic(errors.New("inserting unknown address type")) } } -func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { +func (table *AllowedIPs) Lookup(ip []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() - return table.IPv4.lookup(address) -} - -func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv6.lookup(address) + switch len(ip) { + case net.IPv6len: + return table.IPv6.lookup(ip) + case net.IPv4len: + return table.IPv4.lookup(ip) + default: + panic(errors.New("looking up unknown address type")) + } } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 59c10f7..07065c3 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -1,25 +1,28 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" + "net" + "net/netip" "sort" "testing" ) const ( - NumberOfPeers = 100 - NumberOfAddresses = 250 - NumberOfTests = 10000 + NumberOfPeers = 100 + NumberOfPeerRemovals = 4 + NumberOfAddresses = 250 + NumberOfTests = 10000 ) type SlowNode struct { peer *Peer - cidr uint + cidr uint8 bits []byte } @@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) { r[i], r[j] = r[j], r[i] } -func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { +func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { for _, t := range r { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { t.peer = peer @@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer { return nil } -func TestTrieRandomIPv4(t *testing.T) { - var trie *trieEntry - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) +func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { + n := 0 + for _, x := range r { + if x.peer != peer { + r[n] = x + n++ } } + return r[:n] } -func TestTrieRandomIPv6(t *testing.T) { - var trie *trieEntry - var slow SlowRouter +func TestTrieRandom(t *testing.T) { + var slow4, slow6 SlowRouter var peers []*Peer + var allowedIPs AllowedIPs rand.Seed(1) - const AddressLength = 16 - - for n := 0; n < NumberOfPeers; n += 1 { + for n := 0; n < NumberOfPeers; n++ { peers = append(peers, &Peer{}) } - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) + for n := 0; n < NumberOfAddresses; n++ { + var addr4 [4]byte + rand.Read(addr4[:]) + cidr := uint8(rand.Intn(32) + 1) + index := rand.Intn(NumberOfPeers) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) + slow4 = slow4.Insert(addr4[:], cidr, peers[index]) + + var addr6 [16]byte + rand.Read(addr6[:]) + cidr = uint8(rand.Intn(128) + 1) + index = rand.Intn(NumberOfPeers) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) + slow6 = slow6.Insert(addr6[:], cidr, peers[index]) } - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) + var p int + for p = 0; ; p++ { + for n := 0; n < NumberOfTests; n++ { + var addr4 [4]byte + rand.Read(addr4[:]) + peer1 := slow4.Lookup(addr4[:]) + peer2 := allowedIPs.Lookup(addr4[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) + } + + var addr6 [16]byte + rand.Read(addr6[:]) + peer1 = slow6.Lookup(addr6[:]) + peer2 = allowedIPs.Lookup(addr6[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) + } + } + if p >= len(peers) || p >= NumberOfPeerRemovals { + break } + allowedIPs.RemoveByPeer(peers[p]) + slow4 = slow4.RemoveByPeer(peers[p]) + slow6 = slow6.RemoveByPeer(peers[p]) + } + for ; p < len(peers); p++ { + allowedIPs.RemoveByPeer(peers[p]) + } + + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Failed to remove all nodes from trie by peer") } } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 075ff06..cde068e 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,40 +8,17 @@ package device import ( "math/rand" "net" + "net/netip" "testing" ) -/* Todo: More comprehensive - */ - type testPairCommonBits struct { s1 []byte s2 []byte - match uint -} - -type testPairTrieInsert struct { - key []byte - cidr uint - peer *Peer -} - -type testPairTrieLookup struct { - key []byte - peer *Peer -} - -func printTrie(t *testing.T, p *trieEntry) { - if p == nil { - return - } - t.Log(p) - printTrie(t, p.child[0]) - printTrie(t, p.child[1]) + match uint8 } func TestCommonBits(t *testing.T) { - tests := []testPairCommonBits{ {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, @@ -62,27 +39,28 @@ func TestCommonBits(t *testing.T) { } } -func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { +func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { var trie *trieEntry var peers []*Peer + root := parentIndirection{&trie, 2} rand.Seed(1) const AddressLength = 4 - for n := 0; n < peerNumber; n += 1 { + for n := 0; n < peerNumber; n++ { peers = append(peers, &Peer{}) } - for n := 0; n < addressNumber; n += 1 { + for n := 0; n < addressNumber; n++ { var addr [AddressLength]byte rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) + cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % peerNumber - trie = trie.insert(addr[:], cidr, peers[index]) + root.insert(addr[:], cidr, peers[index]) } - for n := 0; n < b.N; n += 1 { + for n := 0; n < b.N; n++ { var addr [AddressLength]byte rand.Read(addr[:]) trie.lookup(addr[:]) @@ -117,21 +95,21 @@ func TestTrieIPv4(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs - insert := func(peer *Peer, a, b, c, d byte, cidr uint) { - trie = trie.insert([]byte{a, b, c, d}, cidr, peer) + insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { t.Error("Assert EQ failed") } } assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p == peer { t.Error("Assert NEQ failed") } @@ -173,7 +151,7 @@ func TestTrieIPv4(t *testing.T) { assertEQ(a, 192, 0, 0, 0) assertEQ(a, 255, 0, 0, 0) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0) @@ -181,12 +159,21 @@ func TestTrieIPv4(t *testing.T) { assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0) - trie = nil + allowedIPs.RemoveByPeer(a) + allowedIPs.RemoveByPeer(b) + allowedIPs.RemoveByPeer(c) + allowedIPs.RemoveByPeer(d) + allowedIPs.RemoveByPeer(e) + allowedIPs.RemoveByPeer(g) + allowedIPs.RemoveByPeer(h) + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Expected removing all the peers to empty trie, but it did not") + } insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 24) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) } @@ -204,7 +191,7 @@ func TestTrieIPv6(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs expand := func(a uint32) []byte { var out [4]byte @@ -215,13 +202,13 @@ func TestTrieIPv6(t *testing.T) { return out[:] } - insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { var addr []byte addr = append(addr, expand(a)...) addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - trie = trie.insert(addr, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { @@ -230,7 +217,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - p := trie.lookup(addr) + p := allowedIPs.Lookup(addr) if p != peer { t.Error("Assert EQ failed") } diff --git a/device/bind_test.go b/device/bind_test.go index 0c2e2cf..302a521 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -1,23 +1,24 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device -import "errors" +import ( + "errors" + + "golang.zx2c4.com/wireguard/conn" +) type DummyDatagram struct { msg []byte - endpoint Endpoint - world bool // better type + endpoint conn.Endpoint } type DummyBind struct { in6 chan DummyDatagram - ou6 chan DummyDatagram in4 chan DummyDatagram - ou4 chan DummyDatagram closed bool } @@ -25,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error { return nil } -func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { +func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in6 if !ok { return 0, nil, errors.New("closed") } - copy(buff, datagram.msg) + copy(buf, datagram.msg) return len(datagram.msg), datagram.endpoint, nil } -func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { +func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in4 if !ok { return 0, nil, errors.New("closed") } - copy(buff, datagram.msg) + copy(buf, datagram.msg) return len(datagram.msg), datagram.endpoint, nil } @@ -50,6 +51,6 @@ func (b *DummyBind) Close() error { return nil } -func (b *DummyBind) Send(buff []byte, end Endpoint) error { +func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error { return nil } diff --git a/device/boundif_android.go b/device/boundif_android.go deleted file mode 100644 index 6d0fecf..0000000 --- a/device/boundif_android.go +++ /dev/null @@ -1,44 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import "errors" - -func (device *Device) PeekLookAtSocketFd4() (fd int, err error) { - nb, ok := device.net.bind.(*nativeBind) - if !ok { - return 0, errors.New("no socket exists") - } - sysconn, err := nb.ipv4.SyscallConn() - if err != nil { - return - } - err = sysconn.Control(func(f uintptr) { - fd = int(f) - }) - if err != nil { - return - } - return -} - -func (device *Device) PeekLookAtSocketFd6() (fd int, err error) { - nb, ok := device.net.bind.(*nativeBind) - if !ok { - return 0, errors.New("no socket exists") - } - sysconn, err := nb.ipv6.SyscallConn() - if err != nil { - return - } - err = sysconn.Control(func(f uintptr) { - fd = int(f) - }) - if err != nil { - return - } - return -} diff --git a/device/boundif_windows.go b/device/boundif_windows.go deleted file mode 100644 index 6908415..0000000 --- a/device/boundif_windows.go +++ /dev/null @@ -1,64 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "encoding/binary" - "errors" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - sockoptIP_UNICAST_IF = 31 - sockoptIPV6_UNICAST_IF = 31 -) - -func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ - bytes := make([]byte, 4) - binary.BigEndian.PutUint32(bytes, interfaceIndex) - interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) - - if device.net.bind == nil { - return errors.New("Bind is not yet initialized") - } - - sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn() - if err != nil { - return err - } - err2 := sysconn.Control(func(fd uintptr) { - err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex)) - }) - if err2 != nil { - return err2 - } - if err != nil { - return err - } - device.net.bind.(*nativeBind).blackhole4 = blackhole - return nil -} - -func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn() - if err != nil { - return err - } - err2 := sysconn.Control(func(fd uintptr) { - err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex)) - }) - if err2 != nil { - return err2 - } - if err != nil { - return err - } - device.net.bind.(*nativeBind).blackhole6 = blackhole - return nil -} diff --git a/device/channels.go b/device/channels.go new file mode 100644 index 0000000..e526f6b --- /dev/null +++ b/device/channels.go @@ -0,0 +1,137 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "runtime" + "sync" +) + +// An outboundQueue is a channel of QueueOutboundElements awaiting encryption. +// An outboundQueue is ref-counted using its wg field. +// An outboundQueue created with newOutboundQueue has one reference. +// Every additional writer must call wg.Add(1). +// Every completed writer must call wg.Done(). +// When no further writers will be added, +// call wg.Done to remove the initial reference. +// When the refcount hits 0, the queue's channel is closed. +type outboundQueue struct { + c chan *QueueOutboundElementsContainer + wg sync.WaitGroup +} + +func newOutboundQueue() *outboundQueue { + q := &outboundQueue{ + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + +// A inboundQueue is similar to an outboundQueue; see those docs. +type inboundQueue struct { + c chan *QueueInboundElementsContainer + wg sync.WaitGroup +} + +func newInboundQueue() *inboundQueue { + q := &inboundQueue{ + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + +// A handshakeQueue is similar to an outboundQueue; see those docs. +type handshakeQueue struct { + c chan QueueHandshakeElement + wg sync.WaitGroup +} + +func newHandshakeQueue() *handshakeQueue { + q := &handshakeQueue{ + c: make(chan QueueHandshakeElement, QueueHandshakeSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + +type autodrainingInboundQueue struct { + c chan *QueueInboundElementsContainer +} + +// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. +// It is useful in cases in which is it hard to manage the lifetime of the channel. +// The returned channel must not be closed. Senders should signal shutdown using +// some other means, such as sending a sentinel nil values. +func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { + q := &autodrainingInboundQueue{ + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), + } + runtime.SetFinalizer(q, device.flushInboundQueue) + return q +} + +func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { + for { + select { + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsContainer(elemsContainer) + default: + return + } + } +} + +type autodrainingOutboundQueue struct { + c chan *QueueOutboundElementsContainer +} + +// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. +// It is useful in cases in which is it hard to manage the lifetime of the channel. +// The returned channel must not be closed. Senders should signal shutdown using +// some other means, such as sending a sentinel nil values. +// All sends to the channel must be best-effort, because there may be no receivers. +func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { + q := &autodrainingOutboundQueue{ + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), + } + runtime.SetFinalizer(q, device.flushOutboundQueue) + return q +} + +func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { + for { + select { + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsContainer(elemsContainer) + default: + return + } + } +} diff --git a/device/conn.go b/device/conn.go deleted file mode 100644 index 7b341f6..0000000 --- a/device/conn.go +++ /dev/null @@ -1,187 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "errors" - "net" - "strings" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" -) - -const ( - ConnRoutineNumber = 2 -) - -/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic - */ -type Bind interface { - SetMark(value uint32) error - ReceiveIPv6(buff []byte) (int, Endpoint, error) - ReceiveIPv4(buff []byte) (int, Endpoint, error) - Send(buff []byte, end Endpoint) error - Close() error -} - -/* An Endpoint maintains the source/destination caching for a peer - * - * dst : the remote address of a peer ("endpoint" in uapi terminology) - * src : the local address from which datagrams originate going to the peer - */ -type Endpoint interface { - ClearSrc() // clears the source address - SrcToString() string // returns the local source address (ip:port) - DstToString() string // returns the destination address (ip:port) - DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP -} - -func parseEndpoint(s string) (*net.UDPAddr, error) { - // ensure that the host is an IP address - - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err - } - if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { - // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just - // trying to make sure with a small sanity test that this is a real IP address and - // not something that's likely to incur DNS lookups. - host = host[:i] - } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) - } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err - } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 - } - return addr, err -} - -func unsafeCloseBind(device *Device) error { - var err error - netc := &device.net - if netc.bind != nil { - err = netc.bind.Close() - netc.bind = nil - } - netc.stopping.Wait() - return err -} - -func (device *Device) BindSetMark(mark uint32) error { - - device.net.Lock() - defer device.net.Unlock() - - // check if modified - - if device.net.fwmark == mark { - return nil - } - - // update fwmark on existing bind - - device.net.fwmark = mark - if device.isUp.Get() && device.net.bind != nil { - if err := device.net.bind.SetMark(mark); err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - return nil -} - -func (device *Device) BindUpdate() error { - - device.net.Lock() - defer device.net.Unlock() - - // close existing sockets - - if err := unsafeCloseBind(device); err != nil { - return err - } - - // open new sockets - - if device.isUp.Get() { - - // bind to new port - - var err error - netc := &device.net - netc.bind, netc.port, err = CreateBind(netc.port, device) - if err != nil { - netc.bind = nil - netc.port = 0 - return err - } - - // set fwmark - - if netc.fwmark != 0 { - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - // start receiving routines - - device.net.starting.Add(ConnRoutineNumber) - device.net.stopping.Add(ConnRoutineNumber) - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - device.net.starting.Wait() - - device.log.Debug.Println("UDP bind has been updated") - } - - return nil -} - -func (device *Device) BindClose() error { - device.net.Lock() - err := unsafeCloseBind(device) - device.net.Unlock() - return err -} diff --git a/device/conn_default.go b/device/conn_default.go deleted file mode 100644 index 661f57d..0000000 --- a/device/conn_default.go +++ /dev/null @@ -1,178 +0,0 @@ -// +build !linux android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "net" - "os" - "syscall" -) - -/* This code is meant to be a temporary solution - * on platforms for which the sticky socket / source caching behavior - * has not yet been implemented. - * - * See conn_linux.go for an implementation on the linux platform. - */ - -type nativeBind struct { - ipv4 *net.UDPConn - ipv6 *net.UDPConn - blackhole4 bool - blackhole6 bool -} - -type NativeEndpoint net.UDPAddr - -var _ Bind = (*nativeBind)(nil) -var _ Endpoint = (*NativeEndpoint)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*NativeEndpoint)(addr), err -} - -func (_ *NativeEndpoint) ClearSrc() {} - -func (e *NativeEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP -} - -func (e *NativeEndpoint) SrcIP() net.IP { - return nil // not supported -} - -func (e *NativeEndpoint) DstToBytes() []byte { - addr := (*net.UDPAddr)(e) - out := addr.IP.To4() - if out == nil { - out = addr.IP - } - out = append(out, byte(addr.Port&0xff)) - out = append(out, byte((addr.Port>>8)&0xff)) - return out -} - -func (e *NativeEndpoint) DstToString() string { - return (*net.UDPAddr)(e).String() -} - -func (e *NativeEndpoint) SrcToString() string { - return "" -} - -func listenNet(network string, port int) (*net.UDPConn, int, error) { - - // listen - - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) - if err != nil { - return nil, 0, err - } - - // retrieve port - - laddr := conn.LocalAddr() - uaddr, err := net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - if err != nil { - return nil, 0, err - } - return conn, uaddr.Port, nil -} - -func extractErrno(err error) error { - opErr, ok := err.(*net.OpError) - if !ok { - return nil - } - syscallErr, ok := opErr.Err.(*os.SyscallError) - if !ok { - return nil - } - return syscallErr.Err -} - -func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { - var err error - var bind nativeBind - - port := int(uport) - - bind.ipv4, port, err = listenNet("udp4", port) - if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { - return nil, 0, err - } - - bind.ipv6, port, err = listenNet("udp6", port) - if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { - bind.ipv4.Close() - bind.ipv4 = nil - return nil, 0, err - } - - return &bind, uint16(port), nil -} - -func (bind *nativeBind) Close() error { - var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - } - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - if bind.ipv4 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv4.ReadFromUDP(buff) - if endpoint != nil { - endpoint.IP = endpoint.IP.To4() - } - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - if bind.ipv6 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error { - var err error - nend := endpoint.(*NativeEndpoint) - if nend.IP.To4() != nil { - if bind.ipv4 == nil { - return syscall.EAFNOSUPPORT - } - if bind.blackhole4 { - return nil - } - _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } else { - if bind.ipv6 == nil { - return syscall.EAFNOSUPPORT - } - if bind.blackhole6 { - return nil - } - _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } - return err -} diff --git a/device/conn_linux.go b/device/conn_linux.go deleted file mode 100644 index f74ad51..0000000 --- a/device/conn_linux.go +++ /dev/null @@ -1,757 +0,0 @@ -// +build !android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - * - * This implements userspace semantics of "sticky sockets", modeled after - * WireGuard's kernelspace implementation. This is more or less a straight port - * of the sticky-sockets.c example code: - * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c - * - * Currently there is no way to achieve this within the net package: - * See e.g. https://github.com/golang/go/issues/17930 - * So this code is remains platform dependent. - */ - -package device - -import ( - "errors" - "net" - "strconv" - "sync" - "syscall" - "unsafe" - - "golang.org/x/sys/unix" - "golang.zx2c4.com/wireguard/rwcancel" -) - -const ( - FD_ERR = -1 -) - -type IPv4Source struct { - src [4]byte - ifindex int32 -} - -type IPv6Source struct { - src [16]byte - //ifindex belongs in dst.ZoneId -} - -type NativeEndpoint struct { - dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte - src [unsafe.Sizeof(IPv6Source{})]byte - isV6 bool -} - -func (endpoint *NativeEndpoint) src4() *IPv4Source { - return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) src6() *IPv6Source { - return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { - return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) -} - -func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { - return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) -} - -type nativeBind struct { - sock4 int - sock6 int - netlinkSock int - netlinkCancel *rwcancel.RWCancel - lastMark uint32 -} - -var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = (*nativeBind)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - var end NativeEndpoint - addr, err := parseEndpoint(s) - if err != nil { - return nil, err - } - - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := end.dst4() - end.isV6 = false - dst.Port = addr.Port - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return &end, nil - } - - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) - if err != nil { - return nil, err - } - dst := end.dst6() - end.isV6 = true - dst.Port = addr.Port - dst.ZoneId = zone - copy(dst.Addr[:], ipv6[:]) - end.ClearSrc() - return &end, nil - } - - return nil, errors.New("Invalid IP address") -} - -func createNetlinkRouteSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) - if err != nil { - return -1, err - } - saddr := &unix.SockaddrNetlink{ - Family: unix.AF_NETLINK, - Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), - } - err = unix.Bind(sock, saddr) - if err != nil { - unix.Close(sock) - return -1, err - } - return sock, nil - -} - -func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) { - var err error - var bind nativeBind - var newPort uint16 - - bind.netlinkSock, err = createNetlinkRouteSocket() - if err != nil { - return nil, 0, err - } - bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock) - if err != nil { - unix.Close(bind.netlinkSock) - return nil, 0, err - } - - go bind.routineRouteListener(device) - - // attempt ipv6 bind, update port if succesful - - bind.sock6, newPort, err = create6(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() - return nil, 0, err - } - } else { - port = newPort - } - - // attempt ipv4 bind, update port if succesful - - bind.sock4, newPort, err = create4(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - bind.netlinkCancel.Cancel() - unix.Close(bind.sock6) - return nil, 0, err - } - } else { - port = newPort - } - - if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR { - return nil, 0, errors.New("ipv4 and ipv6 not supported") - } - - return &bind, port, nil -} - -func (bind *nativeBind) SetMark(value uint32) error { - if bind.sock6 != -1 { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - if bind.sock4 != -1 { - err := unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - bind.lastMark = value - return nil -} - -func closeUnblock(fd int) error { - // shutdown to unblock readers and writers - unix.Shutdown(fd, unix.SHUT_RDWR) - return unix.Close(fd) -} - -func (bind *nativeBind) Close() error { - var err1, err2, err3 error - if bind.sock6 != -1 { - err1 = closeUnblock(bind.sock6) - } - if bind.sock4 != -1 { - err2 = closeUnblock(bind.sock4) - } - err3 = bind.netlinkCancel.Cancel() - - if err1 != nil { - return err1 - } - if err2 != nil { - return err2 - } - return err3 -} - -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - if bind.sock6 == -1 { - return 0, nil, syscall.EAFNOSUPPORT - } - n, err := receive6( - bind.sock6, - buff, - &end, - ) - return n, &end, err -} - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - if bind.sock4 == -1 { - return 0, nil, syscall.EAFNOSUPPORT - } - n, err := receive4( - bind.sock4, - buff, - &end, - ) - return n, &end, err -} - -func (bind *nativeBind) Send(buff []byte, end Endpoint) error { - nend := end.(*NativeEndpoint) - if !nend.isV6 { - if bind.sock4 == -1 { - return syscall.EAFNOSUPPORT - } - return send4(bind.sock4, nend, buff) - } else { - if bind.sock6 == -1 { - return syscall.EAFNOSUPPORT - } - return send6(bind.sock6, nend, buff) - } -} - -func (end *NativeEndpoint) SrcIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.src4().src[0], - end.src4().src[1], - end.src4().src[2], - end.src4().src[3], - ) - } else { - return end.src6().src[:] - } -} - -func (end *NativeEndpoint) DstIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.dst4().Addr[0], - end.dst4().Addr[1], - end.dst4().Addr[2], - end.dst4().Addr[3], - ) - } else { - return end.dst6().Addr[:] - } -} - -func (end *NativeEndpoint) DstToBytes() []byte { - if !end.isV6 { - return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] - } else { - return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] - } -} - -func (end *NativeEndpoint) SrcToString() string { - return end.SrcIP().String() -} - -func (end *NativeEndpoint) DstToString() string { - var udpAddr net.UDPAddr - udpAddr.IP = end.DstIP() - if !end.isV6 { - udpAddr.Port = end.dst4().Port - } else { - udpAddr.Port = end.dst6().Port - } - return udpAddr.String() -} - -func (end *NativeEndpoint) ClearDst() { - for i := range end.dst { - end.dst[i] = 0 - } -} - -func (end *NativeEndpoint) ClearSrc() { - for i := range end.src { - end.src[i] = 0 - } -} - -func zoneToUint32(zone string) (uint32, error) { - if zone == "" { - return 0, nil - } - if intr, err := net.InterfaceByName(zone); err == nil { - return uint32(intr.Index), nil - } - n, err := strconv.ParseUint(zone, 10, 32) - return uint32(n), err -} - -func create4(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return FD_ERR, 0, err - } - - addr := unix.SockaddrInet4{ - Port: int(port), - } - - // set sockopts and bind - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IP, - unix.IP_PKTINFO, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - return FD_ERR, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet4).Port - } - - return fd, uint16(addr.Port), err -} - -func create6(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return FD_ERR, 0, err - } - - // set sockopts and bind - - addr := unix.SockaddrInet6{ - Port: int(port), - } - - if err := func() error { - - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_RECVPKTINFO, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_V6ONLY, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - - }(); err != nil { - unix.Close(fd) - return FD_ERR, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet6).Port - } - - return fd, uint16(addr.Port), err -} - -func send4(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet4Pktinfo{ - Spec_dst: end.src4().src, - Ifindex: end.src4().ifindex, - }, - } - - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - } - - return err -} - -func send6(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet6Pktinfo{ - Addr: end.src6().src, - Ifindex: end.dst6().ZoneId, - }, - } - - if cmsg.pktinfo.Addr == [16]byte{} { - cmsg.pktinfo.Ifindex = 0 - } - - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - } - - return err -} - -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // contruct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = false - - if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { - *end.dst4() = *newDst4 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().src = cmsg.pktinfo.Spec_dst - end.src4().ifindex = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // contruct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = true - - if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { - *end.dst6() = *newDst6 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && - cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6().src = cmsg.pktinfo.Addr - end.dst6().ZoneId = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func (bind *nativeBind) routineRouteListener(device *Device) { - type peerEndpointPtr struct { - peer *Peer - endpoint *Endpoint - } - var reqPeer map[uint32]peerEndpointPtr - var reqPeerLock sync.Mutex - - defer unix.Close(bind.netlinkSock) - - for msg := make([]byte, 1<<16); ; { - var err error - var msgn int - for { - msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) - if err == nil || !rwcancel.RetryAfterError(err) { - break - } - if !bind.netlinkCancel.ReadyRead() { - return - } - } - if err != nil { - return - } - - for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { - - hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) - - if uint(hdr.Len) > uint(len(remain)) { - break - } - - switch hdr.Type { - case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: - if hdr.Seq <= MaxPeers && hdr.Seq > 0 { - if uint(len(remain)) < uint(hdr.Len) { - break - } - if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { - attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] - for { - if uint(len(attr)) < uint(unix.SizeofRtAttr) { - break - } - attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) - if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { - break - } - if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { - ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) - reqPeerLock.Lock() - if reqPeer == nil { - reqPeerLock.Unlock() - break - } - pePtr, ok := reqPeer[hdr.Seq] - reqPeerLock.Unlock() - if !ok { - break - } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() - break - } - if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx { - pePtr.peer.Unlock() - break - } - pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc() - pePtr.peer.Unlock() - } - attr = attr[attrhdr.Len:] - } - } - break - } - reqPeerLock.Lock() - reqPeer = make(map[uint32]peerEndpointPtr) - reqPeerLock.Unlock() - go func() { - device.peers.RLock() - i := uint32(1) - for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { - peer.RUnlock() - continue - } - if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { - peer.RUnlock() - break - } - nlmsg := struct { - hdr unix.NlMsghdr - msg unix.RtMsg - dsthdr unix.RtAttr - dst [4]byte - srchdr unix.RtAttr - src [4]byte - markhdr unix.RtAttr - mark uint32 - }{ - unix.NlMsghdr{ - Type: uint16(unix.RTM_GETROUTE), - Flags: unix.NLM_F_REQUEST, - Seq: i, - }, - unix.RtMsg{ - Family: unix.AF_INET, - Dst_len: 32, - Src_len: 32, - }, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_DST, - }, - peer.endpoint.(*NativeEndpoint).dst4().Addr, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_SRC, - }, - peer.endpoint.(*NativeEndpoint).src4().src, - unix.RtAttr{ - Len: 8, - Type: unix.RTA_MARK, - }, - uint32(bind.lastMark), - } - nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) - reqPeerLock.Lock() - reqPeer[i] = peerEndpointPtr{ - peer: peer, - endpoint: &peer.endpoint, - } - reqPeerLock.Unlock() - peer.RUnlock() - i++ - _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) - if err != nil { - break - } - } - device.peers.RUnlock() - }() - } - remain = remain[hdr.Len:] - } - } -} diff --git a/device/constants.go b/device/constants.go index e316f32..59854a1 100644 --- a/device/constants.go +++ b/device/constants.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -12,8 +12,8 @@ import ( /* Specification constants */ const ( - RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 - RejectAfterMessages = (1 << 64) - (1 << 4) - 1 + RekeyAfterMessages = (1 << 60) + RejectAfterMessages = (1 << 64) - (1 << 13) - 1 RekeyAfterTime = time.Second * 120 RekeyAttemptTime = time.Second * 90 RekeyTimeout = time.Second * 5 @@ -35,7 +35,6 @@ const ( /* Implementation constants */ const ( - UnderLoadQueueSize = QueueHandshakeSize / 8 UnderLoadAfterTime = time.Second // how long does the device remain under load after detected MaxPeers = 1 << 16 // maximum number of configured peers ) diff --git a/device/cookie.go b/device/cookie.go index f134128..876f05d 100644 --- a/device/cookie.go +++ b/device/cookie.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool { return hmac.Equal(mac1[:], msg[smac1:smac2]) } -func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { +func (st *CookieChecker) CheckMAC2(msg, src []byte) bool { st.RLock() defer st.RUnlock() @@ -119,7 +119,6 @@ func (st *CookieChecker) CreateReply( recv uint32, src []byte, ) (*MessageCookieReply, error) { - st.RLock() // refresh cookie secret @@ -204,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) - if err != nil { return false } @@ -215,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { } func (st *CookieGenerator) AddMacs(msg []byte) { - size := len(msg) smac2 := size - blake2s.Size128 diff --git a/device/cookie_test.go b/device/cookie_test.go index 79a6a86..4f1e50a 100644 --- a/device/cookie_test.go +++ b/device/cookie_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -10,7 +10,6 @@ import ( ) func TestCookieMAC1(t *testing.T) { - // setup generator / checker var ( @@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) { msg[5] ^= 0x20 - srcBad1 := []byte{192, 168, 13, 37, 40, 01} + srcBad1 := []byte{192, 168, 13, 37, 40, 1} if checker.CheckMAC2(msg, srcBad1) { t.Fatal("MAC2 generation/verification failed") } - srcBad2 := []byte{192, 168, 13, 38, 40, 01} + srcBad2 := []byte{192, 168, 13, 38, 40, 1} if checker.CheckMAC2(msg, srcBad2) { t.Fatal("MAC2 generation/verification failed") } diff --git a/device/device.go b/device/device.go index 569c5a8..83c33ee 100644 --- a/device/device.go +++ b/device/device.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -11,37 +11,40 @@ import ( "sync/atomic" "time" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" + "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" ) -const ( - DeviceRoutineNumberPerCPU = 3 - DeviceRoutineNumberAdditional = 2 -) - type Device struct { - isUp AtomicBool // device is (going) up - isClosed AtomicBool // device is closed? (acting as guard) - log *Logger - - // synchronized resources (locks acquired in order) - state struct { - starting sync.WaitGroup + // state holds the device's state. It is accessed atomically. + // Use the device.deviceState method to read it. + // device.deviceState does not acquire the mutex, so it captures only a snapshot. + // During state transitions, the state variable is updated before the device itself. + // The state is thus either the current state of the device or + // the intended future state of the device. + // For example, while executing a call to Up, state will be deviceStateUp. + // There is no guarantee that that intended future state of the device + // will become the actual state; Up can fail. + // The device can also change state multiple times between time of check and time of use. + // Unsynchronized uses of state must therefore be advisory/best-effort only. + state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience + // stopping blocks until all inputs to Device have been closed. stopping sync.WaitGroup + // mu protects state changes. sync.Mutex - changing AtomicBool - current bool } net struct { - starting sync.WaitGroup stopping sync.WaitGroup sync.RWMutex - bind Bind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) + bind conn.Bind // bind interface + netlinkCancel *rwcancel.RWCancel + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + brokenRoaming bool } staticIdentity struct { @@ -51,153 +54,176 @@ type Device struct { } peers struct { - sync.RWMutex - keyMap map[NoisePublicKey]*Peer + sync.RWMutex // protects keyMap + keyMap map[NoisePublicKey]*Peer } - // unprotected / "self-synchronising resources" + rate struct { + underLoadUntil atomic.Int64 + limiter ratelimiter.Ratelimiter + } allowedips AllowedIPs indexTable IndexTable cookieChecker CookieChecker - rate struct { - underLoadUntil atomic.Value - limiter ratelimiter.Ratelimiter - } - pool struct { - messageBufferPool *sync.Pool - messageBufferReuseChan chan *[MaxMessageSize]byte - inboundElementPool *sync.Pool - inboundElementReuseChan chan *QueueInboundElement - outboundElementPool *sync.Pool - outboundElementReuseChan chan *QueueOutboundElement + inboundElementsContainer *WaitPool + outboundElementsContainer *WaitPool + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { - encryption chan *QueueOutboundElement - decryption chan *QueueInboundElement - handshake chan QueueHandshakeElement - } - - signals struct { - stop chan struct{} + encryption *outboundQueue + decryption *inboundQueue + handshake *handshakeQueue } tun struct { device tun.Device - mtu int32 + mtu atomic.Int32 } + + ipcMutex sync.RWMutex + closed chan struct{} + log *Logger } -/* Converts the peer into a "zombie", which remains in the peer map, - * but processes no packets and does not exists in the routing table. - * - * Must hold device.peers.Mutex - */ -func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { +// deviceState represents the state of a Device. +// There are three states: down, up, closed. +// Transitions: +// +// down -----+ +// ↑↓ ↓ +// up -> closed +type deviceState uint32 - // stop routing and processing of packets +//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState +const ( + deviceStateDown deviceState = iota + deviceStateUp + deviceStateClosed +) + +// deviceState returns device.state.state as a deviceState +// See those docs for how to interpret this value. +func (device *Device) deviceState() deviceState { + return deviceState(device.state.state.Load()) +} + +// isClosed reports whether the device is closed (or is closing). +// See device.state.state comments for how to interpret this value. +func (device *Device) isClosed() bool { + return device.deviceState() == deviceStateClosed +} + +// isUp reports whether the device is up (or is attempting to come up). +// See device.state.state comments for how to interpret this value. +func (device *Device) isUp() bool { + return device.deviceState() == deviceStateUp +} +// Must hold device.peers.Lock() +func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { + // stop routing and processing of packets device.allowedips.RemoveByPeer(peer) peer.Stop() // remove from peer map - delete(device.peers.keyMap, key) } -func deviceUpdateState(device *Device) { - - // check if state already being updated (guard) - - if device.state.changing.Swap(true) { - return - } - - // compare to current state of device - +// changeState attempts to change the device state to match want. +func (device *Device) changeState(want deviceState) (err error) { device.state.Lock() - - newIsUp := device.isUp.Get() - - if newIsUp == device.state.current { - device.state.changing.Set(false) - device.state.Unlock() - return + defer device.state.Unlock() + old := device.deviceState() + if old == deviceStateClosed { + // once closed, always closed + device.log.Verbosef("Interface closed, ignored requested state %s", want) + return nil } - - // change state of device - - switch newIsUp { - case true: - if err := device.BindUpdate(); err != nil { - device.log.Error.Printf("Unable to update bind: %v\n", err) - device.isUp.Set(false) + switch want { + case old: + return nil + case deviceStateUp: + device.state.state.Store(uint32(deviceStateUp)) + err = device.upLocked() + if err == nil { break } - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Start() - if peer.persistentKeepaliveInterval > 0 { - peer.SendKeepalive() - } + fallthrough // up failed; bring the device all the way back down + case deviceStateDown: + device.state.state.Store(uint32(deviceStateDown)) + errDown := device.downLocked() + if err == nil { + err = errDown } - device.peers.RUnlock() - - case false: - device.BindClose() - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Stop() - } - device.peers.RUnlock() } + device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) + return +} - // update state variables - - device.state.current = newIsUp - device.state.changing.Set(false) - device.state.Unlock() +// upLocked attempts to bring the device up and reports whether it succeeded. +// The caller must hold device.state.mu and is responsible for updating device.state.state. +func (device *Device) upLocked() error { + if err := device.BindUpdate(); err != nil { + device.log.Errorf("Unable to update bind: %v", err) + return err + } - // check for state change in the mean time + // The IPC set operation waits for peers to be created before calling Start() on them, + // so if there's a concurrent IPC set request happening, we should wait for it to complete. + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() - deviceUpdateState(device) + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Start() + if peer.persistentKeepaliveInterval.Load() > 0 { + peer.SendKeepalive() + } + } + device.peers.RUnlock() + return nil } -func (device *Device) Up() { - - // closed device cannot be brought up +// downLocked attempts to bring the device down. +// The caller must hold device.state.mu and is responsible for updating device.state.state. +func (device *Device) downLocked() error { + err := device.BindClose() + if err != nil { + device.log.Errorf("Bind close failed: %v", err) + } - if device.isClosed.Get() { - return + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Stop() } + device.peers.RUnlock() + return err +} - device.isUp.Set(true) - deviceUpdateState(device) +func (device *Device) Up() error { + return device.changeState(deviceStateUp) } -func (device *Device) Down() { - device.isUp.Set(false) - deviceUpdateState(device) +func (device *Device) Down() error { + return device.changeState(deviceStateDown) } func (device *Device) IsUnderLoad() bool { - // check if currently under load - now := time.Now() - underLoad := len(device.queue.handshake) >= UnderLoadQueueSize + underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 if underLoad { - device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime)) + device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) return true } - // check if recently under load - - until := device.rate.underLoadUntil.Load().(time.Time) - return until.After(now) + return device.rate.underLoadUntil.Load() > now.UnixNano() } func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { @@ -224,7 +250,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { publicKey := sk.publicKey() for key, peer := range device.peers.keyMap { if peer.handshake.remoteStatic.Equals(publicKey) { - unsafeRemovePeer(device, peer, key) + peer.handshake.mutex.RUnlock() + removePeerLocked(device, peer, key) + peer.handshake.mutex.RLock() } } @@ -236,23 +264,11 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // do static-static DH pre-computations - rmKey := device.staticIdentity.privateKey.IsZero() - expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) - for key, peer := range device.peers.keyMap { + for _, peer := range device.peers.keyMap { handshake := &peer.handshake - - if rmKey { - handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{} - } else { - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) - } - - if isZero(handshake.precomputedStaticStatic[:]) { - unsafeRemovePeer(device, peer, key) - } else { - expiredPeers = append(expiredPeers, peer) - } + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) + expiredPeers = append(expiredPeers, peer) } for _, peer := range lockedPeers { @@ -265,68 +281,63 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { return nil } -func NewDevice(tunDevice tun.Device, logger *Logger) *Device { +func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { device := new(Device) - - device.isUp.Set(false) - device.isClosed.Set(false) - + device.state.state.Store(uint32(deviceStateDown)) + device.closed = make(chan struct{}) device.log = logger - + device.net.bind = bind device.tun.device = tunDevice mtu, err := device.tun.device.MTU() if err != nil { - logger.Error.Println("Trouble determining MTU, assuming default:", err) + device.log.Errorf("Trouble determining MTU, assuming default: %v", err) mtu = DefaultMTU } - device.tun.mtu = int32(mtu) - + device.tun.mtu.Store(int32(mtu)) device.peers.keyMap = make(map[NoisePublicKey]*Peer) - device.rate.limiter.Init() - device.rate.underLoadUntil.Store(time.Time{}) - device.indexTable.Init() - device.allowedips.Reset() device.PopulatePools() // create queues - device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) - device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) - device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) - - // prepare signals - - device.signals.stop = make(chan struct{}) - - // prepare net - - device.net.port = 0 - device.net.bind = nil + device.queue.handshake = newHandshakeQueue() + device.queue.encryption = newOutboundQueue() + device.queue.decryption = newInboundQueue() // start workers cpus := runtime.NumCPU() - device.state.starting.Wait() device.state.stopping.Wait() - device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) - device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) - for i := 0; i < cpus; i += 1 { - go device.RoutineEncryption() - go device.RoutineDecryption() - go device.RoutineHandshake() + device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake + for i := 0; i < cpus; i++ { + go device.RoutineEncryption(i + 1) + go device.RoutineDecryption(i + 1) + go device.RoutineHandshake(i + 1) } + device.state.stopping.Add(1) // RoutineReadFromTUN + device.queue.encryption.wg.Add(1) // RoutineReadFromTUN go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() - device.state.starting.Wait() - return device } +// BatchSize returns the BatchSize for the device as a whole which is the max of +// the bind batch size and the tun batch size. The batch size reported by device +// is the size used to construct memory pools, and is the allowed batch size for +// the lifetime of the device. +func (device *Device) BatchSize() int { + size := device.net.bind.BatchSize() + dSize := device.tun.device.BatchSize() + if size < dSize { + size = dSize + } + return size +} + func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { device.peers.RLock() defer device.peers.RUnlock() @@ -341,7 +352,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) { peer, ok := device.peers.keyMap[key] if ok { - unsafeRemovePeer(device, peer, key) + removePeerLocked(device, peer, key) } } @@ -350,67 +361,50 @@ func (device *Device) RemoveAllPeers() { defer device.peers.Unlock() for key, peer := range device.peers.keyMap { - unsafeRemovePeer(device, peer, key) + removePeerLocked(device, peer, key) } device.peers.keyMap = make(map[NoisePublicKey]*Peer) } -func (device *Device) FlushPacketQueues() { - for { - select { - case elem, ok := <-device.queue.decryption: - if ok { - elem.Drop() - } - case elem, ok := <-device.queue.encryption: - if ok { - elem.Drop() - } - case <-device.queue.handshake: - default: - return - } - } - -} - func (device *Device) Close() { - if device.isClosed.Swap(true) { - return - } - - device.state.starting.Wait() - - device.log.Info.Println("Device closing") - device.state.changing.Set(true) device.state.Lock() defer device.state.Unlock() + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() + if device.isClosed() { + return + } + device.state.state.Store(uint32(deviceStateClosed)) + device.log.Verbosef("Device closing") device.tun.device.Close() - device.BindClose() - - device.isUp.Set(false) - - close(device.signals.stop) + device.downLocked() + // Remove peers before closing queues, + // because peers assume that queues are active. device.RemoveAllPeers() + // We kept a reference to the encryption and decryption queues, + // in case we started any new peers that might write to them. + // No new peers are coming; we are done with these queues. + device.queue.encryption.wg.Done() + device.queue.decryption.wg.Done() + device.queue.handshake.wg.Done() device.state.stopping.Wait() - device.FlushPacketQueues() device.rate.limiter.Close() - device.state.changing.Set(false) - device.log.Info.Println("Interface closed") + device.log.Verbosef("Device closed") + close(device.closed) } func (device *Device) Wait() chan struct{} { - return device.signals.stop + return device.closed } func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { - if device.isClosed.Get() { + if !device.isUp() { return } @@ -425,3 +419,118 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { } device.peers.RUnlock() } + +// closeBindLocked closes the device's net.bind. +// The caller must hold the net mutex. +func closeBindLocked(device *Device) error { + var err error + netc := &device.net + if netc.netlinkCancel != nil { + netc.netlinkCancel.Cancel() + } + if netc.bind != nil { + err = netc.bind.Close() + } + netc.stopping.Wait() + return err +} + +func (device *Device) Bind() conn.Bind { + device.net.Lock() + defer device.net.Unlock() + return device.net.bind +} + +func (device *Device) BindSetMark(mark uint32) error { + device.net.Lock() + defer device.net.Unlock() + + // check if modified + if device.net.fwmark == mark { + return nil + } + + // update fwmark on existing bind + device.net.fwmark = mark + if device.isUp() && device.net.bind != nil { + if err := device.net.bind.SetMark(mark); err != nil { + return err + } + } + + // clear cached source addresses + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.markEndpointSrcForClearing() + } + device.peers.RUnlock() + + return nil +} + +func (device *Device) BindUpdate() error { + device.net.Lock() + defer device.net.Unlock() + + // close existing sockets + if err := closeBindLocked(device); err != nil { + return err + } + + // open new sockets + if !device.isUp() { + return nil + } + + // bind to new port + var err error + var recvFns []conn.ReceiveFunc + netc := &device.net + + recvFns, netc.port, err = netc.bind.Open(netc.port) + if err != nil { + netc.port = 0 + return err + } + + netc.netlinkCancel, err = device.startRouteListener(netc.bind) + if err != nil { + netc.bind.Close() + netc.port = 0 + return err + } + + // set fwmark + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err + } + } + + // clear cached source addresses + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.markEndpointSrcForClearing() + } + device.peers.RUnlock() + + // start receiving routines + device.net.stopping.Add(len(recvFns)) + device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption + device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + batchSize := netc.bind.BatchSize() + for _, fn := range recvFns { + go device.RoutineReceiveIncoming(batchSize, fn) + } + + device.log.Verbosef("UDP bind has been updated") + return nil +} + +func (device *Device) BindClose() error { + device.net.Lock() + err := closeBindLocked(device) + device.net.Unlock() + return err +} diff --git a/device/device_test.go b/device/device_test.go index 14cc605..fff172b 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -1,238 +1,476 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "bufio" "bytes" - "encoding/binary" + "encoding/hex" + "fmt" "io" - "net" + "math/rand" + "net/netip" "os" - "strings" + "runtime" + "runtime/pprof" + "sync" + "sync/atomic" "testing" "time" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/conn/bindtest" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/tuntest" ) -func TestTwoDevicePing(t *testing.T) { - // TODO(crawshaw): pick unused ports on localhost - cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58 -listen_port=53511 -replace_peers=true -public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725 -protocol_version=1 -replace_allowed_ips=true -allowed_ip=1.0.0.2/32 -endpoint=127.0.0.1:53512` - tun1 := NewChannelTUN() - dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: ")) - dev1.Up() - defer dev1.Close() - if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil { - t.Fatal(err) - } - - cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768 -listen_port=53512 -replace_peers=true -public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427 -protocol_version=1 -replace_allowed_ips=true -allowed_ip=1.0.0.1/32 -endpoint=127.0.0.1:53511` - tun2 := NewChannelTUN() - dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: ")) - dev2.Up() - defer dev2.Close() - if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil { - t.Fatal(err) +// uapiCfg returns a string that contains cfg formatted use with IpcSet. +// cfg is a series of alternating key/value strings. +// uapiCfg exists because editors and humans like to insert +// whitespace into configs, which can cause failures, some of which are silent. +// For example, a leading blank newline causes the remainder +// of the config to be silently ignored. +func uapiCfg(cfg ...string) string { + if len(cfg)%2 != 0 { + panic("odd number of args to uapiReader") } - - t.Run("ping 1.0.0.1", func(t *testing.T) { - msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) - tun2.Outbound <- msg2to1 - select { - case msgRecv := <-tun1.Inbound: - if !bytes.Equal(msg2to1, msgRecv) { - t.Error("ping did not transit correctly") - } - case <-time.After(300 * time.Millisecond): - t.Error("ping did not transit") + buf := new(bytes.Buffer) + for i, s := range cfg { + buf.WriteString(s) + sep := byte('\n') + if i%2 == 0 { + sep = '=' } - }) + buf.WriteByte(sep) + } + return buf.String() +} - t.Run("ping 1.0.0.2", func(t *testing.T) { - msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - tun1.Outbound <- msg1to2 - select { - case msgRecv := <-tun2.Inbound: - if !bytes.Equal(msg1to2, msgRecv) { - t.Error("return ping did not transit correctly") - } - case <-time.After(300 * time.Millisecond): - t.Error("return ping did not transit") - } - }) +// genConfigs generates a pair of configs that connect to each other. +// The configs use distinct, probably-usable ports. +func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { + var key1, key2 NoisePrivateKey + _, err := rand.Read(key1[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + _, err = rand.Read(key2[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + pub1, pub2 := key1.publicKey(), key2.publicKey() + + cfgs[0] = uapiCfg( + "private_key", hex.EncodeToString(key1[:]), + "listen_port", "0", + "replace_peers", "true", + "public_key", hex.EncodeToString(pub2[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "1.0.0.2/32", + ) + endpointCfgs[0] = uapiCfg( + "public_key", hex.EncodeToString(pub2[:]), + "endpoint", "127.0.0.1:%d", + ) + cfgs[1] = uapiCfg( + "private_key", hex.EncodeToString(key2[:]), + "listen_port", "0", + "replace_peers", "true", + "public_key", hex.EncodeToString(pub1[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "1.0.0.1/32", + ) + endpointCfgs[1] = uapiCfg( + "public_key", hex.EncodeToString(pub1[:]), + "endpoint", "127.0.0.1:%d", + ) + return +} + +// A testPair is a pair of testPeers. +type testPair [2]testPeer + +// A testPeer is a peer used for testing. +type testPeer struct { + tun *tuntest.ChannelTUN + dev *Device + ip netip.Addr } -func ping(dst, src net.IP) []byte { - localPort := uint16(1337) - seq := uint16(0) +type SendDirection bool - payload := make([]byte, 4) - binary.BigEndian.PutUint16(payload[0:], localPort) - binary.BigEndian.PutUint16(payload[2:], seq) +const ( + Ping SendDirection = true + Pong SendDirection = false +) - return genICMPv4(payload, dst, src) +func (d SendDirection) String() string { + if d == Ping { + return "ping" + } + return "pong" } -// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. -func checksum(buf []byte, initial uint16) uint16 { - v := uint32(initial) - for i := 0; i < len(buf)-1; i += 2 { - v += uint32(binary.BigEndian.Uint16(buf[i:])) +func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) { + tb.Helper() + p0, p1 := pair[0], pair[1] + if !ping { + // pong is the new ping + p0, p1 = p1, p0 } - if len(buf)%2 == 1 { - v += uint32(buf[len(buf)-1]) << 8 + msg := tuntest.Ping(p0.ip, p1.ip) + p1.tun.Outbound <- msg + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + var err error + select { + case msgRecv := <-p0.tun.Inbound: + if !bytes.Equal(msg, msgRecv) { + err = fmt.Errorf("%s did not transit correctly", ping) + } + case <-timer.C: + err = fmt.Errorf("%s did not transit", ping) + case <-done: } - for v > 0xffff { - v = (v >> 16) + (v & 0xffff) + if err != nil { + // The error may have occurred because the test is done. + select { + case <-done: + return + default: + } + // Real error. + tb.Error(err) } - return ^uint16(v) } -func genICMPv4(payload []byte, dst, src net.IP) []byte { - const ( - icmpv4ProtocolNumber = 1 - icmpv4Echo = 8 - icmpv4ChecksumOffset = 2 - icmpv4Size = 8 - ipv4Size = 20 - ipv4TotalLenOffset = 2 - ipv4ChecksumOffset = 10 - ttl = 65 - ) +// genTestPair creates a testPair. +func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { + cfg, endpointCfg := genConfigs(tb) + var binds [2]conn.Bind + if realSocket { + binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() + } else { + binds = bindtest.NewChannelBinds() + } + // Bring up a ChannelTun for each config. + for i := range pair { + p := &pair[i] + p.tun = tuntest.NewChannelTUN() + p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) + level := LogLevelVerbose + if _, ok := tb.(*testing.B); ok && !testing.Verbose() { + level = LogLevelError + } + p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) + if err := p.dev.IpcSet(cfg[i]); err != nil { + tb.Errorf("failed to configure device %d: %v", i, err) + p.dev.Close() + continue + } + if err := p.dev.Up(); err != nil { + tb.Errorf("failed to bring up device %d: %v", i, err) + p.dev.Close() + continue + } + endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port) + } + for i := range pair { + p := &pair[i] + if err := p.dev.IpcSet(endpointCfg[i]); err != nil { + tb.Errorf("failed to configure device endpoint %d: %v", i, err) + p.dev.Close() + continue + } + // The device is ready. Close it when the test completes. + tb.Cleanup(p.dev.Close) + } + return +} + +func TestTwoDevicePing(t *testing.T) { + goroutineLeakCheck(t) + pair := genTestPair(t, true) + t.Run("ping 1.0.0.1", func(t *testing.T) { + pair.Send(t, Ping, nil) + }) + t.Run("ping 1.0.0.2", func(t *testing.T) { + pair.Send(t, Pong, nil) + }) +} + +func TestUpDown(t *testing.T) { + goroutineLeakCheck(t) + const itrials = 50 + const otrials = 10 + + for n := 0; n < otrials; n++ { + pair := genTestPair(t, false) + for i := range pair { + for k := range pair[i].dev.peers.keyMap { + pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) + } + } + var wg sync.WaitGroup + wg.Add(len(pair)) + for i := range pair { + go func(d *Device) { + defer wg.Done() + for i := 0; i < itrials; i++ { + if err := d.Up(); err != nil { + t.Errorf("failed up bring up device: %v", err) + } + time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) + if err := d.Down(); err != nil { + t.Errorf("failed to bring down device: %v", err) + } + time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) + } + }(pair[i].dev) + } + wg.Wait() + for i := range pair { + pair[i].dev.Up() + pair[i].dev.Close() + } + } +} - hdr := make([]byte, ipv4Size+icmpv4Size) +// TestConcurrencySafety does other things concurrently with tunnel use. +// It is intended to be used with the race detector to catch data races. +func TestConcurrencySafety(t *testing.T) { + pair := genTestPair(t, true) + done := make(chan struct{}) - ip := hdr[0:ipv4Size] - icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size] + const warmupIters = 10 + var warmup sync.WaitGroup + warmup.Add(warmupIters) + go func() { + // Send data continuously back and forth until we're done. + // Note that we may continue to attempt to send data + // even after done is closed. + i := warmupIters + for ping := Ping; ; ping = !ping { + pair.Send(t, ping, done) + select { + case <-done: + return + default: + } + if i > 0 { + warmup.Done() + i-- + } + } + }() + warmup.Wait() - // https://tools.ietf.org/html/rfc792 - icmpv4[0] = icmpv4Echo // type - icmpv4[1] = 0 // code - chksum := ^checksum(icmpv4, checksum(payload, 0)) - binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) + applyCfg := func(cfg string) { + err := pair[0].dev.IpcSet(cfg) + if err != nil { + t.Fatal(err) + } + } - // https://tools.ietf.org/html/rfc760 section 3.1 - length := uint16(len(hdr) + len(payload)) - ip[0] = (4 << 4) | (ipv4Size / 4) - binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) - ip[8] = ttl - ip[9] = icmpv4ProtocolNumber - copy(ip[12:], src.To4()) - copy(ip[16:], dst.To4()) - chksum = ^checksum(ip[:], 0) - binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) + // Change persistent_keepalive_interval concurrently with tunnel use. + t.Run("persistentKeepaliveInterval", func(t *testing.T) { + var pub NoisePublicKey + for key := range pair[0].dev.peers.keyMap { + pub = key + break + } + cfg := uapiCfg( + "public_key", hex.EncodeToString(pub[:]), + "persistent_keepalive_interval", "1", + ) + for i := 0; i < 1000; i++ { + applyCfg(cfg) + } + }) - var v []byte - v = append(v, hdr...) - v = append(v, payload...) - return []byte(v) -} + // Change private keys concurrently with tunnel use. + t.Run("privateKey", func(t *testing.T) { + bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") + good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:])) + // Set iters to a large number like 1000 to flush out data races quickly. + // Don't leave it large. That can cause logical races + // in which the handshake is interleaved with key changes + // such that the private key appears to be unchanging but + // other state gets reset, which can cause handshake failures like + // "Received packet with invalid mac1". + const iters = 1 + for i := 0; i < iters; i++ { + applyCfg(bad) + applyCfg(good) + } + }) -// TODO(crawshaw): find a reusable home for this. package devicetest? -type ChannelTUN struct { - Inbound chan []byte // incoming packets, closed on TUN close - Outbound chan []byte // outbound packets, blocks forever on TUN close + // Perform bind updates and keepalive sends concurrently with tunnel use. + t.Run("bindUpdate and keepalive", func(t *testing.T) { + const iters = 10 + for i := 0; i < iters; i++ { + for _, peer := range pair { + peer.dev.BindUpdate() + peer.dev.SendKeepalivesToPeersWithCurrentKeypair() + } + } + }) - closed chan struct{} - events chan tun.Event - tun chTun + close(done) } -func NewChannelTUN() *ChannelTUN { - c := &ChannelTUN{ - Inbound: make(chan []byte), - Outbound: make(chan []byte), - closed: make(chan struct{}), - events: make(chan tun.Event, 1), +func BenchmarkLatency(b *testing.B) { + pair := genTestPair(b, true) + + // Establish a connection. + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) } - c.tun.c = c - c.events <- tun.EventUp - return c } -func (c *ChannelTUN) TUN() tun.Device { - return &c.tun -} +func BenchmarkThroughput(b *testing.B) { + pair := genTestPair(b, true) -type chTun struct { - c *ChannelTUN -} + // Establish a connection. + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) -func (t *chTun) File() *os.File { return nil } + // Measure how long it takes to receive b.N packets, + // starting when we receive the first packet. + var recv atomic.Uint64 + var elapsed time.Duration + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var start time.Time + for { + <-pair[0].tun.Inbound + new := recv.Add(1) + if new == 1 { + start = time.Now() + } + // Careful! Don't change this to else if; b.N can be equal to 1. + if new == uint64(b.N) { + elapsed = time.Since(start) + return + } + } + }() -func (t *chTun) Read(data []byte, offset int) (int, error) { - select { - case <-t.c.closed: - return 0, io.EOF // TODO(crawshaw): what is the correct error value? - case msg := <-t.c.Outbound: - return copy(data[offset:], msg), nil + // Send packets as fast as we can until we've received enough. + ping := tuntest.Ping(pair[0].ip, pair[1].ip) + pingc := pair[1].tun.Outbound + var sent uint64 + for recv.Load() != uint64(b.N) { + sent++ + pingc <- ping } + wg.Wait() + + b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op") + b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss") } -// Write is called by the wireguard device to deliver a packet for routing. -func (t *chTun) Write(data []byte, offset int) (int, error) { - if offset == -1 { - close(t.c.closed) - close(t.c.events) - return 0, io.EOF +func BenchmarkUAPIGet(b *testing.B) { + pair := genTestPair(b, true) + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + pair[0].dev.IpcGetOperation(io.Discard) } - msg := make([]byte, len(data)-offset) - copy(msg, data[offset:]) - select { - case <-t.c.closed: - return 0, io.EOF // TODO(crawshaw): what is the correct error value? - case t.c.Inbound <- msg: - return len(data) - offset, nil +} + +func goroutineLeakCheck(t *testing.T) { + goroutines := func() (int, []byte) { + p := pprof.Lookup("goroutine") + b := new(bytes.Buffer) + p.WriteTo(b, 1) + return p.Count(), b.Bytes() } + + startGoroutines, startStacks := goroutines() + t.Cleanup(func() { + if t.Failed() { + return + } + // Give goroutines time to exit, if they need it. + for i := 0; i < 10000; i++ { + if runtime.NumGoroutine() <= startGoroutines { + return + } + time.Sleep(1 * time.Millisecond) + } + endGoroutines, endStacks := goroutines() + t.Logf("starting stacks:\n%s\n", startStacks) + t.Logf("ending stacks:\n%s\n", endStacks) + t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) + }) } -func (t *chTun) Flush() error { return nil } -func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } -func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } -func (t *chTun) Events() chan tun.Event { return t.c.events } -func (t *chTun) Close() error { - t.Write(nil, -1) - return nil +type fakeBindSized struct { + size int } -func assertNil(t *testing.T, err error) { - if err != nil { - t.Fatal(err) - } +func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + return nil, 0, nil } +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } -func assertEqual(t *testing.T, a, b []byte) { - if !bytes.Equal(a, b) { - t.Fatal(a, "!=", b) - } +type fakeTUNDeviceSized struct { + size int } -func randDevice(t *testing.T) *Device { - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) +func (t *fakeTUNDeviceSized) File() *os.File { return nil } +func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + return 0, nil +} +func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } +func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } +func (t *fakeTUNDeviceSized) Close() error { return nil } +func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } + +func TestBatchSize(t *testing.T) { + d := Device{} + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 1, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) } - tun := newDummyTUN("dummy") - logger := NewLogger(LogLevelError, "") - device := NewDevice(tun, logger) - device.SetPrivateKey(sk) - return device } diff --git a/device/devicestate_string.go b/device/devicestate_string.go new file mode 100644 index 0000000..6577dd4 --- /dev/null +++ b/device/devicestate_string.go @@ -0,0 +1,16 @@ +// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT. + +package device + +import "strconv" + +const _deviceState_name = "DownUpClosed" + +var _deviceState_index = [...]uint8{0, 4, 6, 12} + +func (i deviceState) String() string { + if i >= deviceState(len(_deviceState_index)-1) { + return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]] +} diff --git a/device/endpoint_test.go b/device/endpoint_test.go index 1896790..93a4998 100644 --- a/device/endpoint_test.go +++ b/device/endpoint_test.go @@ -1,53 +1,49 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" - "net" + "net/netip" ) type DummyEndpoint struct { - src [16]byte - dst [16]byte + src, dst netip.Addr } func CreateDummyEndpoint() (*DummyEndpoint, error) { - var end DummyEndpoint - if _, err := rand.Read(end.src[:]); err != nil { + var src, dst [16]byte + if _, err := rand.Read(src[:]); err != nil { return nil, err } - _, err := rand.Read(end.dst[:]) - return &end, err + _, err := rand.Read(dst[:]) + return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err } func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) SrcToString() string { - var addr net.UDPAddr - addr.IP = e.SrcIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.SrcIP(), 1000).String() } func (e *DummyEndpoint) DstToString() string { - var addr net.UDPAddr - addr.IP = e.DstIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.DstIP(), 1000).String() } -func (e *DummyEndpoint) SrcToBytes() []byte { - return e.src[:] +func (e *DummyEndpoint) DstToBytes() []byte { + out := e.DstIP().AsSlice() + out = append(out, byte(1000&0xff)) + out = append(out, byte((1000>>8)&0xff)) + return out } -func (e *DummyEndpoint) DstIP() net.IP { - return e.dst[:] +func (e *DummyEndpoint) DstIP() netip.Addr { + return e.dst } -func (e *DummyEndpoint) SrcIP() net.IP { - return e.src[:] +func (e *DummyEndpoint) SrcIP() netip.Addr { + return e.src } diff --git a/device/indextable.go b/device/indextable.go index 4cba970..00ade7d 100644 --- a/device/indextable.go +++ b/device/indextable.go @@ -1,14 +1,14 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "crypto/rand" + "encoding/binary" "sync" - "unsafe" ) type IndexTableEntry struct { @@ -25,7 +25,8 @@ type IndexTable struct { func randUint32() (uint32, error) { var integer [4]byte _, err := rand.Read(integer[:]) - return *(*uint32)(unsafe.Pointer(&integer[0])), err + // Arbitrary endianness; both are intrinsified by the Go compiler. + return binary.LittleEndian.Uint32(integer[:]), err } func (table *IndexTable) Init() { diff --git a/device/ip.go b/device/ip.go index 9d4fb74..eaf2363 100644 --- a/device/ip.go +++ b/device/ip.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/kdf_test.go b/device/kdf_test.go index cb8dbab..f9c76d6 100644 --- a/device/kdf_test.go +++ b/device/kdf_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -20,7 +20,7 @@ type KDFTest struct { t2 string } -func assertEquals(t *testing.T, a string, b string) { +func assertEquals(t *testing.T, a, b string) { if a != b { t.Fatal("expected", a, "=", b) } diff --git a/device/keypair.go b/device/keypair.go index 9c78fa9..e3540d7 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,6 +8,7 @@ package device import ( "crypto/cipher" "sync" + "sync/atomic" "time" "golang.zx2c4.com/wireguard/replay" @@ -21,10 +22,10 @@ import ( */ type Keypair struct { - sendNonce uint64 + sendNonce atomic.Uint64 send cipher.AEAD receive cipher.AEAD - replayFilter replay.ReplayFilter + replayFilter replay.Filter isInitiator bool created time.Time localIndex uint32 @@ -35,7 +36,7 @@ type Keypairs struct { sync.RWMutex current *Keypair previous *Keypair - next *Keypair + next atomic.Pointer[Keypair] } func (kp *Keypairs) Current() *Keypair { diff --git a/device/logger.go b/device/logger.go index 7c8b704..22b0df0 100644 --- a/device/logger.go +++ b/device/logger.go @@ -1,59 +1,48 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "io" - "io/ioutil" "log" "os" ) +// A Logger provides logging for a Device. +// The functions are Printf-style functions. +// They must be safe for concurrent use. +// They do not require a trailing newline in the format. +// If nil, that level of logging will be silent. +type Logger struct { + Verbosef func(format string, args ...any) + Errorf func(format string, args ...any) +} + +// Log levels for use with NewLogger. const ( LogLevelSilent = iota LogLevelError - LogLevelInfo - LogLevelDebug + LogLevelVerbose ) -type Logger struct { - Debug *log.Logger - Info *log.Logger - Error *log.Logger -} +// Function for use in Logger for discarding logged lines. +func DiscardLogf(format string, args ...any) {} +// NewLogger constructs a Logger that writes to stdout. +// It logs at the specified log level and above. +// It decorates log lines with the log level, date, time, and prepend. func NewLogger(level int, prepend string) *Logger { - output := os.Stdout - logger := new(Logger) - - logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) { - if level >= LogLevelDebug { - return output, output, output - } - if level >= LogLevelInfo { - return output, output, ioutil.Discard - } - if level >= LogLevelError { - return output, ioutil.Discard, ioutil.Discard - } - return ioutil.Discard, ioutil.Discard, ioutil.Discard - }() - - logger.Debug = log.New(logDebug, - "DEBUG: "+prepend, - log.Ldate|log.Ltime, - ) - - logger.Info = log.New(logInfo, - "INFO: "+prepend, - log.Ldate|log.Ltime, - ) - logger.Error = log.New(logErr, - "ERROR: "+prepend, - log.Ldate|log.Ltime, - ) + logger := &Logger{DiscardLogf, DiscardLogf} + logf := func(prefix string) func(string, ...any) { + return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf + } + if level >= LogLevelVerbose { + logger.Verbosef = logf("DEBUG") + } + if level >= LogLevelError { + logger.Errorf = logf("ERROR") + } return logger } diff --git a/device/mark_default.go b/device/mark_default.go deleted file mode 100644 index 7de2524..0000000 --- a/device/mark_default.go +++ /dev/null @@ -1,12 +0,0 @@ -// +build !linux,!openbsd,!freebsd - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -func (bind *nativeBind) SetMark(mark uint32) error { - return nil -} diff --git a/device/mark_unix.go b/device/mark_unix.go deleted file mode 100644 index 669b328..0000000 --- a/device/mark_unix.go +++ /dev/null @@ -1,65 +0,0 @@ -// +build android openbsd freebsd - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "runtime" - - "golang.org/x/sys/unix" -) - -var fwmarkIoctl int - -func init() { - switch runtime.GOOS { - case "linux", "android": - fwmarkIoctl = 36 /* unix.SO_MARK */ - case "freebsd": - fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ - case "openbsd": - fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ - } -} - -func (bind *nativeBind) SetMark(mark uint32) error { - var operr error - if fwmarkIoctl == 0 { - return nil - } - if bind.ipv4 != nil { - fd, err := bind.ipv4.SyscallConn() - if err != nil { - return err - } - err = fd.Control(func(fd uintptr) { - operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) - }) - if err == nil { - err = operr - } - if err != nil { - return err - } - } - if bind.ipv6 != nil { - fd, err := bind.ipv6.SyscallConn() - if err != nil { - return err - } - err = fd.Control(func(fd uintptr) { - operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) - }) - if err == nil { - err = operr - } - if err != nil { - return err - } - } - return nil -} diff --git a/device/misc.go b/device/misc.go deleted file mode 100644 index a38d1c1..0000000 --- a/device/misc.go +++ /dev/null @@ -1,48 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "sync/atomic" -) - -/* Atomic Boolean */ - -const ( - AtomicFalse = int32(iota) - AtomicTrue -) - -type AtomicBool struct { - int32 -} - -func (a *AtomicBool) Get() bool { - return atomic.LoadInt32(&a.int32) == AtomicTrue -} - -func (a *AtomicBool) Swap(val bool) bool { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - return atomic.SwapInt32(&a.int32, flag) == AtomicTrue -} - -func (a *AtomicBool) Set(val bool) { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - atomic.StoreInt32(&a.int32, flag) -} - -func min(a, b uint) uint { - if a > b { - return b - } - return a -} diff --git a/device/mobilequirks.go b/device/mobilequirks.go new file mode 100644 index 0000000..0a0080e --- /dev/null +++ b/device/mobilequirks.go @@ -0,0 +1,19 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created, +// though it will try to deal with it, and race maybe, if called after. +func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { + device.net.brokenRoaming = true + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.endpoint.Lock() + peer.endpoint.disableRoaming = peer.endpoint.val != nil + peer.endpoint.Unlock() + } + device.peers.RUnlock() +} diff --git a/device/noise-helpers.go b/device/noise-helpers.go index f5e4b4b..c2f356b 100644 --- a/device/noise-helpers.go +++ b/device/noise-helpers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -9,6 +9,7 @@ import ( "crypto/hmac" "crypto/rand" "crypto/subtle" + "errors" "hash" "golang.org/x/crypto/blake2s" @@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { return } -func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { +var errInvalidPublicKey = errors.New("invalid public key") + +func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) { apk := (*[NoisePublicKeySize]byte)(&pk) ask := (*[NoisePrivateKeySize]byte)(sk) curve25519.ScalarMult(&ss, ask, apk) - return ss + if isZero(ss[:]) { + return ss, errInvalidPublicKey + } + return ss, nil } diff --git a/device/noise-protocol.go b/device/noise-protocol.go index 88c6aae..e8f6145 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -1,29 +1,50 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "errors" + "fmt" "sync" "time" "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" + "golang.zx2c4.com/wireguard/tai64n" ) +type handshakeState int + const ( - HandshakeZeroed = iota - HandshakeInitiationCreated - HandshakeInitiationConsumed - HandshakeResponseCreated - HandshakeResponseConsumed + handshakeZeroed = handshakeState(iota) + handshakeInitiationCreated + handshakeInitiationConsumed + handshakeResponseCreated + handshakeResponseConsumed ) +func (hs handshakeState) String() string { + switch hs { + case handshakeZeroed: + return "handshakeZeroed" + case handshakeInitiationCreated: + return "handshakeInitiationCreated" + case handshakeInitiationConsumed: + return "handshakeInitiationConsumed" + case handshakeResponseCreated: + return "handshakeResponseCreated" + case handshakeResponseConsumed: + return "handshakeResponseConsumed" + default: + return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs)) + } +} + const ( NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" @@ -39,13 +60,13 @@ const ( ) const ( - MessageInitiationSize = 148 // size of handshake initation message + MessageInitiationSize = 148 // size of handshake initiation message MessageResponseSize = 92 // size of response message MessageCookieReplySize = 64 // size of cookie reply message - MessageTransportHeaderSize = 16 // size of data preceeding content in transport message + MessageTransportHeaderSize = 16 // size of data preceding content in transport message MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport MessageKeepaliveSize = MessageTransportSize // size of keepalive - MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message + MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message ) const ( @@ -95,11 +116,11 @@ type MessageCookieReply struct { } type Handshake struct { - state int + state handshakeState mutex sync.RWMutex hash [blake2s.Size]byte // hash value chainKey [blake2s.Size]byte // chain key - presharedKey NoiseSymmetricKey // psk + presharedKey NoisePresharedKey // psk localEphemeral NoisePrivateKey // ephemeral secret key localIndex uint32 // used to clear hash-table remoteIndex uint32 // index for sending @@ -117,11 +138,11 @@ var ( ZeroNonce [chacha20poly1305.NonceSize]byte ) -func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { +func mixKey(dst, c *[blake2s.Size]byte, data []byte) { KDF1(dst, c[:], data) } -func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { +func mixHash(dst, h *[blake2s.Size]byte, data []byte) { hash, _ := blake2s.New256(nil) hash.Write(h[:]) hash.Write(data) @@ -135,7 +156,7 @@ func (h *Handshake) Clear() { setZero(h.chainKey[:]) setZero(h.hash[:]) h.localIndex = 0 - h.state = HandshakeZeroed + h.state = handshakeZeroed } func (h *Handshake) mixHash(data []byte) { @@ -154,7 +175,6 @@ func init() { } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -162,12 +182,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mutex.Lock() defer handshake.mutex.Unlock() - if isZero(handshake.precomputedStaticStatic[:]) { - return nil, errors.New("static shared secret is zero") - } - // create ephemeral key - var err error handshake.hash = InitialHash handshake.chainKey = InitialChainKey @@ -176,59 +191,56 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e return nil, err } - // assign index - - device.indexTable.Delete(handshake.localIndex) - handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) - - if err != nil { - return nil, err - } - handshake.mixHash(handshake.remoteStatic[:]) msg := MessageInitiation{ Type: MessageInitiationType, Ephemeral: handshake.localEphemeral.publicKey(), - Sender: handshake.localIndex, } handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) // encrypt static key - - func() { - var key [chacha20poly1305.KeySize]byte - ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - ss[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if err != nil { + return nil, err + } + var key [chacha20poly1305.KeySize]byte + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + ss[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) handshake.mixHash(msg.Static[:]) // encrypt timestamp - + if isZero(handshake.precomputedStaticStatic[:]) { + return nil, errInvalidPublicKey + } + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + handshake.precomputedStaticStatic[:], + ) timestamp := tai64n.Now() - func() { - var key [chacha20poly1305.KeySize]byte - KDF2( - &handshake.chainKey, - &key, - handshake.chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) - }() + aead, _ = chacha20poly1305.New(key[:]) + aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) + + // assign index + device.indexTable.Delete(handshake.localIndex) + msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake) + if err != nil { + return nil, err + } + handshake.localIndex = msg.Sender handshake.mixHash(msg.Timestamp[:]) - handshake.state = HandshakeInitiationCreated + handshake.state = handshakeInitiationCreated return &msg, nil } @@ -250,16 +262,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) // decrypt static key - - var err error var peerPK NoisePublicKey - func() { - var key [chacha20poly1305.KeySize]byte - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - KDF2(&chainKey, &key, chainKey[:], ss[:]) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) - }() + var key [chacha20poly1305.KeySize]byte + ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if err != nil { + return nil + } + KDF2(&chainKey, &key, chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) if err != nil { return nil } @@ -268,28 +279,29 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // lookup peer peer := device.LookupPeer(peerPK) - if peer == nil { + if peer == nil || !peer.isRunning.Load() { return nil } handshake := &peer.handshake - if isZero(handshake.precomputedStaticStatic[:]) { - return nil - } // verify identity var timestamp tai64n.Timestamp - var key [chacha20poly1305.KeySize]byte handshake.mutex.RLock() + + if isZero(handshake.precomputedStaticStatic[:]) { + handshake.mutex.RUnlock() + return nil + } KDF2( &chainKey, &key, chainKey[:], handshake.precomputedStaticStatic[:], ) - aead, _ := chacha20poly1305.New(key[:]) + aead, _ = chacha20poly1305.New(key[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) if err != nil { handshake.mutex.RUnlock() @@ -299,11 +311,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // protect against replay & flood - var ok bool - ok = timestamp.After(handshake.lastTimestamp) - ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate + replay := !timestamp.After(handshake.lastTimestamp) + flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate handshake.mutex.RUnlock() - if !ok { + if replay { + device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) + return nil + } + if flood { + device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) return nil } @@ -322,7 +338,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { if now.After(handshake.lastInitiationConsumption) { handshake.lastInitiationConsumption = now } - handshake.state = HandshakeInitiationConsumed + handshake.state = handshakeInitiationConsumed handshake.mutex.Unlock() @@ -337,7 +353,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mutex.Lock() defer handshake.mutex.Unlock() - if handshake.state != HandshakeInitiationConsumed { + if handshake.state != handshakeInitiationConsumed { return nil, errors.New("handshake initiation must be consumed first") } @@ -365,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixHash(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:]) - func() { - ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) - handshake.mixKey(ss[:]) - ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.mixKey(ss[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + if err != nil { + return nil, err + } + handshake.mixKey(ss[:]) + ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if err != nil { + return nil, err + } + handshake.mixKey(ss[:]) // add preshared key @@ -387,13 +407,11 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixHash(tau[:]) - func() { - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) - handshake.mixHash(msg.Empty[:]) - }() + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) + handshake.mixHash(msg.Empty[:]) - handshake.state = HandshakeResponseCreated + handshake.state = handshakeResponseCreated return &msg, nil } @@ -417,13 +435,12 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { ) ok := func() bool { - // lock handshake state handshake.mutex.RLock() defer handshake.mutex.RUnlock() - if handshake.state != HandshakeInitiationCreated { + if handshake.state != handshakeInitiationCreated { return false } @@ -437,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) - func() { - ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + if err != nil { + return false + } + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) - func() { - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() + ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if err != nil { + return false + } + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) // add preshared key (psk) @@ -465,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // authenticate transcript aead, _ := chacha20poly1305.New(key[:]) - _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) if err != nil { return false } @@ -484,7 +503,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { handshake.hash = hash handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender - handshake.state = HandshakeResponseConsumed + handshake.state = handshakeResponseConsumed handshake.mutex.Unlock() @@ -509,7 +528,7 @@ func (peer *Peer) BeginSymmetricSession() error { var sendKey [chacha20poly1305.KeySize]byte var recvKey [chacha20poly1305.KeySize]byte - if handshake.state == HandshakeResponseConsumed { + if handshake.state == handshakeResponseConsumed { KDF2( &sendKey, &recvKey, @@ -517,7 +536,7 @@ func (peer *Peer) BeginSymmetricSession() error { nil, ) isInitiator = true - } else if handshake.state == HandshakeResponseCreated { + } else if handshake.state == handshakeResponseCreated { KDF2( &recvKey, &sendKey, @@ -526,7 +545,7 @@ func (peer *Peer) BeginSymmetricSession() error { ) isInitiator = false } else { - return errors.New("invalid state for keypair derivation") + return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state) } // zero handshake @@ -534,7 +553,7 @@ func (peer *Peer) BeginSymmetricSession() error { setZero(handshake.chainKey[:]) setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. setZero(handshake.localEphemeral[:]) - peer.handshake.state = HandshakeZeroed + peer.handshake.state = handshakeZeroed // create AEAD instances @@ -546,8 +565,7 @@ func (peer *Peer) BeginSymmetricSession() error { setZero(recvKey[:]) keypair.created = time.Now() - keypair.sendNonce = 0 - keypair.replayFilter.Init() + keypair.replayFilter.Reset() keypair.isInitiator = isInitiator keypair.localIndex = peer.handshake.localIndex keypair.remoteIndex = peer.handshake.remoteIndex @@ -564,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error { defer keypairs.Unlock() previous := keypairs.previous - next := keypairs.next + next := keypairs.next.Load() current := keypairs.current if isInitiator { if next != nil { - keypairs.next = nil + keypairs.next.Store(nil) keypairs.previous = next device.DeleteKeypair(current) } else { @@ -578,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error { device.DeleteKeypair(previous) keypairs.current = keypair } else { - keypairs.next = keypair + keypairs.next.Store(keypair) device.DeleteKeypair(next) keypairs.previous = nil device.DeleteKeypair(previous) @@ -589,18 +607,19 @@ func (peer *Peer) BeginSymmetricSession() error { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { keypairs := &peer.keypairs - if keypairs.next != receivedKeypair { + + if keypairs.next.Load() != receivedKeypair { return false } keypairs.Lock() defer keypairs.Unlock() - if keypairs.next != receivedKeypair { + if keypairs.next.Load() != receivedKeypair { return false } old := keypairs.previous keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - keypairs.current = keypairs.next - keypairs.next = nil + keypairs.current = keypairs.next.Load() + keypairs.next.Store(nil) return true } diff --git a/device/noise-types.go b/device/noise-types.go index 6b1f16f..e850359 100644 --- a/device/noise-types.go +++ b/device/noise-types.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -9,19 +9,18 @@ import ( "crypto/subtle" "encoding/hex" "errors" - - "golang.org/x/crypto/chacha20poly1305" ) const ( - NoisePublicKeySize = 32 - NoisePrivateKeySize = 32 + NoisePublicKeySize = 32 + NoisePrivateKeySize = 32 + NoisePresharedKeySize = 32 ) type ( NoisePublicKey [NoisePublicKeySize]byte NoisePrivateKey [NoisePrivateKeySize]byte - NoiseSymmetricKey [chacha20poly1305.KeySize]byte + NoisePresharedKey [NoisePresharedKeySize]byte NoiseNonce uint64 // padded to 12-bytes ) @@ -52,18 +51,19 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) { return } -func (key NoisePrivateKey) ToHex() string { - return hex.EncodeToString(key[:]) +func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { + err = loadExactHex(key[:], src) + if key.IsZero() { + return + } + key.clamp() + return } func (key *NoisePublicKey) FromHex(src string) error { return loadExactHex(key[:], src) } -func (key NoisePublicKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - func (key NoisePublicKey) IsZero() bool { var zero NoisePublicKey return key.Equals(zero) @@ -73,10 +73,6 @@ func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 } -func (key *NoiseSymmetricKey) FromHex(src string) error { +func (key *NoisePresharedKey) FromHex(src string) error { return loadExactHex(key[:], src) } - -func (key NoiseSymmetricKey) ToHex() string { - return hex.EncodeToString(key[:]) -} diff --git a/device/noise_test.go b/device/noise_test.go index 6ba3f2e..2dd5324 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -9,6 +9,9 @@ import ( "bytes" "encoding/binary" "testing" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/tun/tuntest" ) func TestCurveWrappers(t *testing.T) { @@ -21,14 +24,38 @@ func TestCurveWrappers(t *testing.T) { pk1 := sk1.publicKey() pk2 := sk2.publicKey() - ss1 := sk1.sharedSecret(pk2) - ss2 := sk2.sharedSecret(pk1) + ss1, err1 := sk1.sharedSecret(pk2) + ss2, err2 := sk2.sharedSecret(pk1) - if ss1 != ss2 { + if ss1 != ss2 || err1 != nil || err2 != nil { t.Fatal("Failed to compute shared secet") } } +func randDevice(t *testing.T) *Device { + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + tun := tuntest.NewChannelTUN() + logger := NewLogger(LogLevelError, "") + device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) + device.SetPrivateKey(sk) + return device +} + +func assertNil(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } +} + +func assertEqual(t *testing.T, a, b []byte) { + if !bytes.Equal(a, b) { + t.Fatal(a, "!=", b) + } +} + func TestNoiseHandshake(t *testing.T) { dev1 := randDevice(t) dev2 := randDevice(t) @@ -36,8 +63,16 @@ func TestNoiseHandshake(t *testing.T) { defer dev1.Close() defer dev2.Close() - peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) - peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) + peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) + if err != nil { + t.Fatal(err) + } + peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) + if err != nil { + t.Fatal(err) + } + peer1.Start() + peer2.Start() assertEqual( t, @@ -113,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) { t.Fatal("failed to derive keypair for peer 2", err) } - key1 := peer1.keypairs.next + key1 := peer1.keypairs.next.Load() key2 := peer2.keypairs.current // encrypting / decryption test diff --git a/device/peer.go b/device/peer.go index 91d975a..47a2f14 100644 --- a/device/peer.go +++ b/device/peer.go @@ -1,37 +1,35 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "encoding/base64" + "container/list" "errors" - "fmt" "sync" "sync/atomic" "time" -) -const ( - PeerRoutineNumber = 3 + "golang.zx2c4.com/wireguard/conn" ) type Peer struct { - isRunning AtomicBool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer - keypairs Keypairs - handshake Handshake - device *Device - endpoint Endpoint - persistentKeepaliveInterval uint16 - - // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly - stats struct { - txBytes uint64 // bytes send to peer (endpoint) - rxBytes uint64 // bytes received from peer - lastHandshakeNano int64 // nano seconds since epoch + isRunning atomic.Bool + keypairs Keypairs + handshake Handshake + device *Device + stopping sync.WaitGroup // routines pending stop + txBytes atomic.Uint64 // bytes send to peer (endpoint) + rxBytes atomic.Uint64 // bytes received from peer + lastHandshakeNano atomic.Int64 // nano seconds since epoch + + endpoint struct { + sync.Mutex + val conn.Endpoint + clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission + disableRoaming bool } timers struct { @@ -40,40 +38,32 @@ type Peer struct { newHandshake *Timer zeroKeyMaterial *Timer persistentKeepalive *Timer - handshakeAttempts uint32 - needAnotherKeepalive AtomicBool - sentLastMinuteHandshake AtomicBool + handshakeAttempts atomic.Uint32 + needAnotherKeepalive atomic.Bool + sentLastMinuteHandshake atomic.Bool } - signals struct { - newKeypairArrived chan struct{} - flushNonceQueue chan struct{} + state struct { + sync.Mutex // protects against concurrent Start/Stop } queue struct { - nonce chan *QueueOutboundElement // nonce / pre-handshake queue - outbound chan *QueueOutboundElement // sequential ordering of work - inbound chan *QueueInboundElement // sequential ordering of work - packetInNonceQueueIsAwaitingKey AtomicBool - } - - routines struct { - sync.Mutex // held when stopping / starting routines - starting sync.WaitGroup // routines pending start - stopping sync.WaitGroup // routines pending stop - stop chan struct{} // size 0, stop all go routines in peer + staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } - cookieGenerator CookieGenerator + cookieGenerator CookieGenerator + trieEntries list.List + persistentKeepaliveInterval atomic.Uint32 } func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { - if device.isClosed.Get() { + if device.isClosed() { return nil, errors.New("device closed") } // lock resources - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -81,136 +71,144 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { defer device.peers.Unlock() // check if over limit - if len(device.peers.keyMap) >= MaxPeers { return nil, errors.New("too many peers") } // create peer - peer := new(Peer) - peer.Lock() - defer peer.Unlock() peer.cookieGenerator.Init(pk) peer.device = device - peer.isRunning.Set(false) + peer.queue.outbound = newAutodrainingOutboundQueue(device) + peer.queue.inbound = newAutodrainingInboundQueue(device) + peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) // map public key - _, ok := device.peers.keyMap[pk] if ok { return nil, errors.New("adding existing peer") } // pre-compute DH - handshake := &peer.handshake handshake.mutex.Lock() - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) - ssIsZero := isZero(handshake.precomputedStaticStatic[:]) + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk) handshake.remoteStatic = pk handshake.mutex.Unlock() // reset endpoint + peer.endpoint.Lock() + peer.endpoint.val = nil + peer.endpoint.disableRoaming = false + peer.endpoint.clearSrcOnTx = false + peer.endpoint.Unlock() - peer.endpoint = nil - - // conditionally add - - if !ssIsZero { - device.peers.keyMap[pk] = peer - } else { - return nil, nil - } - - // start peer + // init timers + peer.timersInit() - if peer.device.isUp.Get() { - peer.Start() - } + // add + device.peers.keyMap[pk] = peer return peer, nil } -func (peer *Peer) SendBuffer(buffer []byte) error { +func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() - if peer.device.net.bind == nil { - return errors.New("no bind") + if peer.device.isClosed() { + return nil } - peer.RLock() - defer peer.RUnlock() - - if peer.endpoint == nil { + peer.endpoint.Lock() + endpoint := peer.endpoint.val + if endpoint == nil { + peer.endpoint.Unlock() return errors.New("no known endpoint for peer") } + if peer.endpoint.clearSrcOnTx { + endpoint.ClearSrc() + peer.endpoint.clearSrcOnTx = false + } + peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffer, peer.endpoint) + err := peer.device.net.bind.Send(buffers, endpoint) if err == nil { - atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) + var totalLen uint64 + for _, b := range buffers { + totalLen += uint64(len(b)) + } + peer.txBytes.Add(totalLen) } return err } func (peer *Peer) String() string { - base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) - abbreviatedKey := "invalid" - if len(base64Key) == 44 { - abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] + // The awful goo that follows is identical to: + // + // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) + // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43] + // return fmt.Sprintf("peer(%s)", abbreviatedKey) + // + // except that it is considerably more efficient. + src := peer.handshake.remoteStatic + b64 := func(input byte) byte { + return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3) } - return fmt.Sprintf("peer(%s)", abbreviatedKey) + b := []byte("peer(____…____)") + const first = len("peer(") + const second = len("peer(____…") + b[first+0] = b64((src[0] >> 2) & 63) + b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63) + b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63) + b[first+3] = b64(src[2] & 63) + b[second+0] = b64(src[29] & 63) + b[second+1] = b64((src[30] >> 2) & 63) + b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63) + b[second+3] = b64((src[31] << 2) & 63) + return string(b) } func (peer *Peer) Start() { - // should never start a peer on a closed device - - if peer.device.isClosed.Get() { + if peer.device.isClosed() { return } // prevent simultaneous start/stop operations + peer.state.Lock() + defer peer.state.Unlock() - peer.routines.Lock() - defer peer.routines.Unlock() - - if peer.isRunning.Get() { + if peer.isRunning.Load() { return } device := peer.device - device.log.Debug.Println(peer, "- Starting...") + device.log.Verbosef("%v - Starting", peer) // reset routine state + peer.stopping.Wait() + peer.stopping.Add(2) - peer.routines.starting.Wait() - peer.routines.stopping.Wait() - peer.routines.stop = make(chan struct{}) - peer.routines.starting.Add(PeerRoutineNumber) - peer.routines.stopping.Add(PeerRoutineNumber) - - // prepare queues + peer.handshake.mutex.Lock() + peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + peer.handshake.mutex.Unlock() - peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes - peer.timersInit() - peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) - peer.signals.newKeypairArrived = make(chan struct{}, 1) - peer.signals.flushNonceQueue = make(chan struct{}, 1) + peer.timersStart() - // wait for routines to start + device.flushInboundQueue(peer.queue.inbound) + device.flushOutboundQueue(peer.queue.outbound) - go peer.RoutineNonce() - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() + // Use the device batch size, not the bind batch size, as the device size is + // the size of the batch pools. + batchSize := peer.device.BatchSize() + go peer.RoutineSequentialSender(batchSize) + go peer.RoutineSequentialReceiver(batchSize) - peer.routines.starting.Wait() - peer.isRunning.Set(true) + peer.isRunning.Store(true) } func (peer *Peer) ZeroAndFlushAll() { @@ -222,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() { keypairs.Lock() device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.current) - device.DeleteKeypair(keypairs.next) + device.DeleteKeypair(keypairs.next.Load()) keypairs.previous = nil keypairs.current = nil - keypairs.next = nil + keypairs.next.Store(nil) keypairs.Unlock() // clear handshake state @@ -236,7 +234,7 @@ func (peer *Peer) ZeroAndFlushAll() { handshake.Clear() handshake.mutex.Unlock() - peer.FlushNonceQueue() + peer.FlushStagedPackets() } func (peer *Peer) ExpireCurrentKeypairs() { @@ -244,58 +242,55 @@ func (peer *Peer) ExpireCurrentKeypairs() { handshake.mutex.Lock() peer.device.indexTable.Delete(handshake.localIndex) handshake.Clear() - handshake.mutex.Unlock() peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + handshake.mutex.Unlock() keypairs := &peer.keypairs keypairs.Lock() if keypairs.current != nil { - keypairs.current.sendNonce = RejectAfterMessages + keypairs.current.sendNonce.Store(RejectAfterMessages) } - if keypairs.next != nil { - keypairs.next.sendNonce = RejectAfterMessages + if next := keypairs.next.Load(); next != nil { + next.sendNonce.Store(RejectAfterMessages) } keypairs.Unlock() } func (peer *Peer) Stop() { - - // prevent simultaneous start/stop operations + peer.state.Lock() + defer peer.state.Unlock() if !peer.isRunning.Swap(false) { return } - peer.routines.starting.Wait() - - peer.routines.Lock() - defer peer.routines.Unlock() - - peer.device.log.Debug.Println(peer, "- Stopping...") + peer.device.log.Verbosef("%v - Stopping", peer) peer.timersStop() - - // stop & wait for ongoing peer routines - - close(peer.routines.stop) - peer.routines.stopping.Wait() - - // close queues - - close(peer.queue.nonce) - close(peer.queue.outbound) - close(peer.queue.inbound) + // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. + peer.queue.inbound.c <- nil + peer.queue.outbound.c <- nil + peer.stopping.Wait() + peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us peer.ZeroAndFlushAll() } -var RoamingDisabled bool +func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.disableRoaming { + return + } + peer.endpoint.clearSrcOnTx = false + peer.endpoint.val = endpoint +} -func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { - if RoamingDisabled { +func (peer *Peer) markEndpointSrcForClearing() { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.val == nil { return } - peer.Lock() - peer.endpoint = endpoint - peer.Unlock() + peer.endpoint.clearSrcOnTx = true } diff --git a/device/pools.go b/device/pools.go index 98f4ef1..94f3dc7 100644 --- a/device/pools.go +++ b/device/pools.go @@ -1,89 +1,120 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device -import "sync" +import ( + "sync" + "sync/atomic" +) -func (device *Device) PopulatePools() { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool = &sync.Pool{ - New: func() interface{} { - return new([MaxMessageSize]byte) - }, - } - device.pool.inboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueInboundElement) - }, - } - device.pool.outboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueOutboundElement) - }, - } - } else { - device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte) - } - device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.inboundElementReuseChan <- new(QueueInboundElement) - } - device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.outboundElementReuseChan <- new(QueueOutboundElement) +type WaitPool struct { + pool sync.Pool + cond sync.Cond + lock sync.Mutex + count atomic.Uint32 + max uint32 +} + +func NewWaitPool(max uint32, new func() any) *WaitPool { + p := &WaitPool{pool: sync.Pool{New: new}, max: max} + p.cond = sync.Cond{L: &p.lock} + return p +} + +func (p *WaitPool) Get() any { + if p.max != 0 { + p.lock.Lock() + for p.count.Load() >= p.max { + p.cond.Wait() } + p.count.Add(1) + p.lock.Unlock() } + return p.pool.Get() } -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - if PreallocatedBuffersPerPool == 0 { - return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte) - } else { - return <-device.pool.messageBufferReuseChan +func (p *WaitPool) Put(x any) { + p.pool.Put(x) + if p.max == 0 { + return } + p.count.Add(^uint32(0)) + p.cond.Signal() } -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool.Put(msg) - } else { - device.pool.messageBufferReuseChan <- msg - } +func (device *Device) PopulatePools() { + device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueInboundElement, 0, device.BatchSize()) + return &QueueInboundElementsContainer{elems: s} + }) + device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &QueueOutboundElementsContainer{elems: s} + }) + device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new([MaxMessageSize]byte) + }) + device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new(QueueInboundElement) + }) + device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new(QueueOutboundElement) + }) } -func (device *Device) GetInboundElement() *QueueInboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.inboundElementPool.Get().(*QueueInboundElement) - } else { - return <-device.pool.inboundElementReuseChan +func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { + c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) + c.Mutex = sync.Mutex{} + return c +} + +func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } + c.elems = c.elems[:0] + device.pool.inboundElementsContainer.Put(c) +} + +func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { + c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutInboundElement(msg *QueueInboundElement) { - if PreallocatedBuffersPerPool == 0 { - device.pool.inboundElementPool.Put(msg) - } else { - device.pool.inboundElementReuseChan <- msg +func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } + c.elems = c.elems[:0] + device.pool.outboundElementsContainer.Put(c) +} + +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + device.pool.messageBuffers.Put(msg) +} + +func (device *Device) GetInboundElement() *QueueInboundElement { + return device.pool.inboundElements.Get().(*QueueInboundElement) +} + +func (device *Device) PutInboundElement(elem *QueueInboundElement) { + elem.clearPointers() + device.pool.inboundElements.Put(elem) } func (device *Device) GetOutboundElement() *QueueOutboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.outboundElementPool.Get().(*QueueOutboundElement) - } else { - return <-device.pool.outboundElementReuseChan - } + return device.pool.outboundElements.Get().(*QueueOutboundElement) } -func (device *Device) PutOutboundElement(msg *QueueOutboundElement) { - if PreallocatedBuffersPerPool == 0 { - device.pool.outboundElementPool.Put(msg) - } else { - device.pool.outboundElementReuseChan <- msg - } +func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { + elem.clearPointers() + device.pool.outboundElements.Put(elem) } diff --git a/device/pools_test.go b/device/pools_test.go new file mode 100644 index 0000000..82d7493 --- /dev/null +++ b/device/pools_test.go @@ -0,0 +1,139 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "math/rand" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestWaitPool(t *testing.T) { + t.Skip("Currently disabled") + var wg sync.WaitGroup + var trials atomic.Int32 + startTrials := int32(100000) + if raceEnabled { + // This test can be very slow with -race. + startTrials /= 10 + } + trials.Store(startTrials) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + t.Skip("Not enough cores") + } + p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) + wg.Add(workers) + var max atomic.Uint32 + updateMax := func() { + count := p.count.Load() + if count > p.max { + t.Errorf("count (%d) > max (%d)", count, p.max) + } + for { + old := max.Load() + if count <= old { + break + } + if max.CompareAndSwap(old, count) { + break + } + } + } + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + updateMax() + x := p.Get() + updateMax() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + updateMax() + p.Put(x) + updateMax() + } + }() + } + wg.Wait() + if max.Load() != p.max { + t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) + } +} + +func BenchmarkWaitPool(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} + +func BenchmarkWaitPoolEmpty(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := NewWaitPool(0, func() any { return make([]byte, 16) }) + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} + +func BenchmarkSyncPool(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := sync.Pool{New: func() any { return make([]byte, 16) }} + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index f5c042d..25f700a 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -1,16 +1,19 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device +import "golang.zx2c4.com/wireguard/conn" + /* Reduce memory consumption for Android */ const ( + QueueStagedSize = conn.IdealBatchSize QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 - MaxSegmentSize = 2200 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram PreallocatedBuffersPerPool = 4096 ) diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index cf86ba1..ea763d0 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -1,13 +1,16 @@ -// +build !android,!ios +//go:build !android && !ios && !windows /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device +import "golang.zx2c4.com/wireguard/conn" + const ( + QueueStagedSize = conn.IdealBatchSize QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go index 589b0aa..acd3cec 100644 --- a/device/queueconstants_ios.go +++ b/device/queueconstants_ios.go @@ -1,18 +1,21 @@ -// +build ios +//go:build ios /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device -/* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */ - -const ( - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = 1700 - PreallocatedBuffersPerPool = 1024 +// Fit within memory limits for iOS's Network Extension API, which has stricter requirements. +// These are vars instead of consts, because heavier network extensions might want to reduce +// them further. +var ( + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + PreallocatedBuffersPerPool uint32 = 1024 ) + +const MaxSegmentSize = 1700 diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go new file mode 100644 index 0000000..1eee32b --- /dev/null +++ b/device/queueconstants_windows.go @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +const ( + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + MaxSegmentSize = 2048 - 32 // largest possible UDP datagram + PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth +) diff --git a/device/race_disabled_test.go b/device/race_disabled_test.go new file mode 100644 index 0000000..bb5c450 --- /dev/null +++ b/device/race_disabled_test.go @@ -0,0 +1,10 @@ +//go:build !race + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +const raceEnabled = false diff --git a/device/race_enabled_test.go b/device/race_enabled_test.go new file mode 100644 index 0000000..4e9daea --- /dev/null +++ b/device/race_enabled_test.go @@ -0,0 +1,10 @@ +//go:build race + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +const raceEnabled = true diff --git a/device/receive.go b/device/receive.go index 7d0693e..1ab3e29 100644 --- a/device/receive.go +++ b/device/receive.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,66 +8,46 @@ package device import ( "bytes" "encoding/binary" + "errors" "net" - "strconv" "sync" - "sync/atomic" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" ) type QueueHandshakeElement struct { msgType uint32 packet []byte - endpoint Endpoint + endpoint conn.Endpoint buffer *[MaxMessageSize]byte } type QueueInboundElement struct { - dropped int32 - sync.Mutex buffer *[MaxMessageSize]byte packet []byte counter uint64 keypair *Keypair - endpoint Endpoint -} - -func (elem *QueueInboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) + endpoint conn.Endpoint } -func (elem *QueueInboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool { - select { - case inboundQueue <- element: - select { - case decryptionQueue <- element: - return true - default: - element.Drop() - element.Unlock() - return false - } - default: - device.PutInboundElement(element) - return false - } +type QueueInboundElementsContainer struct { + sync.Mutex + elems []*QueueInboundElement } -func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool { - select { - case queue <- element: - return true - default: - return false - } +// clearPointers clears elem fields that contain pointers. +// This makes the garbage collector's life easier and +// avoids accidentally keeping other objects around unnecessarily. +// It also reduces the possible collateral damage from use-after-free bugs. +func (elem *QueueInboundElement) clearPointers() { + elem.buffer = nil + elem.packet = nil + elem.keypair = nil + elem.endpoint = nil } /* Called when a new authenticated message has been received @@ -75,12 +55,12 @@ func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem * NOTE: Not thread safe, but called by sequential receiver! */ func (peer *Peer) keepKeyFreshReceiving() { - if peer.timers.sentLastMinuteHandshake.Get() { + if peer.timers.sentLastMinuteHandshake.Load() { return } keypair := peer.keypairs.Current() if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { - peer.timers.sentLastMinuteHandshake.Set(true) + peer.timers.sentLastMinuteHandshake.Store(true) peer.SendHandshakeInitiation(false) } } @@ -90,188 +70,189 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { - - logDebug := device.log.Debug +func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { + recvName := recv.PrettyName() defer func() { - logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) + device.queue.decryption.wg.Done() + device.queue.handshake.wg.Done() device.net.stopping.Done() }() - logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started") - device.net.starting.Done() + device.log.Verbosef("Routine: receive incoming %s - started", recvName) // receive datagrams until conn is closed - buffer := device.GetMessageBuffer() - var ( - err error - size int - endpoint Endpoint + bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) + bufs = make([][]byte, maxBatchSize) + err error + sizes = make([]int, maxBatchSize) + count int + endpoints = make([]conn.Endpoint, maxBatchSize) + deathSpiral int + elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) - for { - - // read next datagram + for i := range bufsArrs { + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + } - switch IP { - case ipv4.Version: - size, endpoint, err = bind.ReceiveIPv4(buffer[:]) - case ipv6.Version: - size, endpoint, err = bind.ReceiveIPv6(buffer[:]) - default: - panic("invalid IP version") + defer func() { + for i := 0; i < maxBatchSize; i++ { + if bufsArrs[i] != nil { + device.PutMessageBuffer(bufsArrs[i]) + } } + }() + for { + count, err = recv(bufs, sizes, endpoints) if err != nil { - device.PutMessageBuffer(buffer) + if errors.Is(err, net.ErrClosed) { + return + } + device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) + if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { + return + } + if deathSpiral < 10 { + deathSpiral++ + time.Sleep(time.Second / 3) + continue + } return } + deathSpiral = 0 - if size < MinMessageSize { - continue - } - - // check size of packet + // handle each packet in the batch + for i, size := range sizes[:count] { + if size < MinMessageSize { + continue + } - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) + // check size of packet - var okay bool + packet := bufsArrs[i][:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - switch msgType { + switch msgType { - // check if transport + // check if transport - case MessageTransportType: + case MessageTransportType: - // check size + // check size - if len(packet) < MessageTransportSize { - continue - } + if len(packet) < MessageTransportSize { + continue + } - // lookup key pair + // lookup key pair - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indexTable.Lookup(receiver) - keypair := value.keypair - if keypair == nil { - continue - } + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { + continue + } - // check keypair expiry + // check keypair expiry - if keypair.created.Add(RejectAfterTime).Before(time.Now()) { - continue - } + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } - // create work element - peer := value.peer - elem := device.GetInboundElement() - elem.packet = packet - elem.buffer = buffer - elem.keypair = keypair - elem.dropped = AtomicFalse - elem.endpoint = endpoint - elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() - - // add to decryption queues - - if peer.isRunning.Get() { - if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { - buffer = device.GetMessageBuffer() + // create work element + peer := value.peer + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = bufsArrs[i] + elem.keypair = keypair + elem.endpoint = endpoints[i] + elem.counter = 0 + + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetInboundElementsContainer() + elemsForPeer.Lock() + elemsByPeer[peer] = elemsForPeer } - } + elemsForPeer.elems = append(elemsForPeer.elems, elem) + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + continue - continue + // otherwise it is a fixed size & handshake related packet - // otherwise it is a fixed size & handshake related packet + case MessageInitiationType: + if len(packet) != MessageInitiationSize { + continue + } - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize + case MessageResponseType: + if len(packet) != MessageResponseSize { + continue + } - case MessageResponseType: - okay = len(packet) == MessageResponseSize + case MessageCookieReplyType: + if len(packet) != MessageCookieReplySize { + continue + } - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize + default: + device.log.Verbosef("Received message with unknown type") + continue + } - default: - logDebug.Println("Received message with unknown type") + select { + case device.queue.handshake.c <- QueueHandshakeElement{ + msgType: msgType, + buffer: bufsArrs[i], + packet: packet, + endpoint: endpoints[i], + }: + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + default: + } } - - if okay { - if (device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - endpoint: endpoint, - }, - )) { - buffer = device.GetMessageBuffer() + for peer, elemsContainer := range elemsByPeer { + if peer.isRunning.Load() { + peer.queue.inbound.c <- elemsContainer + device.queue.decryption.c <- elemsContainer + } else { + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsContainer(elemsContainer) } + delete(elemsByPeer, peer) } } } -func (device *Device) RoutineDecryption() { - +func (device *Device) RoutineDecryption(id int) { var nonce [chacha20poly1305.NonceSize]byte - logDebug := device.log.Debug - defer func() { - logDebug.Println("Routine: decryption worker - stopped") - device.state.stopping.Done() - }() - logDebug.Println("Routine: decryption worker - started") - device.state.starting.Done() - - for { - select { - case <-device.signals.stop: - return - - case elem, ok := <-device.queue.decryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } + defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) + device.log.Verbosef("Routine: decryption worker %d - started", id) + for elemsContainer := range device.queue.decryption.c { + for _, elem := range elemsContainer.elems { // split message into fields - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] content := elem.packet[MessageTransportOffsetContent:] - // expand nonce - - nonce[0x4] = counter[0x0] - nonce[0x5] = counter[0x1] - nonce[0x6] = counter[0x2] - nonce[0x7] = counter[0x3] - - nonce[0x8] = counter[0x4] - nonce[0x9] = counter[0x5] - nonce[0xa] = counter[0x6] - nonce[0xb] = counter[0x7] - // decrypt and release to consumer - var err error elem.counter = binary.LittleEndian.Uint64(counter) + // copy counter to nonce + binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) elem.packet, err = elem.keypair.receive.Open( content[:0], nonce[:], @@ -279,51 +260,23 @@ func (device *Device) RoutineDecryption() { nil, ) if err != nil { - elem.Drop() - device.PutMessageBuffer(elem.buffer) + elem.packet = nil } - elem.Unlock() } + elemsContainer.Unlock() } } /* Handles incoming packets related to handshake */ -func (device *Device) RoutineHandshake() { - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - var elem QueueHandshakeElement - var ok bool - +func (device *Device) RoutineHandshake(id int) { defer func() { - logDebug.Println("Routine: handshake worker - stopped") - device.state.stopping.Done() - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - } + device.log.Verbosef("Routine: handshake worker %d - stopped", id) + device.queue.encryption.wg.Done() }() + device.log.Verbosef("Routine: handshake worker %d - started", id) - logDebug.Println("Routine: handshake worker - started") - device.state.starting.Done() - - for { - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - elem.buffer = nil - } - - select { - case elem, ok = <-device.queue.handshake: - case <-device.signals.stop: - return - } - - if !ok { - return - } + for elem := range device.queue.handshake.c { // handle cookie fields and ratelimiting @@ -337,8 +290,8 @@ func (device *Device) RoutineHandshake() { reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &reply) if err != nil { - logDebug.Println("Failed to decode cookie reply") - return + device.log.Verbosef("Failed to decode cookie reply") + goto skip } // lookup peer from index @@ -346,27 +299,27 @@ func (device *Device) RoutineHandshake() { entry := device.indexTable.Lookup(reply.Receiver) if entry.peer == nil { - continue + goto skip } // consume reply - if peer := entry.peer; peer.isRunning.Get() { - logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString()) + if peer := entry.peer; peer.isRunning.Load() { + device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) if !peer.cookieGenerator.ConsumeReply(&reply) { - logDebug.Println("Could not decrypt invalid cookie response") + device.log.Verbosef("Could not decrypt invalid cookie response") } } - continue + goto skip case MessageInitiationType, MessageResponseType: // check mac fields and maybe ratelimit if !device.cookieChecker.CheckMAC1(elem.packet) { - logDebug.Println("Received packet with invalid mac1") - continue + device.log.Verbosef("Received packet with invalid mac1") + goto skip } // endpoints destination address is the source of the datagram @@ -377,19 +330,19 @@ func (device *Device) RoutineHandshake() { if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { device.SendHandshakeCookie(&elem) - continue + goto skip } // check ratelimiter if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { - continue + goto skip } } default: - logError.Println("Invalid packet ended up in the handshake queue") - continue + device.log.Errorf("Invalid packet ended up in the handshake queue") + goto skip } // handle handshake initiation/response content @@ -403,19 +356,16 @@ func (device *Device) RoutineHandshake() { reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { - logError.Println("Failed to decode initiation message") - continue + device.log.Errorf("Failed to decode initiation message") + goto skip } // consume initiation peer := device.ConsumeMessageInitiation(&msg) if peer == nil { - logInfo.Println( - "Received invalid initiation message from", - elem.endpoint.DstToString(), - ) - continue + device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) + goto skip } // update timers @@ -426,8 +376,8 @@ func (device *Device) RoutineHandshake() { // update endpoint peer.SetEndpointFromPacket(elem.endpoint) - logDebug.Println(peer, "- Received handshake initiation") - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + device.log.Verbosef("%v - Received handshake initiation", peer) + peer.rxBytes.Add(uint64(len(elem.packet))) peer.SendHandshakeResponse() @@ -439,26 +389,23 @@ func (device *Device) RoutineHandshake() { reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { - logError.Println("Failed to decode response message") - continue + device.log.Errorf("Failed to decode response message") + goto skip } // consume response peer := device.ConsumeMessageResponse(&msg) if peer == nil { - logInfo.Println( - "Received invalid response message from", - elem.endpoint.DstToString(), - ) - continue + device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) + goto skip } // update endpoint peer.SetEndpointFromPacket(elem.endpoint) - logDebug.Println(peer, "- Received handshake response") - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + device.log.Verbosef("%v - Received handshake response", peer) + peer.rxBytes.Add(uint64(len(elem.packet))) // update timers @@ -470,178 +417,124 @@ func (device *Device) RoutineHandshake() { err = peer.BeginSymmetricSession() if err != nil { - logError.Println(peer, "- Failed to derive keypair:", err) - continue + device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) + goto skip } peer.timersSessionDerived() peer.timersHandshakeComplete() peer.SendKeepalive() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } } + skip: + device.PutMessageBuffer(elem.buffer) } } -func (peer *Peer) RoutineSequentialReceiver() { - +func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device := peer.device - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - var elem *QueueInboundElement - defer func() { - logDebug.Println(peer, "- Routine: sequential receiver - stopped") - peer.routines.stopping.Done() - if elem != nil { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - } - device.PutInboundElement(elem) - } + device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) + peer.stopping.Done() }() + device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - logDebug.Println(peer, "- Routine: sequential receiver - started") - - peer.routines.starting.Done() - - for { - if elem != nil { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - } - device.PutInboundElement(elem) - elem = nil - } + bufs := make([][]byte, 0, maxBatchSize) - var elemOk bool - select { - case <-peer.routines.stop: + for elemsContainer := range peer.queue.inbound.c { + if elemsContainer == nil { return - case elem, elemOk = <-peer.queue.inbound: - if !elemOk { - return - } - } - - // wait for decryption - - elem.Lock() - - if elem.IsDropped() { - continue - } - - // check for replay - - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - continue - } - - // update endpoint - peer.SetEndpointFromPacket(elem.endpoint) - - // check if using new keypair - if peer.ReceivedWithKeypair(elem.keypair) { - peer.timersHandshakeComplete() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } - } - - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) - - // check for keepalive - - if len(elem.packet) == 0 { - logDebug.Println(peer, "- Receiving keepalive packet") - continue } - peer.timersDataReceived() - - // verify source and strip padding - - switch elem.packet[0] >> 4 { - case ipv4.Version: - - // strip padding - - if len(elem.packet) < ipv4.HeaderLen { + elemsContainer.Lock() + validTailPacket := -1 + dataPacketReceived := false + rxBytesLen := uint64(0) + for i, elem := range elemsContainer.elems { + if elem.packet == nil { + // decryption failed continue } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { continue } - elem.packet = elem.packet[:length] - - // verify IPv4 source - - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.LookupIPv4(src) != peer { - logInfo.Println( - "IPv4 packet with disallowed source address from", - peer, - ) - continue + validTailPacket = i + if peer.ReceivedWithKeypair(elem.keypair) { + peer.SetEndpointFromPacket(elem.endpoint) + peer.timersHandshakeComplete() + peer.SendStagedPackets() } + rxBytesLen += uint64(len(elem.packet) + MinMessageSize) - case ipv6.Version: - - // strip padding - - if len(elem.packet) < ipv6.HeaderLen { + if len(elem.packet) == 0 { + device.log.Verbosef("%v - Receiving keepalive packet", peer) continue } + dataPacketReceived = true - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } - - elem.packet = elem.packet[:length] + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) + continue + } - // verify IPv6 source + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) + continue + } - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.LookupIPv6(src) != peer { - logInfo.Println( - "IPv6 packet with disallowed source address from", - peer, - ) + default: + device.log.Verbosef("Packet with invalid IP version from %v", peer) continue } - default: - logInfo.Println("Packet with invalid IP version from", peer) - continue + bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } - // write to tun device - - offset := MessageTransportOffsetContent - _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) - if len(peer.queue.inbound) == 0 { - err = device.tun.device.Flush() - if err != nil { - peer.device.log.Error.Printf("Unable to flush packets: %v", err) + peer.rxBytes.Add(rxBytesLen) + if validTailPacket >= 0 { + peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + } + if dataPacketReceived { + peer.timersDataReceived() + } + if len(bufs) > 0 { + _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) + if err != nil && !device.isClosed() { + device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - if err != nil && !device.isClosed.Get() { - logError.Println("Failed to write packet to TUN device:", err) + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) } + bufs = bufs[:0] + device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index 72633be..769720a 100644 --- a/device/send.go +++ b/device/send.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,14 +8,17 @@ package device import ( "bytes" "encoding/binary" + "errors" "net" + "os" "sync" - "sync/atomic" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/tun" ) /* Outbound flow @@ -43,8 +46,6 @@ import ( */ type QueueOutboundElement struct { - dropped int32 - sync.Mutex buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption @@ -52,80 +53,52 @@ type QueueOutboundElement struct { peer *Peer // related peer } +type QueueOutboundElementsContainer struct { + sync.Mutex + elems []*QueueOutboundElement +} + func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() - elem.dropped = AtomicFalse elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} elem.nonce = 0 - elem.keypair = nil - elem.peer = nil + // keypair and peer were cleared (if necessary) by clearPointers. return elem } -func (elem *QueueOutboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueOutboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - device.PutMessageBuffer(old.buffer) - device.PutOutboundElement(old) - default: - } - } - } +// clearPointers clears elem fields that contain pointers. +// This makes the garbage collector's life easier and +// avoids accidentally keeping other objects around unnecessarily. +// It also reduces the possible collateral damage from use-after-free bugs. +func (elem *QueueOutboundElement) clearPointers() { + elem.buffer = nil + elem.packet = nil + elem.keypair = nil + elem.peer = nil } -func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) { - select { - case outboundQueue <- element: +/* Queues a keepalive if no packets are queued for peer + */ +func (peer *Peer) SendKeepalive() { + if len(peer.queue.staged) == 0 && peer.isRunning.Load() { + elem := peer.device.NewOutboundElement() + elemsContainer := peer.device.GetOutboundElementsContainer() + elemsContainer.elems = append(elemsContainer.elems, elem) select { - case encryptionQueue <- element: - return + case peer.queue.staged <- elemsContainer: + peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: - element.Drop() - element.peer.device.PutMessageBuffer(element.buffer) - element.Unlock() + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + peer.device.PutOutboundElementsContainer(elemsContainer) } - default: - element.peer.device.PutMessageBuffer(element.buffer) - element.peer.device.PutOutboundElement(element) - } -} - -/* Queues a keepalive if no packets are queued for peer - */ -func (peer *Peer) SendKeepalive() bool { - if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() { - return false - } - elem := peer.device.NewOutboundElement() - elem.packet = nil - select { - case peer.queue.nonce <- elem: - peer.device.log.Debug.Println(peer, "- Sending keepalive packet") - return true - default: - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) - return false } + peer.SendStagedPackets() } func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { if !isRetry { - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.handshakeAttempts.Store(0) } peer.handshake.mutex.RLock() @@ -143,16 +116,16 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.handshake.lastSentHandshake = time.Now() peer.handshake.mutex.Unlock() - peer.device.log.Debug.Println(peer, "- Sending handshake initiation") + peer.device.log.Verbosef("%v - Sending handshake initiation", peer) msg, err := peer.device.CreateMessageInitiation(peer) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err) + peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) return err } - var buff [MessageInitiationSize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageInitiationSize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, msg) packet := writer.Bytes() peer.cookieGenerator.AddMacs(packet) @@ -160,9 +133,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + err = peer.SendBuffers([][]byte{packet}) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err) + peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } peer.timersHandshakeInitiated() @@ -174,23 +147,23 @@ func (peer *Peer) SendHandshakeResponse() error { peer.handshake.lastSentHandshake = time.Now() peer.handshake.mutex.Unlock() - peer.device.log.Debug.Println(peer, "- Sending handshake response") + peer.device.log.Verbosef("%v - Sending handshake response", peer) response, err := peer.device.CreateMessageResponse(peer) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to create response message:", err) + peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) return err } - var buff [MessageResponseSize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageResponseSize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, response) packet := writer.Bytes() peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() if err != nil { - peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err) + peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) return err } @@ -198,28 +171,29 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + // TODO: allocation could be avoided + err = peer.SendBuffers([][]byte{packet}) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to send handshake response", err) + peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } return err } func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { - - device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString()) + device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) if err != nil { - device.log.Error.Println("Failed to create cookie reply:", err) + device.log.Errorf("Failed to create cookie reply: %v", err) return err } - var buff [MessageCookieReplySize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageCookieReplySize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + // TODO: allocation could be avoided + device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) return nil } @@ -228,280 +202,255 @@ func (peer *Peer) keepKeyFreshSending() { if keypair == nil { return } - nonce := atomic.LoadUint64(&keypair.sendNonce) + nonce := keypair.sendNonce.Load() if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { peer.SendHandshakeInitiation(false) } } -/* Reads packets from the TUN and inserts - * into nonce queue for peer - * - * Obs. Single instance per TUN device - */ func (device *Device) RoutineReadFromTUN() { - - logDebug := device.log.Debug - logError := device.log.Error - defer func() { - logDebug.Println("Routine: TUN reader - stopped") + device.log.Verbosef("Routine: TUN reader - stopped") device.state.stopping.Done() + device.queue.encryption.wg.Done() }() - logDebug.Println("Routine: TUN reader - started") - device.state.starting.Done() - - var elem *QueueOutboundElement + device.log.Verbosef("Routine: TUN reader - started") + + var ( + batchSize = device.BatchSize() + readErr error + elems = make([]*QueueOutboundElement, batchSize) + bufs = make([][]byte, batchSize) + elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) + count = 0 + sizes = make([]int, batchSize) + offset = MessageTransportHeaderSize + ) + + for i := range elems { + elems[i] = device.NewOutboundElement() + bufs[i] = elems[i].buffer[:] + } - for { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + defer func() { + for _, elem := range elems { + if elem != nil { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } } - elem = device.NewOutboundElement() - - // read packet - - offset := MessageTransportHeaderSize - size, err := device.tun.device.Read(elem.buffer[:], offset) + }() - if err != nil { - if !device.isClosed.Get() { - logError.Println("Failed to read packet from TUN device:", err) - device.Close() + for { + // read packets + count, readErr = device.tun.device.Read(bufs, sizes, offset) + for i := 0; i < count; i++ { + if sizes[i] < 1 { + continue } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return - } - if size == 0 || size > MaxContentSize { - continue - } + elem := elems[i] + elem.packet = bufs[i][offset : offset+sizes[i]] - elem.packet = elem.buffer[offset : offset+size] + // lookup peer + var peer *Peer + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.allowedips.Lookup(dst) - // lookup peer + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.allowedips.Lookup(dst) - var peer *Peer - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue + default: + device.log.Verbosef("Received packet with unknown IP version") } - dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.allowedips.LookupIPv4(dst) - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { + if peer == nil { continue } - dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.allowedips.LookupIPv6(dst) - - default: - logDebug.Println("Received packet with unknown IP version") + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetOutboundElementsContainer() + elemsByPeer[peer] = elemsForPeer + } + elemsForPeer.elems = append(elemsForPeer.elems, elem) + elems[i] = device.NewOutboundElement() + bufs[i] = elems[i].buffer[:] } - if peer == nil { - continue + for peer, elemsForPeer := range elemsByPeer { + if peer.isRunning.Load() { + peer.StagePackets(elemsForPeer) + peer.SendStagedPackets() + } else { + for _, elem := range elemsForPeer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsContainer(elemsForPeer) + } + delete(elemsByPeer, peer) } - // insert into nonce/pre-handshake queue - - if peer.isRunning.Get() { - if peer.queue.packetInNonceQueueIsAwaitingKey.Get() { - peer.SendHandshakeInitiation(false) + if readErr != nil { + if errors.Is(readErr, tun.ErrTooManySegments) { + // TODO: record stat for this + // This will happen if MSS is surprisingly small (< 576) + // coincident with reasonably high throughput. + device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) + continue } - addToNonceQueue(peer.queue.nonce, elem, device) - elem = nil + if !device.isClosed() { + if !errors.Is(readErr, os.ErrClosed) { + device.log.Errorf("Failed to read packet from TUN device: %v", readErr) + } + go device.Close() + } + return } } } -func (peer *Peer) FlushNonceQueue() { - select { - case peer.signals.flushNonceQueue <- struct{}{}: - default: - } -} - -/* Queues packets when there is no handshake. - * Then assigns nonces to packets sequentially - * and creates "work" structs for workers - * - * Obs. A single instance per peer - */ -func (peer *Peer) RoutineNonce() { - var keypair *Keypair - - device := peer.device - logDebug := device.log.Debug - - flush := func() { - for { - select { - case elem := <-peer.queue.nonce: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - default: - return +func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { + for { + select { + case peer.queue.staged <- elems: + return + default: + } + select { + case tooOld := <-peer.queue.staged: + for _, elem := range tooOld.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) } + peer.device.PutOutboundElementsContainer(tooOld) + default: } } +} - defer func() { - flush() - logDebug.Println(peer, "- Routine: nonce worker - stopped") - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - peer.routines.stopping.Done() - }() +func (peer *Peer) SendStagedPackets() { +top: + if len(peer.queue.staged) == 0 || !peer.device.isUp() { + return + } - peer.routines.starting.Done() - logDebug.Println(peer, "- Routine: nonce worker - started") + keypair := peer.keypairs.Current() + if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { + peer.SendHandshakeInitiation(false) + return + } for { - NextPacket: - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - + var elemsContainerOOO *QueueOutboundElementsContainer select { - case <-peer.routines.stop: - return - - case <-peer.signals.flushNonceQueue: - flush() - goto NextPacket - - case elem, ok := <-peer.queue.nonce: - - if !ok { - return - } - - // make sure to always pick the newest key - - for { - - // check validity of newest key pair - - keypair = peer.keypairs.Current() - if keypair != nil && keypair.sendNonce < RejectAfterMessages { - if time.Since(keypair.created) < RejectAfterTime { - break + case elemsContainer := <-peer.queue.staged: + i := 0 + for _, elem := range elemsContainer.elems { + elem.peer = peer + elem.nonce = keypair.sendNonce.Add(1) - 1 + if elem.nonce >= RejectAfterMessages { + keypair.sendNonce.Store(RejectAfterMessages) + if elemsContainerOOO == nil { + elemsContainerOOO = peer.device.GetOutboundElementsContainer() } - } - peer.queue.packetInNonceQueueIsAwaitingKey.Set(true) - - // no suitable key pair, request for new handshake - - select { - case <-peer.signals.newKeypairArrived: - default: + elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) + continue + } else { + elemsContainer.elems[i] = elem + i++ } - peer.SendHandshakeInitiation(false) - - // wait for key to be established - - logDebug.Println(peer, "- Awaiting keypair") + elem.keypair = keypair + } + elemsContainer.Lock() + elemsContainer.elems = elemsContainer.elems[:i] - select { - case <-peer.signals.newKeypairArrived: - logDebug.Println(peer, "- Obtained awaited keypair") + if elemsContainerOOO != nil { + peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans + } - case <-peer.signals.flushNonceQueue: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - flush() - goto NextPacket + if len(elemsContainer.elems) == 0 { + peer.device.PutOutboundElementsContainer(elemsContainer) + goto top + } - case <-peer.routines.stop: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return + // add to parallel and sequential queue + if peer.isRunning.Load() { + peer.queue.outbound.c <- elemsContainer + peer.device.queue.encryption.c <- elemsContainer + } else { + for _, elem := range elemsContainer.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) } + peer.device.PutOutboundElementsContainer(elemsContainer) } - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - - // populate work element - elem.peer = peer - elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 - - // double check in case of race condition added by future code - - if elem.nonce >= RejectAfterMessages { - atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - goto NextPacket + if elemsContainerOOO != nil { + goto top } + default: + return + } + } +} - elem.keypair = keypair - elem.dropped = AtomicFalse - elem.Lock() - - // add to parallel and sequential queue - addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem) +func (peer *Peer) FlushStagedPackets() { + for { + select { + case elemsContainer := <-peer.queue.staged: + for _, elem := range elemsContainer.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsContainer(elemsContainer) + default: + return } } } +func calculatePaddingSize(packetSize, mtu int) int { + lastUnit := packetSize + if mtu == 0 { + return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit + } + if lastUnit > mtu { + lastUnit %= mtu + } + paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) + if paddedSize > mtu { + paddedSize = mtu + } + return paddedSize - lastUnit +} + /* Encrypts the elements in the queue * and marks them for sequential consumption (by releasing the mutex) * * Obs. One instance per core */ -func (device *Device) RoutineEncryption() { - +func (device *Device) RoutineEncryption(id int) { + var paddingZeros [PaddingMultiple]byte var nonce [chacha20poly1305.NonceSize]byte - logDebug := device.log.Debug - - defer func() { - for { - select { - case elem, ok := <-device.queue.encryption: - if ok && !elem.IsDropped() { - elem.Drop() - device.PutMessageBuffer(elem.buffer) - elem.Unlock() - } - default: - goto out - } - } - out: - logDebug.Println("Routine: encryption worker - stopped") - device.state.stopping.Done() - }() - - logDebug.Println("Routine: encryption worker - started") - device.state.starting.Done() - - for { - - // fetch next element - - select { - case <-device.signals.stop: - return - - case elem, ok := <-device.queue.encryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } + defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) + device.log.Verbosef("Routine: encryption worker %d - started", id) + for elemsContainer := range device.queue.encryption.c { + for _, elem := range elemsContainer.elems { // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] fieldType := header[0:4] @@ -513,16 +462,8 @@ func (device *Device) RoutineEncryption() { binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) // pad content to multiple of 16 - - mtu := int(atomic.LoadInt32(&device.tun.mtu)) - lastUnit := len(elem.packet) % mtu - paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1) - if paddedSize > mtu { - paddedSize = mtu - } - for i := len(elem.packet); i < paddedSize; i++ { - elem.packet = append(elem.packet, 0) - } + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) // encrypt content and release to consumer @@ -533,82 +474,73 @@ func (device *Device) RoutineEncryption() { elem.packet, nil, ) - elem.Unlock() } + elemsContainer.Unlock() } } -/* Sequentially reads packets from queue and sends to endpoint - * - * Obs. Single instance per peer. - * The routine terminates then the outbound queue is closed. - */ -func (peer *Peer) RoutineSequentialSender() { - +func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device := peer.device - - logDebug := device.log.Debug - logError := device.log.Error - defer func() { - for { - select { - case elem, ok := <-peer.queue.outbound: - if ok { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - elem.Drop() - } - device.PutOutboundElement(elem) - } - default: - goto out - } - } - out: - logDebug.Println(peer, "- Routine: sequential sender - stopped") - peer.routines.stopping.Done() + defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) + peer.stopping.Done() }() + device.log.Verbosef("%v - Routine: sequential sender - started", peer) - logDebug.Println(peer, "- Routine: sequential sender - started") - - peer.routines.starting.Done() - - for { - select { + bufs := make([][]byte, 0, maxBatchSize) - case <-peer.routines.stop: + for elemsContainer := range peer.queue.outbound.c { + bufs = bufs[:0] + if elemsContainer == nil { return - - case elem, ok := <-peer.queue.outbound: - - if !ok { - return - } - - elem.Lock() - if elem.IsDropped() { + } + if !peer.isRunning.Load() { + // peer has been stopped; return re-usable elems to the shared pool. + // This is an optimization only. It is possible for the peer to be stopped + // immediately after this check, in which case, elem will get processed. + // The timers and SendBuffers code are resilient to a few stragglers. + // TODO: rework peer shutdown order to ensure + // that we never accidentally keep timers alive longer than necessary. + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) - continue } - - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketSent() - - // send message and return buffer to pool - - err := peer.SendBuffer(elem.packet) + continue + } + dataSent := false + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { - peer.timersDataSent() + dataSent = true } + bufs = append(bufs, elem.packet) + } + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() + + err := peer.SendBuffers(bufs) + if dataSent { + peer.timersDataSent() + } + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) - if err != nil { - logError.Println(peer, "- Failed to send data packet", err) - continue + } + device.PutOutboundElementsContainer(elemsContainer) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr } - - peer.keepKeyFreshSending() } + if err != nil { + device.log.Errorf("%v - Failed to send data packets: %v", peer, err) + continue + } + + peer.keepKeyFreshSending() } } diff --git a/device/sticky_default.go b/device/sticky_default.go new file mode 100644 index 0000000..1038256 --- /dev/null +++ b/device/sticky_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +package device + +import ( + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/rwcancel" +) + +func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + return nil, nil +} diff --git a/device/sticky_linux.go b/device/sticky_linux.go new file mode 100644 index 0000000..6057ff1 --- /dev/null +++ b/device/sticky_linux.go @@ -0,0 +1,224 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + * + * This implements userspace semantics of "sticky sockets", modeled after + * WireGuard's kernelspace implementation. This is more or less a straight port + * of the sticky-sockets.c example code: + * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c + * + * Currently there is no way to achieve this within the net package: + * See e.g. https://github.com/golang/go/issues/17930 + * So this code is remains platform dependent. + */ + +package device + +import ( + "sync" + "unsafe" + + "golang.org/x/sys/unix" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/rwcancel" +) + +func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + if !conn.StdNetSupportsStickySockets { + return nil, nil + } + if _, ok := bind.(*conn.StdNetBind); !ok { + return nil, nil + } + + netlinkSock, err := createNetlinkRouteSocket() + if err != nil { + return nil, err + } + netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock) + if err != nil { + unix.Close(netlinkSock) + return nil, err + } + + go device.routineRouteListener(bind, netlinkSock, netlinkCancel) + + return netlinkCancel, nil +} + +func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) { + type peerEndpointPtr struct { + peer *Peer + endpoint *conn.Endpoint + } + var reqPeer map[uint32]peerEndpointPtr + var reqPeerLock sync.Mutex + + defer netlinkCancel.Close() + defer unix.Close(netlinkSock) + + for msg := make([]byte, 1<<16); ; { + var err error + var msgn int + for { + msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0) + if err == nil || !rwcancel.RetryAfterError(err) { + break + } + if !netlinkCancel.ReadyRead() { + return + } + } + if err != nil { + return + } + + for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { + + hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) + + if uint(hdr.Len) > uint(len(remain)) { + break + } + + switch hdr.Type { + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + if hdr.Seq <= MaxPeers && hdr.Seq > 0 { + if uint(len(remain)) < uint(hdr.Len) { + break + } + if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { + attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] + for { + if uint(len(attr)) < uint(unix.SizeofRtAttr) { + break + } + attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) + if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { + break + } + if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { + ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) + reqPeerLock.Lock() + if reqPeer == nil { + reqPeerLock.Unlock() + break + } + pePtr, ok := reqPeer[hdr.Seq] + reqPeerLock.Unlock() + if !ok { + break + } + pePtr.peer.endpoint.Lock() + if &pePtr.peer.endpoint.val != pePtr.endpoint { + pePtr.peer.endpoint.Unlock() + break + } + if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { + pePtr.peer.endpoint.Unlock() + break + } + pePtr.peer.endpoint.clearSrcOnTx = true + pePtr.peer.endpoint.Unlock() + } + attr = attr[attrhdr.Len:] + } + } + break + } + reqPeerLock.Lock() + reqPeer = make(map[uint32]peerEndpointPtr) + reqPeerLock.Unlock() + go func() { + device.peers.RLock() + i := uint32(1) + for _, peer := range device.peers.keyMap { + peer.endpoint.Lock() + if peer.endpoint.val == nil { + peer.endpoint.Unlock() + continue + } + nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) + if nativeEP == nil { + peer.endpoint.Unlock() + continue + } + if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { + peer.endpoint.Unlock() + break + } + nlmsg := struct { + hdr unix.NlMsghdr + msg unix.RtMsg + dsthdr unix.RtAttr + dst [4]byte + srchdr unix.RtAttr + src [4]byte + markhdr unix.RtAttr + mark uint32 + }{ + unix.NlMsghdr{ + Type: uint16(unix.RTM_GETROUTE), + Flags: unix.NLM_F_REQUEST, + Seq: i, + }, + unix.RtMsg{ + Family: unix.AF_INET, + Dst_len: 32, + Src_len: 32, + }, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_DST, + }, + nativeEP.DstIP().As4(), + unix.RtAttr{ + Len: 8, + Type: unix.RTA_SRC, + }, + nativeEP.SrcIP().As4(), + unix.RtAttr{ + Len: 8, + Type: unix.RTA_MARK, + }, + device.net.fwmark, + } + nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) + reqPeerLock.Lock() + reqPeer[i] = peerEndpointPtr{ + peer: peer, + endpoint: &peer.endpoint.val, + } + reqPeerLock.Unlock() + peer.endpoint.Unlock() + i++ + _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) + if err != nil { + break + } + } + device.peers.RUnlock() + }() + } + remain = remain[hdr.Len:] + } + } +} + +func createNetlinkRouteSocket() (int, error) { + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) + if err != nil { + return -1, err + } + saddr := &unix.SockaddrNetlink{ + Family: unix.AF_NETLINK, + Groups: unix.RTMGRP_IPV4_ROUTE, + } + err = unix.Bind(sock, saddr) + if err != nil { + unix.Close(sock) + return -1, err + } + return sock, nil +} diff --git a/device/timers.go b/device/timers.go index 18ee736..d4a4ed4 100644 --- a/device/timers.go +++ b/device/timers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * * This is based heavily on timers.c from the kernel implementation. */ @@ -8,16 +8,16 @@ package device import ( - "math/rand" "sync" - "sync/atomic" "time" + _ "unsafe" ) -/* This Timer structure and related functions should roughly copy the interface of - * the Linux kernel's struct timer_list. - */ +//go:linkname fastrandn runtime.fastrandn +func fastrandn(n uint32) uint32 +// A Timer manages time-based aspects of the WireGuard protocol. +// Timer roughly copies the interface of the Linux kernel's struct timer_list. type Timer struct { *time.Timer modifyingLock sync.RWMutex @@ -29,18 +29,17 @@ func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { timer := &Timer{} timer.Timer = time.AfterFunc(time.Hour, func() { timer.runningLock.Lock() + defer timer.runningLock.Unlock() timer.modifyingLock.Lock() if !timer.isPending { timer.modifyingLock.Unlock() - timer.runningLock.Unlock() return } timer.isPending = false timer.modifyingLock.Unlock() expirationFunction(peer) - timer.runningLock.Unlock() }) timer.Stop() return timer @@ -74,12 +73,12 @@ func (timer *Timer) IsPending() bool { } func (peer *Peer) timersActive() bool { - return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0 + return peer.isRunning.Load() && peer.device != nil && peer.device.isUp() } func expiredRetransmitHandshake(peer *Peer) { - if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { - peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2) + if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes { + peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) if peer.timersActive() { peer.timers.sendKeepalive.Del() @@ -88,7 +87,7 @@ func expiredRetransmitHandshake(peer *Peer) { /* We drop all packets without a keypair and don't try again, * if we try unsuccessfully for too long to make a handshake. */ - peer.FlushNonceQueue() + peer.FlushStagedPackets() /* We set a timer for destroying any residue that might be left * of a partial exchange. @@ -97,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) { peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) } } else { - atomic.AddUint32(&peer.timers.handshakeAttempts, 1) - peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) + peer.timers.handshakeAttempts.Add(1) + peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(true) } @@ -113,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) { func expiredSendKeepalive(peer *Peer) { peer.SendKeepalive() - if peer.timers.needAnotherKeepalive.Get() { - peer.timers.needAnotherKeepalive.Set(false) + if peer.timers.needAnotherKeepalive.Load() { + peer.timers.needAnotherKeepalive.Store(false) if peer.timersActive() { peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } @@ -122,24 +117,19 @@ func expiredSendKeepalive(peer *Peer) { } func expiredNewHandshake(peer *Peer) { - peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) + peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(false) - } func expiredZeroKeyMaterial(peer *Peer) { - peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) + peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds())) peer.ZeroAndFlushAll() } func expiredPersistentKeepalive(peer *Peer) { - if peer.persistentKeepaliveInterval > 0 { + if peer.persistentKeepaliveInterval.Load() > 0 { peer.SendKeepalive() } } @@ -147,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) { /* Should be called after an authenticated data packet is sent. */ func (peer *Peer) timersDataSent() { if peer.timersActive() && !peer.timers.newHandshake.IsPending() { - peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) + peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) } } @@ -157,7 +147,7 @@ func (peer *Peer) timersDataReceived() { if !peer.timers.sendKeepalive.IsPending() { peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } else { - peer.timers.needAnotherKeepalive.Set(true) + peer.timers.needAnotherKeepalive.Store(true) } } } @@ -179,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() { /* Should be called after a handshake initiation message is sent. */ func (peer *Peer) timersHandshakeInitiated() { if peer.timersActive() { - peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) + peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) } } @@ -188,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() { if peer.timersActive() { peer.timers.retransmitHandshake.Del() } - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - peer.timers.sentLastMinuteHandshake.Set(false) - atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) + peer.timers.handshakeAttempts.Store(0) + peer.timers.sentLastMinuteHandshake.Store(false) + peer.lastHandshakeNano.Store(time.Now().UnixNano()) } /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ @@ -202,8 +192,9 @@ func (peer *Peer) timersSessionDerived() { /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { - if peer.persistentKeepaliveInterval > 0 && peer.timersActive() { - peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second) + keepalive := peer.persistentKeepaliveInterval.Load() + if keepalive > 0 && peer.timersActive() { + peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) } } @@ -213,9 +204,12 @@ func (peer *Peer) timersInit() { peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - peer.timers.sentLastMinuteHandshake.Set(false) - peer.timers.needAnotherKeepalive.Set(false) +} + +func (peer *Peer) timersStart() { + peer.timers.handshakeAttempts.Store(0) + peer.timers.sentLastMinuteHandshake.Store(false) + peer.timers.needAnotherKeepalive.Store(false) } func (peer *Peer) timersStop() { diff --git a/device/tun.go b/device/tun.go index 0a3fc79..2a2ace9 100644 --- a/device/tun.go +++ b/device/tun.go @@ -1,12 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "sync/atomic" + "fmt" "golang.zx2c4.com/wireguard/tun" ) @@ -14,43 +14,40 @@ import ( const DefaultMTU = 1420 func (device *Device) RoutineTUNEventReader() { - setUp := false - logDebug := device.log.Debug - logInfo := device.log.Info - logError := device.log.Error - - logDebug.Println("Routine: event worker - started") - device.state.starting.Done() + device.log.Verbosef("Routine: event worker - started") for event := range device.tun.device.Events() { if event&tun.EventMTUUpdate != 0 { mtu, err := device.tun.device.MTU() - old := atomic.LoadInt32(&device.tun.mtu) if err != nil { - logError.Println("Failed to load updated MTU of device:", err) - } else if int(old) != mtu { - if mtu+MessageTransportSize > MaxMessageSize { - logInfo.Println("MTU updated:", mtu, "(too large)") - } else { - logInfo.Println("MTU updated:", mtu) - } - atomic.StoreInt32(&device.tun.mtu, int32(mtu)) + device.log.Errorf("Failed to load updated MTU of device: %v", err) + continue + } + if mtu < 0 { + device.log.Errorf("MTU not updated to negative value: %v", mtu) + continue + } + var tooLarge string + if mtu > MaxContentSize { + tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) + mtu = MaxContentSize + } + old := device.tun.mtu.Swap(int32(mtu)) + if int(old) != mtu { + device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) } } - if event&tun.EventUp != 0 && !setUp { - logInfo.Println("Interface set up") - setUp = true + if event&tun.EventUp != 0 { + device.log.Verbosef("Interface up requested") device.Up() } - if event&tun.EventDown != 0 && setUp { - logInfo.Println("Interface set down") - setUp = false + if event&tun.EventDown != 0 { + device.log.Verbosef("Interface down requested") device.Down() } } - logDebug.Println("Routine: event worker - stopped") - device.state.stopping.Done() + device.log.Verbosef("Routine: event worker - stopped") } diff --git a/device/tun_test.go b/device/tun_test.go deleted file mode 100644 index 5614771..0000000 --- a/device/tun_test.go +++ /dev/null @@ -1,56 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "errors" - "os" - - "golang.zx2c4.com/wireguard/tun" -) - -// newDummyTUN creates a dummy TUN device with the specified name. -func newDummyTUN(name string) tun.Device { - return &dummyTUN{ - name: name, - packets: make(chan []byte, 100), - events: make(chan tun.Event, 10), - } -} - -// A dummyTUN is a tun.Device which is used in unit tests. -type dummyTUN struct { - name string - mtu int - packets chan []byte - events chan tun.Event -} - -func (d *dummyTUN) Events() chan tun.Event { return d.events } -func (*dummyTUN) File() *os.File { return nil } -func (*dummyTUN) Flush() error { return nil } -func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil } -func (d *dummyTUN) Name() (string, error) { return d.name, nil } - -func (d *dummyTUN) Close() error { - close(d.events) - close(d.packets) - return nil -} - -func (d *dummyTUN) Read(b []byte, offset int) (int, error) { - buf, ok := <-d.packets - if !ok { - return 0, errors.New("device closed") - } - copy(b[offset:], buf) - return len(buf), nil -} - -func (d *dummyTUN) Write(b []byte, offset int) (int, error) { - d.packets <- b[offset:] - return len(b), nil -} diff --git a/device/uapi.go b/device/uapi.go index 999eeb5..d81dae3 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -1,43 +1,77 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "bufio" + "bytes" + "errors" "fmt" "io" "net" + "net/netip" "strconv" "strings" - "sync/atomic" + "sync" "time" "golang.zx2c4.com/wireguard/ipc" ) type IPCError struct { - int64 + code int64 // error code + err error // underlying/wrapped error } func (s IPCError) Error() string { - return fmt.Sprintf("IPC error: %d", s.int64) + return fmt.Sprintf("IPC error %d: %v", s.code, s.err) +} + +func (s IPCError) Unwrap() error { + return s.err } func (s IPCError) ErrorCode() int64 { - return s.int64 + return s.code +} + +func ipcErrorf(code int64, msg string, args ...any) *IPCError { + return &IPCError{code: code, err: fmt.Errorf(msg, args...)} } -func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError { - lines := make([]string, 0, 100) - send := func(line string) { - lines = append(lines, line) +var byteBufferPool = &sync.Pool{ + New: func() any { return new(bytes.Buffer) }, +} + +// IpcGetOperation implements the WireGuard configuration protocol "get" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. +func (device *Device) IpcGetOperation(w io.Writer) error { + device.ipcMutex.RLock() + defer device.ipcMutex.RUnlock() + + buf := byteBufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer byteBufferPool.Put(buf) + sendf := func(format string, args ...any) { + fmt.Fprintf(buf, format, args...) + buf.WriteByte('\n') + } + keyf := func(prefix string, key *[32]byte) { + buf.Grow(len(key)*2 + 2 + len(prefix)) + buf.WriteString(prefix) + buf.WriteByte('=') + const hex = "0123456789abcdef" + for i := 0; i < len(key); i++ { + buf.WriteByte(hex[key[i]>>4]) + buf.WriteByte(hex[key[i]&0xf]) + } + buf.WriteByte('\n') } func() { - // lock required resources device.net.RLock() @@ -52,353 +86,326 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError { // serialize device related values if !device.staticIdentity.privateKey.IsZero() { - send("private_key=" + device.staticIdentity.privateKey.ToHex()) + keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey)) } if device.net.port != 0 { - send(fmt.Sprintf("listen_port=%d", device.net.port)) + sendf("listen_port=%d", device.net.port) } if device.net.fwmark != 0 { - send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) + sendf("fwmark=%d", device.net.fwmark) } - // serialize each peer state - for _, peer := range device.peers.keyMap { - peer.RLock() - defer peer.RUnlock() - - send("public_key=" + peer.handshake.remoteStatic.ToHex()) - send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - send("protocol_version=1") - if peer.endpoint != nil { - send("endpoint=" + peer.endpoint.DstToString()) + // Serialize peer state. + peer.handshake.mutex.RLock() + keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) + keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) + peer.handshake.mutex.RUnlock() + sendf("protocol_version=1") + peer.endpoint.Lock() + if peer.endpoint.val != nil { + sendf("endpoint=%s", peer.endpoint.val.DstToString()) } + peer.endpoint.Unlock() - nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) + nano := peer.lastHandshakeNano.Load() secs := nano / time.Second.Nanoseconds() nano %= time.Second.Nanoseconds() - send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) - send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) - send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))) - send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))) - send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) - - for _, ip := range device.allowedips.EntriesForPeer(peer) { - send("allowed_ip=" + ip.String()) - } + sendf("last_handshake_time_sec=%d", secs) + sendf("last_handshake_time_nsec=%d", nano) + sendf("tx_bytes=%d", peer.txBytes.Load()) + sendf("rx_bytes=%d", peer.rxBytes.Load()) + sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) + return true + }) } }() // send lines (does not require resource locks) - - for _, line := range lines { - _, err := socket.WriteString(line + "\n") - if err != nil { - return &IPCError{ipc.IpcErrorIO} - } + if _, err := w.Write(buf.Bytes()); err != nil { + return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err) } return nil } -func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError { - scanner := bufio.NewScanner(socket) - logError := device.log.Error - logDebug := device.log.Debug +// IpcSetOperation implements the WireGuard configuration protocol "set" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. +func (device *Device) IpcSetOperation(r io.Reader) (err error) { + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() - var peer *Peer + defer func() { + if err != nil { + device.log.Errorf("%v", err) + } + }() - dummy := false - createdNewPeer := false + peer := new(ipcSetPeer) deviceConfig := true + scanner := bufio.NewScanner(r) for scanner.Scan() { - - // parse line - line := scanner.Text() if line == "" { + // Blank line means terminate operation. + peer.handlePostConfig() return nil } - parts := strings.Split(line, "=") - if len(parts) != 2 { - return &IPCError{ipc.IpcErrorProtocol} + key, value, ok := strings.Cut(line, "=") + if !ok { + return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line) } - key := parts[0] - value := parts[1] - - /* device configuration */ - - if deviceConfig { - - switch key { - case "private_key": - var sk NoisePrivateKey - err := sk.FromHex(value) - if err != nil { - logError.Println("Failed to set private_key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - logDebug.Println("UAPI: Updating private key") - device.SetPrivateKey(sk) - - case "listen_port": - - // parse port number - - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to parse listen_port:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - // update port and rebind - - logDebug.Println("UAPI: Updating listen port") - - device.net.Lock() - device.net.port = uint16(port) - device.net.Unlock() - - if err := device.BindUpdate(); err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{ipc.IpcErrorPortInUse} - } - case "fwmark": - - // parse fwmark field - - fwmark, err := func() (uint32, error) { - if value == "" { - return 0, nil - } - mark, err := strconv.ParseUint(value, 10, 32) - return uint32(mark), err - }() - - if err != nil { - logError.Println("Invalid fwmark", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - logDebug.Println("UAPI: Updating fwmark") - - if err := device.BindSetMark(uint32(fwmark)); err != nil { - logError.Println("Failed to update fwmark:", err) - return &IPCError{ipc.IpcErrorPortInUse} - } - - case "public_key": - // switch to peer configuration - logDebug.Println("UAPI: Transition to peer configuration") + if key == "public_key" { + if deviceConfig { deviceConfig = false - - case "replace_peers": - if value != "true" { - logError.Println("Failed to set replace_peers, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - logDebug.Println("UAPI: Removing all peers") - device.RemoveAllPeers() - - default: - logError.Println("Invalid UAPI device key:", key) - return &IPCError{ipc.IpcErrorInvalid} } + peer.handlePostConfig() + // Load/create the peer we are now configuring. + err := device.handlePublicKeyLine(peer, value) + if err != nil { + return err + } + continue } - /* peer configuration */ - - if !deviceConfig { - - switch key { - - case "public_key": - var publicKey NoisePublicKey - err := publicKey.FromHex(value) - if err != nil { - logError.Println("Failed to get peer by public key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - // ignore peer with public key of device - - device.staticIdentity.RLock() - dummy = device.staticIdentity.publicKey.Equals(publicKey) - device.staticIdentity.RUnlock() - - if dummy { - peer = &Peer{} - } else { - peer = device.LookupPeer(publicKey) - } + var err error + if deviceConfig { + err = device.handleDeviceLine(key, value) + } else { + err = device.handlePeerLine(peer, key, value) + } + if err != nil { + return err + } + } + peer.handlePostConfig() - createdNewPeer = peer == nil - if createdNewPeer { - peer, err = device.NewPeer(publicKey) - if err != nil { - logError.Println("Failed to create new peer:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - if peer == nil { - dummy = true - peer = &Peer{} - } else { - logDebug.Println(peer, "- UAPI: Created") - } - } + if err := scanner.Err(); err != nil { + return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err) + } + return nil +} - case "update_only": +func (device *Device) handleDeviceLine(key, value string) error { + switch key { + case "private_key": + var sk NoisePrivateKey + err := sk.FromMaybeZeroHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) + } + device.log.Verbosef("UAPI: Updating private key") + device.SetPrivateKey(sk) - // allow disabling of creation + case "listen_port": + port, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) + } - if value != "true" { - logError.Println("Failed to set update only, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - if createdNewPeer && !dummy { - device.RemovePeer(peer.handshake.remoteStatic) - peer = &Peer{} - dummy = true - } + // update port and rebind + device.log.Verbosef("UAPI: Updating listen port") - case "remove": + device.net.Lock() + device.net.port = uint16(port) + device.net.Unlock() - // remove currently selected peer from device + if err := device.BindUpdate(); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) + } - if value != "true" { - logError.Println("Failed to set remove, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - if !dummy { - logDebug.Println(peer, "- UAPI: Removing") - device.RemovePeer(peer.handshake.remoteStatic) - } - peer = &Peer{} - dummy = true + case "fwmark": + mark, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) + } - case "preshared_key": + device.log.Verbosef("UAPI: Updating fwmark") + if err := device.BindSetMark(uint32(mark)); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) + } - // update PSK + case "replace_peers": + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) + } + device.log.Verbosef("UAPI: Removing all peers") + device.RemoveAllPeers() - logDebug.Println(peer, "- UAPI: Updating preshared key") + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) + } - peer.handshake.mutex.Lock() - err := peer.handshake.presharedKey.FromHex(value) - peer.handshake.mutex.Unlock() + return nil +} - if err != nil { - logError.Println("Failed to set preshared key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } +// An ipcSetPeer is the current state of an IPC set operation on a peer. +type ipcSetPeer struct { + *Peer // Peer is the current peer being operated on + dummy bool // dummy reports whether this peer is a temporary, placeholder peer + created bool // new reports whether this is a newly created peer + pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on +} - case "endpoint": +func (peer *ipcSetPeer) handlePostConfig() { + if peer.Peer == nil || peer.dummy { + return + } + if peer.created { + peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil + } + if peer.device.isUp() { + peer.Start() + if peer.pkaOn { + peer.SendKeepalive() + } + peer.SendStagedPackets() + } +} - // set endpoint destination +func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error { + // Load/create the peer we are configuring. + var publicKey NoisePublicKey + err := publicKey.FromHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) + } - logDebug.Println(peer, "- UAPI: Updating endpoint") + // Ignore peer with the same public key as this device. + device.staticIdentity.RLock() + peer.dummy = device.staticIdentity.publicKey.Equals(publicKey) + device.staticIdentity.RUnlock() - err := func() error { - peer.Lock() - defer peer.Unlock() - endpoint, err := CreateEndpoint(value) - if err != nil { - return err - } - peer.endpoint = endpoint - return nil - }() + if peer.dummy { + peer.Peer = &Peer{} + } else { + peer.Peer = device.LookupPeer(publicKey) + } - if err != nil { - logError.Println("Failed to set endpoint:", err, ":", value) - return &IPCError{ipc.IpcErrorInvalid} - } - - case "persistent_keepalive_interval": - - // update persistent keepalive interval - - logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval") - - secs, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to set persistent keepalive interval:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - old := peer.persistentKeepaliveInterval - peer.persistentKeepaliveInterval = uint16(secs) - - // send immediate keepalive if we're turning it on and before it wasn't on - - if old == 0 && secs != 0 { - if err != nil { - logError.Println("Failed to get tun device status:", err) - return &IPCError{ipc.IpcErrorIO} - } - if device.isUp.Get() && !dummy { - peer.SendKeepalive() - } - } + peer.created = peer.Peer == nil + if peer.created { + peer.Peer, err = device.NewPeer(publicKey) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) + } + device.log.Verbosef("%v - UAPI: Created", peer.Peer) + } + return nil +} - case "replace_allowed_ips": +func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error { + switch key { + case "update_only": + // allow disabling of creation + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) + } + if peer.created && !peer.dummy { + device.RemovePeer(peer.handshake.remoteStatic) + peer.Peer = &Peer{} + peer.dummy = true + } - logDebug.Println(peer, "- UAPI: Removing all allowedips") + case "remove": + // remove currently selected peer from device + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) + } + if !peer.dummy { + device.log.Verbosef("%v - UAPI: Removing", peer.Peer) + device.RemovePeer(peer.handshake.remoteStatic) + } + peer.Peer = &Peer{} + peer.dummy = true - if value != "true" { - logError.Println("Failed to replace allowedips, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } + case "preshared_key": + device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer) - if dummy { - continue - } + peer.handshake.mutex.Lock() + err := peer.handshake.presharedKey.FromHex(value) + peer.handshake.mutex.Unlock() - device.allowedips.RemoveByPeer(peer) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) + } - case "allowed_ip": + case "endpoint": + device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer) + endpoint, err := device.net.bind.ParseEndpoint(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) + } + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + peer.endpoint.val = endpoint - logDebug.Println(peer, "- UAPI: Adding allowedip") + case "persistent_keepalive_interval": + device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) - _, network, err := net.ParseCIDR(value) - if err != nil { - logError.Println("Failed to set allowed ip:", err) - return &IPCError{ipc.IpcErrorInvalid} - } + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) + } - if dummy { - continue - } + old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint(ones), peer) + // Send immediate keepalive if we're turning it on and before it wasn't on. + peer.pkaOn = old == 0 && secs != 0 - case "protocol_version": + case "replace_allowed_ips": + device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) + } + if peer.dummy { + return nil + } + device.allowedips.RemoveByPeer(peer.Peer) - if value != "1" { - logError.Println("Invalid protocol version:", value) - return &IPCError{ipc.IpcErrorInvalid} - } + case "allowed_ip": + device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) + prefix, err := netip.ParsePrefix(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) + } + if peer.dummy { + return nil + } + device.allowedips.Insert(prefix, peer.Peer) - default: - logError.Println("Invalid UAPI peer key:", key) - return &IPCError{ipc.IpcErrorInvalid} - } + case "protocol_version": + if value != "1" { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) } + + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) } return nil } -func (device *Device) IpcHandle(socket net.Conn) { +func (device *Device) IpcGet() (string, error) { + buf := new(strings.Builder) + if err := device.IpcGetOperation(buf); err != nil { + return "", err + } + return buf.String(), nil +} - // create buffered read/writer +func (device *Device) IpcSet(uapiConf string) error { + return device.IpcSetOperation(strings.NewReader(uapiConf)) +} +func (device *Device) IpcHandle(socket net.Conn) { defer socket.Close() buffered := func(s io.ReadWriter) *bufio.ReadWriter { @@ -407,35 +414,44 @@ func (device *Device) IpcHandle(socket net.Conn) { return bufio.NewReadWriter(reader, writer) }(socket) - defer buffered.Flush() - - op, err := buffered.ReadString('\n') - if err != nil { - return - } - - // handle operation - - var status *IPCError - - switch op { - case "set=1\n": - status = device.IpcSetOperation(buffered.Reader) - - case "get=1\n": - status = device.IpcGetOperation(buffered.Writer) - - default: - device.log.Error.Println("Invalid UAPI operation:", op) - return - } + for { + op, err := buffered.ReadString('\n') + if err != nil { + return + } - // write status + // handle operation + switch op { + case "set=1\n": + err = device.IpcSetOperation(buffered.Reader) + case "get=1\n": + var nextByte byte + nextByte, err = buffered.ReadByte() + if err != nil { + return + } + if nextByte != '\n' { + err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) + break + } + err = device.IpcGetOperation(buffered.Writer) + default: + device.log.Errorf("invalid UAPI operation: %v", op) + return + } - if status != nil { - device.log.Error.Println(status) - fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") + // write status + var status *IPCError + if err != nil && !errors.As(err, &status) { + // shouldn't happen + status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err) + } + if status != nil { + device.log.Errorf("%v", status) + fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") + } + buffered.Flush() } } diff --git a/device/version.go b/device/version.go deleted file mode 100644 index 326b9a9..0000000 --- a/device/version.go +++ /dev/null @@ -1,3 +0,0 @@ -package device - -const WireGuardGoVersion = "0.0.20191012" |