diff options
Diffstat (limited to 'device')
43 files changed, 2684 insertions, 2276 deletions
diff --git a/device/allowedips.go b/device/allowedips.go index 143bda3..fa46f97 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -1,173 +1,201 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( + "container/list" + "encoding/binary" "errors" "math/bits" "net" + "net/netip" "sync" "unsafe" ) -type trieEntry struct { - cidr uint - child [2]*trieEntry - bits net.IP - peer *Peer - - // index of "branching" bit - - bit_at_byte uint - bit_at_shift uint -} - -func isLittleEndian() bool { - one := uint32(1) - return *(*byte)(unsafe.Pointer(&one)) != 0 +type parentIndirection struct { + parentBit **trieEntry + parentBitType uint8 } -func swapU32(i uint32) uint32 { - if !isLittleEndian() { - return i - } - - return bits.ReverseBytes32(i) -} - -func swapU64(i uint64) uint64 { - if !isLittleEndian() { - return i - } - - return bits.ReverseBytes64(i) +type trieEntry struct { + peer *Peer + child [2]*trieEntry + parent parentIndirection + cidr uint8 + bitAtByte uint8 + bitAtShift uint8 + bits []byte + perPeerElem *list.Element } -func commonBits(ip1 net.IP, ip2 net.IP) uint { +func commonBits(ip1, ip2 []byte) uint8 { size := len(ip1) if size == net.IPv4len { - a := (*uint32)(unsafe.Pointer(&ip1[0])) - b := (*uint32)(unsafe.Pointer(&ip2[0])) - x := *a ^ *b - return uint(bits.LeadingZeros32(swapU32(x))) + a := binary.BigEndian.Uint32(ip1) + b := binary.BigEndian.Uint32(ip2) + x := a ^ b + return uint8(bits.LeadingZeros32(x)) } else if size == net.IPv6len { - a := (*uint64)(unsafe.Pointer(&ip1[0])) - b := (*uint64)(unsafe.Pointer(&ip2[0])) - x := *a ^ *b + a := binary.BigEndian.Uint64(ip1) + b := binary.BigEndian.Uint64(ip2) + x := a ^ b if x != 0 { - return uint(bits.LeadingZeros64(swapU64(x))) + return uint8(bits.LeadingZeros64(x)) } - a = (*uint64)(unsafe.Pointer(&ip1[8])) - b = (*uint64)(unsafe.Pointer(&ip2[8])) - x = *a ^ *b - return 64 + uint(bits.LeadingZeros64(swapU64(x))) + a = binary.BigEndian.Uint64(ip1[8:]) + b = binary.BigEndian.Uint64(ip2[8:]) + x = a ^ b + return 64 + uint8(bits.LeadingZeros64(x)) } else { panic("Wrong size bit string") } } -func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].removeByPeer(p) - node.child[1] = node.child[1].removeByPeer(p) +func (node *trieEntry) addToPeerEntries() { + node.perPeerElem = node.peer.trieEntries.PushBack(node) +} - if node.peer != p { - return node +func (node *trieEntry) removeFromPeerEntries() { + if node.perPeerElem != nil { + node.peer.trieEntries.Remove(node.perPeerElem) + node.perPeerElem = nil } +} - // remove peer & merge +func (node *trieEntry) choose(ip []byte) byte { + return (ip[node.bitAtByte] >> node.bitAtShift) & 1 +} - node.peer = nil - if node.child[0] == nil { - return node.child[1] +func (node *trieEntry) maskSelf() { + mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) + for i := 0; i < len(mask); i++ { + node.bits[i] &= mask[i] } - return node.child[0] } -func (node *trieEntry) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 +func (node *trieEntry) zeroizePointers() { + // Make the garbage collector's life slightly easier + node.peer = nil + node.child[0] = nil + node.child[1] = nil + node.parent.parentBit = nil } -func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { - - // at leaf - - if node == nil { - return &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), +func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { + for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { + parent = node + if parent.cidr == cidr { + exact = true + return } + bit := node.choose(ip) + node = node.child[bit] } + return +} - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.peer = peer - return node +func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { + if *trie.parentBit == nil { + node := &trieEntry{ + peer: peer, + parent: trie, + bits: ip, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } - bit := node.choose(ip) - node.child[bit] = node.child[bit].insert(ip, cidr, peer) - return node + node.maskSelf() + node.addToPeerEntries() + *trie.parentBit = node + return + } + node, exact := (*trie.parentBit).nodePlacement(ip, cidr) + if exact { + node.removeFromPeerEntries() + node.peer = peer + node.addToPeerEntries() + return } - - // split node newNode := &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + peer: peer, + bits: ip, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } + newNode.maskSelf() + newNode.addToPeerEntries() - cidr = min(cidr, common) - - // check for shorter prefix + var down *trieEntry + if node == nil { + down = *trie.parentBit + } else { + bit := node.choose(ip) + down = node.child[bit] + if down == nil { + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + return + } + } + common := commonBits(down.bits, ip) + if common < cidr { + cidr = common + } + parent := node if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode + bit := newNode.choose(down.bits) + down.parent = parentIndirection{&newNode.child[bit], bit} + newNode.child[bit] = down + if parent == nil { + newNode.parent = trie + *trie.parentBit = newNode + } else { + bit := parent.choose(newNode.bits) + newNode.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = newNode + } + return } - // create new parent for node & newNode - - parent := &trieEntry{ - bits: ip, - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + node = &trieEntry{ + bits: append([]byte{}, newNode.bits...), + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), + } + node.maskSelf() + + bit := node.choose(down.bits) + down.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = down + bit = node.choose(newNode.bits) + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + if parent == nil { + node.parent = trie + *trie.parentBit = node + } else { + bit := parent.choose(node.bits) + node.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = node } - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent } -func (node *trieEntry) lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip []byte) *Peer { var found *Peer - size := uint(len(ip)) + size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { if node.peer != nil { found = node.peer } - if node.bit_at_byte == size { + if node.bitAtByte == size { break } bit := node.choose(ip) @@ -176,76 +204,91 @@ func (node *trieEntry) lookup(ip net.IP) *Peer { return found } -func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { - if node == nil { - return results - } - if node.peer == p { - mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) - results = append(results, net.IPNet{ - Mask: mask, - IP: node.bits.Mask(mask), - }) - } - results = node.child[0].entriesForPeer(p, results) - results = node.child[1].entriesForPeer(p, results) - return results -} - type AllowedIPs struct { IPv4 *trieEntry IPv6 *trieEntry mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { table.mutex.RLock() defer table.mutex.RUnlock() - allowed := make([]net.IPNet, 0, 10) - allowed = table.IPv4.entriesForPeer(peer, allowed) - allowed = table.IPv6.entriesForPeer(peer, allowed) - return allowed -} - -func (table *AllowedIPs) Reset() { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = nil - table.IPv6 = nil + for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { + node := elem.Value.(*trieEntry) + a, _ := netip.AddrFromSlice(node.bits) + if !cb(netip.PrefixFrom(a, int(node.cidr))) { + return + } + } } func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - table.IPv4 = table.IPv4.removeByPeer(peer) - table.IPv6 = table.IPv6.removeByPeer(peer) + var next *list.Element + for elem := peer.trieEntries.Front(); elem != nil; elem = next { + next = elem.Next() + node := elem.Value.(*trieEntry) + + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + continue + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + node.zeroizePointers() + continue + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + node.zeroizePointers() + continue + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + node.zeroizePointers() + parent.zeroizePointers() + } } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { +func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - switch len(ip) { - case net.IPv6len: - table.IPv6 = table.IPv6.insert(ip, cidr, peer) - case net.IPv4len: - table.IPv4 = table.IPv4.insert(ip, cidr, peer) - default: + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else { panic(errors.New("inserting unknown address type")) } } -func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { +func (table *AllowedIPs) Lookup(ip []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() - return table.IPv4.lookup(address) -} - -func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv6.lookup(address) + switch len(ip) { + case net.IPv6len: + return table.IPv6.lookup(ip) + case net.IPv4len: + return table.IPv4.lookup(ip) + default: + panic(errors.New("looking up unknown address type")) + } } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 3947830..07065c3 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -1,25 +1,28 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" + "net" + "net/netip" "sort" "testing" ) const ( - NumberOfPeers = 100 - NumberOfAddresses = 250 - NumberOfTests = 10000 + NumberOfPeers = 100 + NumberOfPeerRemovals = 4 + NumberOfAddresses = 250 + NumberOfTests = 10000 ) type SlowNode struct { peer *Peer - cidr uint + cidr uint8 bits []byte } @@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) { r[i], r[j] = r[j], r[i] } -func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { +func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { for _, t := range r { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { t.peer = peer @@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer { return nil } -func TestTrieRandomIPv4(t *testing.T) { - var trie *trieEntry - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) +func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { + n := 0 + for _, x := range r { + if x.peer != peer { + r[n] = x + n++ } } + return r[:n] } -func TestTrieRandomIPv6(t *testing.T) { - var trie *trieEntry - var slow SlowRouter +func TestTrieRandom(t *testing.T) { + var slow4, slow6 SlowRouter var peers []*Peer + var allowedIPs AllowedIPs rand.Seed(1) - const AddressLength = 16 - - for n := 0; n < NumberOfPeers; n += 1 { + for n := 0; n < NumberOfPeers; n++ { peers = append(peers, &Peer{}) } - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) + for n := 0; n < NumberOfAddresses; n++ { + var addr4 [4]byte + rand.Read(addr4[:]) + cidr := uint8(rand.Intn(32) + 1) + index := rand.Intn(NumberOfPeers) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) + slow4 = slow4.Insert(addr4[:], cidr, peers[index]) + + var addr6 [16]byte + rand.Read(addr6[:]) + cidr = uint8(rand.Intn(128) + 1) + index = rand.Intn(NumberOfPeers) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) + slow6 = slow6.Insert(addr6[:], cidr, peers[index]) } - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) + var p int + for p = 0; ; p++ { + for n := 0; n < NumberOfTests; n++ { + var addr4 [4]byte + rand.Read(addr4[:]) + peer1 := slow4.Lookup(addr4[:]) + peer2 := allowedIPs.Lookup(addr4[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) + } + + var addr6 [16]byte + rand.Read(addr6[:]) + peer1 = slow6.Lookup(addr6[:]) + peer2 = allowedIPs.Lookup(addr6[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) + } + } + if p >= len(peers) || p >= NumberOfPeerRemovals { + break } + allowedIPs.RemoveByPeer(peers[p]) + slow4 = slow4.RemoveByPeer(peers[p]) + slow6 = slow6.RemoveByPeer(peers[p]) + } + for ; p < len(peers); p++ { + allowedIPs.RemoveByPeer(peers[p]) + } + + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Failed to remove all nodes from trie by peer") } } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 005df48..cde068e 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,40 +8,17 @@ package device import ( "math/rand" "net" + "net/netip" "testing" ) -/* Todo: More comprehensive - */ - type testPairCommonBits struct { s1 []byte s2 []byte - match uint -} - -type testPairTrieInsert struct { - key []byte - cidr uint - peer *Peer -} - -type testPairTrieLookup struct { - key []byte - peer *Peer -} - -func printTrie(t *testing.T, p *trieEntry) { - if p == nil { - return - } - t.Log(p) - printTrie(t, p.child[0]) - printTrie(t, p.child[1]) + match uint8 } func TestCommonBits(t *testing.T) { - tests := []testPairCommonBits{ {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, @@ -62,27 +39,28 @@ func TestCommonBits(t *testing.T) { } } -func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { +func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { var trie *trieEntry var peers []*Peer + root := parentIndirection{&trie, 2} rand.Seed(1) const AddressLength = 4 - for n := 0; n < peerNumber; n += 1 { + for n := 0; n < peerNumber; n++ { peers = append(peers, &Peer{}) } - for n := 0; n < addressNumber; n += 1 { + for n := 0; n < addressNumber; n++ { var addr [AddressLength]byte rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) + cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % peerNumber - trie = trie.insert(addr[:], cidr, peers[index]) + root.insert(addr[:], cidr, peers[index]) } - for n := 0; n < b.N; n += 1 { + for n := 0; n < b.N; n++ { var addr [AddressLength]byte rand.Read(addr[:]) trie.lookup(addr[:]) @@ -117,21 +95,21 @@ func TestTrieIPv4(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs - insert := func(peer *Peer, a, b, c, d byte, cidr uint) { - trie = trie.insert([]byte{a, b, c, d}, cidr, peer) + insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { t.Error("Assert EQ failed") } } assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p == peer { t.Error("Assert NEQ failed") } @@ -173,7 +151,7 @@ func TestTrieIPv4(t *testing.T) { assertEQ(a, 192, 0, 0, 0) assertEQ(a, 255, 0, 0, 0) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0) @@ -181,12 +159,21 @@ func TestTrieIPv4(t *testing.T) { assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0) - trie = nil + allowedIPs.RemoveByPeer(a) + allowedIPs.RemoveByPeer(b) + allowedIPs.RemoveByPeer(c) + allowedIPs.RemoveByPeer(d) + allowedIPs.RemoveByPeer(e) + allowedIPs.RemoveByPeer(g) + allowedIPs.RemoveByPeer(h) + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Expected removing all the peers to empty trie, but it did not") + } insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 24) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) } @@ -204,7 +191,7 @@ func TestTrieIPv6(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs expand := func(a uint32) []byte { var out [4]byte @@ -215,13 +202,13 @@ func TestTrieIPv6(t *testing.T) { return out[:] } - insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { var addr []byte addr = append(addr, expand(a)...) addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - trie = trie.insert(addr, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { @@ -230,7 +217,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - p := trie.lookup(addr) + p := allowedIPs.Lookup(addr) if p != peer { t.Error("Assert EQ failed") } diff --git a/device/bind_test.go b/device/bind_test.go index 339adbe..302a521 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -14,14 +14,11 @@ import ( type DummyDatagram struct { msg []byte endpoint conn.Endpoint - world bool // better type } type DummyBind struct { in6 chan DummyDatagram - ou6 chan DummyDatagram in4 chan DummyDatagram - ou4 chan DummyDatagram closed bool } @@ -29,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error { return nil } -func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) { +func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in6 if !ok { return 0, nil, errors.New("closed") } - copy(buff, datagram.msg) + copy(buf, datagram.msg) return len(datagram.msg), datagram.endpoint, nil } -func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) { +func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in4 if !ok { return 0, nil, errors.New("closed") } - copy(buff, datagram.msg) + copy(buf, datagram.msg) return len(datagram.msg), datagram.endpoint, nil } @@ -54,6 +51,6 @@ func (b *DummyBind) Close() error { return nil } -func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error { +func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error { return nil } diff --git a/device/bindsocketshim.go b/device/bindsocketshim.go deleted file mode 100644 index 896c7d2..0000000 --- a/device/bindsocketshim.go +++ /dev/null @@ -1,36 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "errors" - - "golang.zx2c4.com/wireguard/conn" -) - -// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. -func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - if device.net.bind == nil { - return errors.New("Bind is not yet initialized") - } - - if iface, ok := device.net.bind.(conn.BindSocketToInterface); ok { - return iface.BindSocketToInterface4(interfaceIndex, blackhole) - } - return nil -} - -// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. -func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - if device.net.bind == nil { - return errors.New("Bind is not yet initialized") - } - - if iface, ok := device.net.bind.(conn.BindSocketToInterface); ok { - return iface.BindSocketToInterface6(interfaceIndex, blackhole) - } - return nil -} diff --git a/device/channels.go b/device/channels.go new file mode 100644 index 0000000..e526f6b --- /dev/null +++ b/device/channels.go @@ -0,0 +1,137 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "runtime" + "sync" +) + +// An outboundQueue is a channel of QueueOutboundElements awaiting encryption. +// An outboundQueue is ref-counted using its wg field. +// An outboundQueue created with newOutboundQueue has one reference. +// Every additional writer must call wg.Add(1). +// Every completed writer must call wg.Done(). +// When no further writers will be added, +// call wg.Done to remove the initial reference. +// When the refcount hits 0, the queue's channel is closed. +type outboundQueue struct { + c chan *QueueOutboundElementsContainer + wg sync.WaitGroup +} + +func newOutboundQueue() *outboundQueue { + q := &outboundQueue{ + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + +// A inboundQueue is similar to an outboundQueue; see those docs. +type inboundQueue struct { + c chan *QueueInboundElementsContainer + wg sync.WaitGroup +} + +func newInboundQueue() *inboundQueue { + q := &inboundQueue{ + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + +// A handshakeQueue is similar to an outboundQueue; see those docs. +type handshakeQueue struct { + c chan QueueHandshakeElement + wg sync.WaitGroup +} + +func newHandshakeQueue() *handshakeQueue { + q := &handshakeQueue{ + c: make(chan QueueHandshakeElement, QueueHandshakeSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + +type autodrainingInboundQueue struct { + c chan *QueueInboundElementsContainer +} + +// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. +// It is useful in cases in which is it hard to manage the lifetime of the channel. +// The returned channel must not be closed. Senders should signal shutdown using +// some other means, such as sending a sentinel nil values. +func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { + q := &autodrainingInboundQueue{ + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), + } + runtime.SetFinalizer(q, device.flushInboundQueue) + return q +} + +func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { + for { + select { + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsContainer(elemsContainer) + default: + return + } + } +} + +type autodrainingOutboundQueue struct { + c chan *QueueOutboundElementsContainer +} + +// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. +// It is useful in cases in which is it hard to manage the lifetime of the channel. +// The returned channel must not be closed. Senders should signal shutdown using +// some other means, such as sending a sentinel nil values. +// All sends to the channel must be best-effort, because there may be no receivers. +func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { + q := &autodrainingOutboundQueue{ + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), + } + runtime.SetFinalizer(q, device.flushOutboundQueue) + return q +} + +func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { + for { + select { + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsContainer(elemsContainer) + default: + return + } + } +} diff --git a/device/constants.go b/device/constants.go index 1a4b8ea..59854a1 100644 --- a/device/constants.go +++ b/device/constants.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -35,7 +35,6 @@ const ( /* Implementation constants */ const ( - UnderLoadQueueSize = QueueHandshakeSize / 8 UnderLoadAfterTime = time.Second // how long does the device remain under load after detected MaxPeers = 1 << 16 // maximum number of configured peers ) diff --git a/device/cookie.go b/device/cookie.go index c658ca3..876f05d 100644 --- a/device/cookie.go +++ b/device/cookie.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool { return hmac.Equal(mac1[:], msg[smac1:smac2]) } -func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { +func (st *CookieChecker) CheckMAC2(msg, src []byte) bool { st.RLock() defer st.RUnlock() @@ -119,7 +119,6 @@ func (st *CookieChecker) CreateReply( recv uint32, src []byte, ) (*MessageCookieReply, error) { - st.RLock() // refresh cookie secret @@ -204,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) - if err != nil { return false } @@ -215,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { } func (st *CookieGenerator) AddMacs(msg []byte) { - size := len(msg) smac2 := size - blake2s.Size128 diff --git a/device/cookie_test.go b/device/cookie_test.go index 7e4c362..4f1e50a 100644 --- a/device/cookie_test.go +++ b/device/cookie_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -10,7 +10,6 @@ import ( ) func TestCookieMAC1(t *testing.T) { - // setup generator / checker var ( @@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) { msg[5] ^= 0x20 - srcBad1 := []byte{192, 168, 13, 37, 40, 01} + srcBad1 := []byte{192, 168, 13, 37, 40, 1} if checker.CheckMAC2(msg, srcBad1) { t.Fatal("MAC2 generation/verification failed") } - srcBad2 := []byte{192, 168, 13, 38, 40, 01} + srcBad2 := []byte{192, 168, 13, 38, 40, 1} if checker.CheckMAC2(msg, srcBad2) { t.Fatal("MAC2 generation/verification failed") } diff --git a/device/device.go b/device/device.go index c64432e..83c33ee 100644 --- a/device/device.go +++ b/device/device.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -11,8 +11,6 @@ import ( "sync/atomic" "time" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" @@ -20,28 +18,33 @@ import ( ) type Device struct { - isUp AtomicBool // device is (going) up - isClosed AtomicBool // device is closed? (acting as guard) - log *Logger - - // synchronized resources (locks acquired in order) - state struct { - starting sync.WaitGroup + // state holds the device's state. It is accessed atomically. + // Use the device.deviceState method to read it. + // device.deviceState does not acquire the mutex, so it captures only a snapshot. + // During state transitions, the state variable is updated before the device itself. + // The state is thus either the current state of the device or + // the intended future state of the device. + // For example, while executing a call to Up, state will be deviceStateUp. + // There is no guarantee that that intended future state of the device + // will become the actual state; Up can fail. + // The device can also change state multiple times between time of check and time of use. + // Unsynchronized uses of state must therefore be advisory/best-effort only. + state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience + // stopping blocks until all inputs to Device have been closed. stopping sync.WaitGroup + // mu protects state changes. sync.Mutex - changing AtomicBool - current bool } net struct { - starting sync.WaitGroup stopping sync.WaitGroup sync.RWMutex bind conn.Bind // bind interface netlinkCancel *rwcancel.RWCancel port uint16 // listening port fwmark uint32 // mark value (0 = disabled) + brokenRoaming bool } staticIdentity struct { @@ -51,153 +54,176 @@ type Device struct { } peers struct { - sync.RWMutex - keyMap map[NoisePublicKey]*Peer + sync.RWMutex // protects keyMap + keyMap map[NoisePublicKey]*Peer } - // unprotected / "self-synchronising resources" + rate struct { + underLoadUntil atomic.Int64 + limiter ratelimiter.Ratelimiter + } allowedips AllowedIPs indexTable IndexTable cookieChecker CookieChecker - rate struct { - underLoadUntil atomic.Value - limiter ratelimiter.Ratelimiter - } - pool struct { - messageBufferPool *sync.Pool - messageBufferReuseChan chan *[MaxMessageSize]byte - inboundElementPool *sync.Pool - inboundElementReuseChan chan *QueueInboundElement - outboundElementPool *sync.Pool - outboundElementReuseChan chan *QueueOutboundElement + inboundElementsContainer *WaitPool + outboundElementsContainer *WaitPool + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { - encryption chan *QueueOutboundElement - decryption chan *QueueInboundElement - handshake chan QueueHandshakeElement - } - - signals struct { - stop chan struct{} + encryption *outboundQueue + decryption *inboundQueue + handshake *handshakeQueue } tun struct { device tun.Device - mtu int32 + mtu atomic.Int32 } + + ipcMutex sync.RWMutex + closed chan struct{} + log *Logger } -/* Converts the peer into a "zombie", which remains in the peer map, - * but processes no packets and does not exists in the routing table. - * - * Must hold device.peers.Mutex - */ -func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { +// deviceState represents the state of a Device. +// There are three states: down, up, closed. +// Transitions: +// +// down -----+ +// ↑↓ ↓ +// up -> closed +type deviceState uint32 + +//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState +const ( + deviceStateDown deviceState = iota + deviceStateUp + deviceStateClosed +) - // stop routing and processing of packets +// deviceState returns device.state.state as a deviceState +// See those docs for how to interpret this value. +func (device *Device) deviceState() deviceState { + return deviceState(device.state.state.Load()) +} + +// isClosed reports whether the device is closed (or is closing). +// See device.state.state comments for how to interpret this value. +func (device *Device) isClosed() bool { + return device.deviceState() == deviceStateClosed +} + +// isUp reports whether the device is up (or is attempting to come up). +// See device.state.state comments for how to interpret this value. +func (device *Device) isUp() bool { + return device.deviceState() == deviceStateUp +} +// Must hold device.peers.Lock() +func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { + // stop routing and processing of packets device.allowedips.RemoveByPeer(peer) peer.Stop() // remove from peer map - delete(device.peers.keyMap, key) } -func deviceUpdateState(device *Device) { - - // check if state already being updated (guard) - - if device.state.changing.Swap(true) { - return - } - - // compare to current state of device - +// changeState attempts to change the device state to match want. +func (device *Device) changeState(want deviceState) (err error) { device.state.Lock() - - newIsUp := device.isUp.Get() - - if newIsUp == device.state.current { - device.state.changing.Set(false) - device.state.Unlock() - return + defer device.state.Unlock() + old := device.deviceState() + if old == deviceStateClosed { + // once closed, always closed + device.log.Verbosef("Interface closed, ignored requested state %s", want) + return nil } - - // change state of device - - switch newIsUp { - case true: - if err := device.BindUpdate(); err != nil { - device.log.Error.Printf("Unable to update bind: %v\n", err) - device.isUp.Set(false) + switch want { + case old: + return nil + case deviceStateUp: + device.state.state.Store(uint32(deviceStateUp)) + err = device.upLocked() + if err == nil { break } - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Start() - if peer.persistentKeepaliveInterval > 0 { - peer.SendKeepalive() - } + fallthrough // up failed; bring the device all the way back down + case deviceStateDown: + device.state.state.Store(uint32(deviceStateDown)) + errDown := device.downLocked() + if err == nil { + err = errDown } - device.peers.RUnlock() - - case false: - device.BindClose() - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Stop() - } - device.peers.RUnlock() } + device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) + return +} - // update state variables - - device.state.current = newIsUp - device.state.changing.Set(false) - device.state.Unlock() +// upLocked attempts to bring the device up and reports whether it succeeded. +// The caller must hold device.state.mu and is responsible for updating device.state.state. +func (device *Device) upLocked() error { + if err := device.BindUpdate(); err != nil { + device.log.Errorf("Unable to update bind: %v", err) + return err + } - // check for state change in the mean time + // The IPC set operation waits for peers to be created before calling Start() on them, + // so if there's a concurrent IPC set request happening, we should wait for it to complete. + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() - deviceUpdateState(device) + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Start() + if peer.persistentKeepaliveInterval.Load() > 0 { + peer.SendKeepalive() + } + } + device.peers.RUnlock() + return nil } -func (device *Device) Up() { - - // closed device cannot be brought up +// downLocked attempts to bring the device down. +// The caller must hold device.state.mu and is responsible for updating device.state.state. +func (device *Device) downLocked() error { + err := device.BindClose() + if err != nil { + device.log.Errorf("Bind close failed: %v", err) + } - if device.isClosed.Get() { - return + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Stop() } + device.peers.RUnlock() + return err +} - device.isUp.Set(true) - deviceUpdateState(device) +func (device *Device) Up() error { + return device.changeState(deviceStateUp) } -func (device *Device) Down() { - device.isUp.Set(false) - deviceUpdateState(device) +func (device *Device) Down() error { + return device.changeState(deviceStateDown) } func (device *Device) IsUnderLoad() bool { - // check if currently under load - now := time.Now() - underLoad := len(device.queue.handshake) >= UnderLoadQueueSize + underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 if underLoad { - device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime)) + device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) return true } - // check if recently under load - - until := device.rate.underLoadUntil.Load().(time.Time) - return until.After(now) + return device.rate.underLoadUntil.Load() > now.UnixNano() } func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { @@ -224,7 +250,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { publicKey := sk.publicKey() for key, peer := range device.peers.keyMap { if peer.handshake.remoteStatic.Equals(publicKey) { - unsafeRemovePeer(device, peer, key) + peer.handshake.mutex.RUnlock() + removePeerLocked(device, peer, key) + peer.handshake.mutex.RLock() } } @@ -239,7 +267,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { handshake := &peer.handshake - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) expiredPeers = append(expiredPeers, peer) } @@ -253,70 +281,63 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { return nil } -func NewDevice(tunDevice tun.Device, logger *Logger) *Device { +func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { device := new(Device) - - device.isUp.Set(false) - device.isClosed.Set(false) - + device.state.state.Store(uint32(deviceStateDown)) + device.closed = make(chan struct{}) device.log = logger - + device.net.bind = bind device.tun.device = tunDevice mtu, err := device.tun.device.MTU() if err != nil { - logger.Error.Println("Trouble determining MTU, assuming default:", err) + device.log.Errorf("Trouble determining MTU, assuming default: %v", err) mtu = DefaultMTU } - device.tun.mtu = int32(mtu) - + device.tun.mtu.Store(int32(mtu)) device.peers.keyMap = make(map[NoisePublicKey]*Peer) - device.rate.limiter.Init() - device.rate.underLoadUntil.Store(time.Time{}) - device.indexTable.Init() - device.allowedips.Reset() device.PopulatePools() // create queues - device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) - device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) - device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) - - // prepare signals - - device.signals.stop = make(chan struct{}) - - // prepare net - - device.net.port = 0 - device.net.bind = nil + device.queue.handshake = newHandshakeQueue() + device.queue.encryption = newOutboundQueue() + device.queue.decryption = newInboundQueue() // start workers cpus := runtime.NumCPU() - device.state.starting.Wait() device.state.stopping.Wait() - for i := 0; i < cpus; i += 1 { - device.state.starting.Add(3) - device.state.stopping.Add(3) - go device.RoutineEncryption() - go device.RoutineDecryption() - go device.RoutineHandshake() + device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake + for i := 0; i < cpus; i++ { + go device.RoutineEncryption(i + 1) + go device.RoutineDecryption(i + 1) + go device.RoutineHandshake(i + 1) } - device.state.starting.Add(2) - device.state.stopping.Add(2) + device.state.stopping.Add(1) // RoutineReadFromTUN + device.queue.encryption.wg.Add(1) // RoutineReadFromTUN go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() - device.state.starting.Wait() - return device } +// BatchSize returns the BatchSize for the device as a whole which is the max of +// the bind batch size and the tun batch size. The batch size reported by device +// is the size used to construct memory pools, and is the allowed batch size for +// the lifetime of the device. +func (device *Device) BatchSize() int { + size := device.net.bind.BatchSize() + dSize := device.tun.device.BatchSize() + if size < dSize { + size = dSize + } + return size +} + func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { device.peers.RLock() defer device.peers.RUnlock() @@ -331,7 +352,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) { peer, ok := device.peers.keyMap[key] if ok { - unsafeRemovePeer(device, peer, key) + removePeerLocked(device, peer, key) } } @@ -340,67 +361,50 @@ func (device *Device) RemoveAllPeers() { defer device.peers.Unlock() for key, peer := range device.peers.keyMap { - unsafeRemovePeer(device, peer, key) + removePeerLocked(device, peer, key) } device.peers.keyMap = make(map[NoisePublicKey]*Peer) } -func (device *Device) FlushPacketQueues() { - for { - select { - case elem, ok := <-device.queue.decryption: - if ok { - elem.Drop() - } - case elem, ok := <-device.queue.encryption: - if ok { - elem.Drop() - } - case <-device.queue.handshake: - default: - return - } - } - -} - func (device *Device) Close() { - if device.isClosed.Swap(true) { - return - } - - device.state.starting.Wait() - - device.log.Info.Println("Device closing") - device.state.changing.Set(true) device.state.Lock() defer device.state.Unlock() + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() + if device.isClosed() { + return + } + device.state.state.Store(uint32(deviceStateClosed)) + device.log.Verbosef("Device closing") device.tun.device.Close() - device.BindClose() - - device.isUp.Set(false) - - close(device.signals.stop) + device.downLocked() + // Remove peers before closing queues, + // because peers assume that queues are active. device.RemoveAllPeers() + // We kept a reference to the encryption and decryption queues, + // in case we started any new peers that might write to them. + // No new peers are coming; we are done with these queues. + device.queue.encryption.wg.Done() + device.queue.decryption.wg.Done() + device.queue.handshake.wg.Done() device.state.stopping.Wait() - device.FlushPacketQueues() device.rate.limiter.Close() - device.state.changing.Set(false) - device.log.Info.Println("Interface closed") + device.log.Verbosef("Device closed") + close(device.closed) } func (device *Device) Wait() chan struct{} { - return device.signals.stop + return device.closed } func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { - if device.isClosed.Get() { + if !device.isUp() { return } @@ -416,7 +420,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { device.peers.RUnlock() } -func unsafeCloseBind(device *Device) error { +// closeBindLocked closes the device's net.bind. +// The caller must hold the net mutex. +func closeBindLocked(device *Device) error { var err error netc := &device.net if netc.netlinkCancel != nil { @@ -424,7 +430,6 @@ func unsafeCloseBind(device *Device) error { } if netc.bind != nil { err = netc.bind.Close() - netc.bind = nil } netc.stopping.Wait() return err @@ -437,34 +442,26 @@ func (device *Device) Bind() conn.Bind { } func (device *Device) BindSetMark(mark uint32) error { - device.net.Lock() defer device.net.Unlock() // check if modified - if device.net.fwmark == mark { return nil } // update fwmark on existing bind - device.net.fwmark = mark - if device.isUp.Get() && device.net.bind != nil { + if device.isUp() && device.net.bind != nil { if err := device.net.bind.SetMark(mark); err != nil { return err } } // clear cached source addresses - device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + peer.markEndpointSrcForClearing() } device.peers.RUnlock() @@ -472,76 +469,68 @@ func (device *Device) BindSetMark(mark uint32) error { } func (device *Device) BindUpdate() error { - device.net.Lock() defer device.net.Unlock() // close existing sockets - - if err := unsafeCloseBind(device); err != nil { + if err := closeBindLocked(device); err != nil { return err } // open new sockets + if !device.isUp() { + return nil + } - if device.isUp.Get() { + // bind to new port + var err error + var recvFns []conn.ReceiveFunc + netc := &device.net - // bind to new port + recvFns, netc.port, err = netc.bind.Open(netc.port) + if err != nil { + netc.port = 0 + return err + } - var err error - netc := &device.net - netc.bind, netc.port, err = conn.CreateBind(netc.port) - if err != nil { - netc.bind = nil - netc.port = 0 - return err - } - netc.netlinkCancel, err = device.startRouteListener(netc.bind) + netc.netlinkCancel, err = device.startRouteListener(netc.bind) + if err != nil { + netc.bind.Close() + netc.port = 0 + return err + } + + // set fwmark + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) if err != nil { - netc.bind.Close() - netc.bind = nil - netc.port = 0 return err } + } - // set fwmark - - if netc.fwmark != 0 { - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - // start receiving routines - - device.net.starting.Add(2) - device.net.stopping.Add(2) - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - device.net.starting.Wait() + // clear cached source addresses + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.markEndpointSrcForClearing() + } + device.peers.RUnlock() - device.log.Debug.Println("UDP bind has been updated") + // start receiving routines + device.net.stopping.Add(len(recvFns)) + device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption + device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + batchSize := netc.bind.BatchSize() + for _, fn := range recvFns { + go device.RoutineReceiveIncoming(batchSize, fn) } + device.log.Verbosef("UDP bind has been updated") return nil } func (device *Device) BindClose() error { device.net.Lock() - err := unsafeCloseBind(device) + err := closeBindLocked(device) device.net.Unlock() return err } diff --git a/device/device_test.go b/device/device_test.go index 5ea5410..fff172b 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -1,102 +1,476 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "bufio" "bytes" - "net" - "strings" + "encoding/hex" + "fmt" + "io" + "math/rand" + "net/netip" + "os" + "runtime" + "runtime/pprof" + "sync" + "sync/atomic" "testing" "time" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/conn/bindtest" + "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/tuntest" ) -func TestTwoDevicePing(t *testing.T) { - // TODO(crawshaw): pick unused ports on localhost - cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58 -listen_port=53511 -replace_peers=true -public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725 -protocol_version=1 -replace_allowed_ips=true -allowed_ip=1.0.0.2/32 -endpoint=127.0.0.1:53512` - tun1 := tuntest.NewChannelTUN() - dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: ")) - dev1.Up() - defer dev1.Close() - if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil { - t.Fatal(err) - } - - cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768 -listen_port=53512 -replace_peers=true -public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427 -protocol_version=1 -replace_allowed_ips=true -allowed_ip=1.0.0.1/32 -endpoint=127.0.0.1:53511` - tun2 := tuntest.NewChannelTUN() - dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: ")) - dev2.Up() - defer dev2.Close() - if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil { - t.Fatal(err) +// uapiCfg returns a string that contains cfg formatted use with IpcSet. +// cfg is a series of alternating key/value strings. +// uapiCfg exists because editors and humans like to insert +// whitespace into configs, which can cause failures, some of which are silent. +// For example, a leading blank newline causes the remainder +// of the config to be silently ignored. +func uapiCfg(cfg ...string) string { + if len(cfg)%2 != 0 { + panic("odd number of args to uapiReader") + } + buf := new(bytes.Buffer) + for i, s := range cfg { + buf.WriteString(s) + sep := byte('\n') + if i%2 == 0 { + sep = '=' + } + buf.WriteByte(sep) } + return buf.String() +} - t.Run("ping 1.0.0.1", func(t *testing.T) { - msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) - tun2.Outbound <- msg2to1 +// genConfigs generates a pair of configs that connect to each other. +// The configs use distinct, probably-usable ports. +func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { + var key1, key2 NoisePrivateKey + _, err := rand.Read(key1[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + _, err = rand.Read(key2[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + pub1, pub2 := key1.publicKey(), key2.publicKey() + + cfgs[0] = uapiCfg( + "private_key", hex.EncodeToString(key1[:]), + "listen_port", "0", + "replace_peers", "true", + "public_key", hex.EncodeToString(pub2[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "1.0.0.2/32", + ) + endpointCfgs[0] = uapiCfg( + "public_key", hex.EncodeToString(pub2[:]), + "endpoint", "127.0.0.1:%d", + ) + cfgs[1] = uapiCfg( + "private_key", hex.EncodeToString(key2[:]), + "listen_port", "0", + "replace_peers", "true", + "public_key", hex.EncodeToString(pub1[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "1.0.0.1/32", + ) + endpointCfgs[1] = uapiCfg( + "public_key", hex.EncodeToString(pub1[:]), + "endpoint", "127.0.0.1:%d", + ) + return +} + +// A testPair is a pair of testPeers. +type testPair [2]testPeer + +// A testPeer is a peer used for testing. +type testPeer struct { + tun *tuntest.ChannelTUN + dev *Device + ip netip.Addr +} + +type SendDirection bool + +const ( + Ping SendDirection = true + Pong SendDirection = false +) + +func (d SendDirection) String() string { + if d == Ping { + return "ping" + } + return "pong" +} + +func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) { + tb.Helper() + p0, p1 := pair[0], pair[1] + if !ping { + // pong is the new ping + p0, p1 = p1, p0 + } + msg := tuntest.Ping(p0.ip, p1.ip) + p1.tun.Outbound <- msg + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + var err error + select { + case msgRecv := <-p0.tun.Inbound: + if !bytes.Equal(msg, msgRecv) { + err = fmt.Errorf("%s did not transit correctly", ping) + } + case <-timer.C: + err = fmt.Errorf("%s did not transit", ping) + case <-done: + } + if err != nil { + // The error may have occurred because the test is done. select { - case msgRecv := <-tun1.Inbound: - if !bytes.Equal(msg2to1, msgRecv) { - t.Error("ping did not transit correctly") + case <-done: + return + default: + } + // Real error. + tb.Error(err) + } +} + +// genTestPair creates a testPair. +func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { + cfg, endpointCfg := genConfigs(tb) + var binds [2]conn.Bind + if realSocket { + binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() + } else { + binds = bindtest.NewChannelBinds() + } + // Bring up a ChannelTun for each config. + for i := range pair { + p := &pair[i] + p.tun = tuntest.NewChannelTUN() + p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) + level := LogLevelVerbose + if _, ok := tb.(*testing.B); ok && !testing.Verbose() { + level = LogLevelError + } + p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) + if err := p.dev.IpcSet(cfg[i]); err != nil { + tb.Errorf("failed to configure device %d: %v", i, err) + p.dev.Close() + continue + } + if err := p.dev.Up(); err != nil { + tb.Errorf("failed to bring up device %d: %v", i, err) + p.dev.Close() + continue + } + endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port) + } + for i := range pair { + p := &pair[i] + if err := p.dev.IpcSet(endpointCfg[i]); err != nil { + tb.Errorf("failed to configure device endpoint %d: %v", i, err) + p.dev.Close() + continue + } + // The device is ready. Close it when the test completes. + tb.Cleanup(p.dev.Close) + } + return +} + +func TestTwoDevicePing(t *testing.T) { + goroutineLeakCheck(t) + pair := genTestPair(t, true) + t.Run("ping 1.0.0.1", func(t *testing.T) { + pair.Send(t, Ping, nil) + }) + t.Run("ping 1.0.0.2", func(t *testing.T) { + pair.Send(t, Pong, nil) + }) +} + +func TestUpDown(t *testing.T) { + goroutineLeakCheck(t) + const itrials = 50 + const otrials = 10 + + for n := 0; n < otrials; n++ { + pair := genTestPair(t, false) + for i := range pair { + for k := range pair[i].dev.peers.keyMap { + pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) + } + } + var wg sync.WaitGroup + wg.Add(len(pair)) + for i := range pair { + go func(d *Device) { + defer wg.Done() + for i := 0; i < itrials; i++ { + if err := d.Up(); err != nil { + t.Errorf("failed up bring up device: %v", err) + } + time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) + if err := d.Down(); err != nil { + t.Errorf("failed to bring down device: %v", err) + } + time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) + } + }(pair[i].dev) + } + wg.Wait() + for i := range pair { + pair[i].dev.Up() + pair[i].dev.Close() + } + } +} + +// TestConcurrencySafety does other things concurrently with tunnel use. +// It is intended to be used with the race detector to catch data races. +func TestConcurrencySafety(t *testing.T) { + pair := genTestPair(t, true) + done := make(chan struct{}) + + const warmupIters = 10 + var warmup sync.WaitGroup + warmup.Add(warmupIters) + go func() { + // Send data continuously back and forth until we're done. + // Note that we may continue to attempt to send data + // even after done is closed. + i := warmupIters + for ping := Ping; ; ping = !ping { + pair.Send(t, ping, done) + select { + case <-done: + return + default: + } + if i > 0 { + warmup.Done() + i-- } - case <-time.After(300 * time.Millisecond): - t.Error("ping did not transit") + } + }() + warmup.Wait() + + applyCfg := func(cfg string) { + err := pair[0].dev.IpcSet(cfg) + if err != nil { + t.Fatal(err) + } + } + + // Change persistent_keepalive_interval concurrently with tunnel use. + t.Run("persistentKeepaliveInterval", func(t *testing.T) { + var pub NoisePublicKey + for key := range pair[0].dev.peers.keyMap { + pub = key + break + } + cfg := uapiCfg( + "public_key", hex.EncodeToString(pub[:]), + "persistent_keepalive_interval", "1", + ) + for i := 0; i < 1000; i++ { + applyCfg(cfg) } }) - t.Run("ping 1.0.0.2", func(t *testing.T) { - msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - tun1.Outbound <- msg1to2 - select { - case msgRecv := <-tun2.Inbound: - if !bytes.Equal(msg1to2, msgRecv) { - t.Error("return ping did not transit correctly") + // Change private keys concurrently with tunnel use. + t.Run("privateKey", func(t *testing.T) { + bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") + good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:])) + // Set iters to a large number like 1000 to flush out data races quickly. + // Don't leave it large. That can cause logical races + // in which the handshake is interleaved with key changes + // such that the private key appears to be unchanging but + // other state gets reset, which can cause handshake failures like + // "Received packet with invalid mac1". + const iters = 1 + for i := 0; i < iters; i++ { + applyCfg(bad) + applyCfg(good) + } + }) + + // Perform bind updates and keepalive sends concurrently with tunnel use. + t.Run("bindUpdate and keepalive", func(t *testing.T) { + const iters = 10 + for i := 0; i < iters; i++ { + for _, peer := range pair { + peer.dev.BindUpdate() + peer.dev.SendKeepalivesToPeersWithCurrentKeypair() } - case <-time.After(300 * time.Millisecond): - t.Error("return ping did not transit") } }) + + close(done) } -func assertNil(t *testing.T, err error) { - if err != nil { - t.Fatal(err) +func BenchmarkLatency(b *testing.B) { + pair := genTestPair(b, true) + + // Establish a connection. + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) } } -func assertEqual(t *testing.T, a, b []byte) { - if !bytes.Equal(a, b) { - t.Fatal(a, "!=", b) +func BenchmarkThroughput(b *testing.B) { + pair := genTestPair(b, true) + + // Establish a connection. + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) + + // Measure how long it takes to receive b.N packets, + // starting when we receive the first packet. + var recv atomic.Uint64 + var elapsed time.Duration + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var start time.Time + for { + <-pair[0].tun.Inbound + new := recv.Add(1) + if new == 1 { + start = time.Now() + } + // Careful! Don't change this to else if; b.N can be equal to 1. + if new == uint64(b.N) { + elapsed = time.Since(start) + return + } + } + }() + + // Send packets as fast as we can until we've received enough. + ping := tuntest.Ping(pair[0].ip, pair[1].ip) + pingc := pair[1].tun.Outbound + var sent uint64 + for recv.Load() != uint64(b.N) { + sent++ + pingc <- ping } + wg.Wait() + + b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op") + b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss") } -func randDevice(t *testing.T) *Device { - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) +func BenchmarkUAPIGet(b *testing.B) { + pair := genTestPair(b, true) + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + pair[0].dev.IpcGetOperation(io.Discard) + } +} + +func goroutineLeakCheck(t *testing.T) { + goroutines := func() (int, []byte) { + p := pprof.Lookup("goroutine") + b := new(bytes.Buffer) + p.WriteTo(b, 1) + return p.Count(), b.Bytes() + } + + startGoroutines, startStacks := goroutines() + t.Cleanup(func() { + if t.Failed() { + return + } + // Give goroutines time to exit, if they need it. + for i := 0; i < 10000; i++ { + if runtime.NumGoroutine() <= startGoroutines { + return + } + time.Sleep(1 * time.Millisecond) + } + endGoroutines, endStacks := goroutines() + t.Logf("starting stacks:\n%s\n", startStacks) + t.Logf("ending stacks:\n%s\n", endStacks) + t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) + }) +} + +type fakeBindSized struct { + size int +} + +func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + return nil, 0, nil +} +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } + +type fakeTUNDeviceSized struct { + size int +} + +func (t *fakeTUNDeviceSized) File() *os.File { return nil } +func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + return 0, nil +} +func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } +func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } +func (t *fakeTUNDeviceSized) Close() error { return nil } +func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } + +func TestBatchSize(t *testing.T) { + d := Device{} + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 1, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) } - tun := newDummyTUN("dummy") - logger := NewLogger(LogLevelError, "") - device := NewDevice(tun, logger) - device.SetPrivateKey(sk) - return device } diff --git a/device/devicestate_string.go b/device/devicestate_string.go new file mode 100644 index 0000000..6577dd4 --- /dev/null +++ b/device/devicestate_string.go @@ -0,0 +1,16 @@ +// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT. + +package device + +import "strconv" + +const _deviceState_name = "DownUpClosed" + +var _deviceState_index = [...]uint8{0, 4, 6, 12} + +func (i deviceState) String() string { + if i >= deviceState(len(_deviceState_index)-1) { + return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]] +} diff --git a/device/endpoint_test.go b/device/endpoint_test.go index e66d493..93a4998 100644 --- a/device/endpoint_test.go +++ b/device/endpoint_test.go @@ -1,53 +1,49 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" - "net" + "net/netip" ) type DummyEndpoint struct { - src [16]byte - dst [16]byte + src, dst netip.Addr } func CreateDummyEndpoint() (*DummyEndpoint, error) { - var end DummyEndpoint - if _, err := rand.Read(end.src[:]); err != nil { + var src, dst [16]byte + if _, err := rand.Read(src[:]); err != nil { return nil, err } - _, err := rand.Read(end.dst[:]) - return &end, err + _, err := rand.Read(dst[:]) + return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err } func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) SrcToString() string { - var addr net.UDPAddr - addr.IP = e.SrcIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.SrcIP(), 1000).String() } func (e *DummyEndpoint) DstToString() string { - var addr net.UDPAddr - addr.IP = e.DstIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.DstIP(), 1000).String() } -func (e *DummyEndpoint) SrcToBytes() []byte { - return e.src[:] +func (e *DummyEndpoint) DstToBytes() []byte { + out := e.DstIP().AsSlice() + out = append(out, byte(1000&0xff)) + out = append(out, byte((1000>>8)&0xff)) + return out } -func (e *DummyEndpoint) DstIP() net.IP { - return e.dst[:] +func (e *DummyEndpoint) DstIP() netip.Addr { + return e.dst } -func (e *DummyEndpoint) SrcIP() net.IP { - return e.src[:] +func (e *DummyEndpoint) SrcIP() netip.Addr { + return e.src } diff --git a/device/indextable.go b/device/indextable.go index 5e10eef..00ade7d 100644 --- a/device/indextable.go +++ b/device/indextable.go @@ -1,14 +1,14 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "crypto/rand" + "encoding/binary" "sync" - "unsafe" ) type IndexTableEntry struct { @@ -25,7 +25,8 @@ type IndexTable struct { func randUint32() (uint32, error) { var integer [4]byte _, err := rand.Read(integer[:]) - return *(*uint32)(unsafe.Pointer(&integer[0])), err + // Arbitrary endianness; both are intrinsified by the Go compiler. + return binary.LittleEndian.Uint32(integer[:]), err } func (table *IndexTable) Init() { diff --git a/device/ip.go b/device/ip.go index 3bc6929..eaf2363 100644 --- a/device/ip.go +++ b/device/ip.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/kdf_test.go b/device/kdf_test.go index 1a3bc87..f9c76d6 100644 --- a/device/kdf_test.go +++ b/device/kdf_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -20,7 +20,7 @@ type KDFTest struct { t2 string } -func assertEquals(t *testing.T, a string, b string) { +func assertEquals(t *testing.T, a, b string) { if a != b { t.Fatal("expected", a, "=", b) } diff --git a/device/keypair.go b/device/keypair.go index 2f2f222..e3540d7 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -10,7 +10,6 @@ import ( "sync" "sync/atomic" "time" - "unsafe" "golang.zx2c4.com/wireguard/replay" ) @@ -23,10 +22,10 @@ import ( */ type Keypair struct { - sendNonce uint64 + sendNonce atomic.Uint64 send cipher.AEAD receive cipher.AEAD - replayFilter replay.ReplayFilter + replayFilter replay.Filter isInitiator bool created time.Time localIndex uint32 @@ -37,15 +36,7 @@ type Keypairs struct { sync.RWMutex current *Keypair previous *Keypair - next *Keypair -} - -func (kp *Keypairs) storeNext(next *Keypair) { - atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next)) -} - -func (kp *Keypairs) loadNext() *Keypair { - return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)))) + next atomic.Pointer[Keypair] } func (kp *Keypairs) Current() *Keypair { diff --git a/device/logger.go b/device/logger.go index 3c4d744..22b0df0 100644 --- a/device/logger.go +++ b/device/logger.go @@ -1,59 +1,48 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "io" - "io/ioutil" "log" "os" ) +// A Logger provides logging for a Device. +// The functions are Printf-style functions. +// They must be safe for concurrent use. +// They do not require a trailing newline in the format. +// If nil, that level of logging will be silent. +type Logger struct { + Verbosef func(format string, args ...any) + Errorf func(format string, args ...any) +} + +// Log levels for use with NewLogger. const ( LogLevelSilent = iota LogLevelError - LogLevelInfo - LogLevelDebug + LogLevelVerbose ) -type Logger struct { - Debug *log.Logger - Info *log.Logger - Error *log.Logger -} +// Function for use in Logger for discarding logged lines. +func DiscardLogf(format string, args ...any) {} +// NewLogger constructs a Logger that writes to stdout. +// It logs at the specified log level and above. +// It decorates log lines with the log level, date, time, and prepend. func NewLogger(level int, prepend string) *Logger { - output := os.Stdout - logger := new(Logger) - - logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) { - if level >= LogLevelDebug { - return output, output, output - } - if level >= LogLevelInfo { - return output, output, ioutil.Discard - } - if level >= LogLevelError { - return output, ioutil.Discard, ioutil.Discard - } - return ioutil.Discard, ioutil.Discard, ioutil.Discard - }() - - logger.Debug = log.New(logDebug, - "DEBUG: "+prepend, - log.Ldate|log.Ltime, - ) - - logger.Info = log.New(logInfo, - "INFO: "+prepend, - log.Ldate|log.Ltime, - ) - logger.Error = log.New(logErr, - "ERROR: "+prepend, - log.Ldate|log.Ltime, - ) + logger := &Logger{DiscardLogf, DiscardLogf} + logf := func(prefix string) func(string, ...any) { + return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf + } + if level >= LogLevelVerbose { + logger.Verbosef = logf("DEBUG") + } + if level >= LogLevelError { + logger.Errorf = logf("ERROR") + } return logger } diff --git a/device/misc.go b/device/misc.go deleted file mode 100644 index 30d1156..0000000 --- a/device/misc.go +++ /dev/null @@ -1,48 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "sync/atomic" -) - -/* Atomic Boolean */ - -const ( - AtomicFalse = int32(iota) - AtomicTrue -) - -type AtomicBool struct { - int32 -} - -func (a *AtomicBool) Get() bool { - return atomic.LoadInt32(&a.int32) == AtomicTrue -} - -func (a *AtomicBool) Swap(val bool) bool { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - return atomic.SwapInt32(&a.int32, flag) == AtomicTrue -} - -func (a *AtomicBool) Set(val bool) { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - atomic.StoreInt32(&a.int32, flag) -} - -func min(a, b uint) uint { - if a > b { - return b - } - return a -} diff --git a/device/mobilequirks.go b/device/mobilequirks.go new file mode 100644 index 0000000..0a0080e --- /dev/null +++ b/device/mobilequirks.go @@ -0,0 +1,19 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created, +// though it will try to deal with it, and race maybe, if called after. +func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { + device.net.brokenRoaming = true + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.endpoint.Lock() + peer.endpoint.disableRoaming = peer.endpoint.val != nil + peer.endpoint.Unlock() + } + device.peers.RUnlock() +} diff --git a/device/noise-helpers.go b/device/noise-helpers.go index b3b5acf..c2f356b 100644 --- a/device/noise-helpers.go +++ b/device/noise-helpers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -9,6 +9,7 @@ import ( "crypto/hmac" "crypto/rand" "crypto/subtle" + "errors" "hash" "golang.org/x/crypto/blake2s" @@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { return } -func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { +var errInvalidPublicKey = errors.New("invalid public key") + +func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) { apk := (*[NoisePublicKeySize]byte)(&pk) ask := (*[NoisePrivateKeySize]byte)(sk) curve25519.ScalarMult(&ss, ask, apk) - return ss + if isZero(ss[:]) { + return ss, errInvalidPublicKey + } + return ss, nil } diff --git a/device/noise-protocol.go b/device/noise-protocol.go index be92b4b..e8f6145 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -20,7 +20,6 @@ import ( type handshakeState int -// TODO(crawshaw): add commentary describing each state and the transitions const ( handshakeZeroed = handshakeState(iota) handshakeInitiationCreated @@ -121,7 +120,7 @@ type Handshake struct { mutex sync.RWMutex hash [blake2s.Size]byte // hash value chainKey [blake2s.Size]byte // chain key - presharedKey NoiseSymmetricKey // psk + presharedKey NoisePresharedKey // psk localEphemeral NoisePrivateKey // ephemeral secret key localIndex uint32 // used to clear hash-table remoteIndex uint32 // index for sending @@ -139,11 +138,11 @@ var ( ZeroNonce [chacha20poly1305.NonceSize]byte ) -func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { +func mixKey(dst, c *[blake2s.Size]byte, data []byte) { KDF1(dst, c[:], data) } -func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { +func mixHash(dst, h *[blake2s.Size]byte, data []byte) { hash, _ := blake2s.New256(nil) hash.Write(h[:]) hash.Write(data) @@ -176,8 +175,6 @@ func init() { } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - var errZeroECDHResult = errors.New("ECDH returned all zeros") - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -205,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(msg.Ephemeral[:]) // encrypt static key - ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - if isZero(ss[:]) { - return nil, errZeroECDHResult + ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if err != nil { + return nil, err } var key [chacha20poly1305.KeySize]byte KDF2( @@ -222,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e // encrypt timestamp if isZero(handshake.precomputedStaticStatic[:]) { - return nil, errZeroECDHResult + return nil, errInvalidPublicKey } KDF2( &handshake.chainKey, @@ -265,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) // decrypt static key - var err error var peerPK NoisePublicKey var key [chacha20poly1305.KeySize]byte - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - if isZero(ss[:]) { + ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if err != nil { return nil } KDF2(&chainKey, &key, chainKey[:], ss[:]) @@ -283,7 +279,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // lookup peer peer := device.LookupPeer(peerPK) - if peer == nil { + if peer == nil || !peer.isRunning.Load() { return nil } @@ -319,11 +315,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate handshake.mutex.RUnlock() if replay { - device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake replay @ %v\n", peer, timestamp) + device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) return nil } if flood { - device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake flood\n", peer) + device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) return nil } @@ -385,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixHash(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:]) - func() { - ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) - handshake.mixKey(ss[:]) - ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.mixKey(ss[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + if err != nil { + return nil, err + } + handshake.mixKey(ss[:]) + ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if err != nil { + return nil, err + } + handshake.mixKey(ss[:]) // add preshared key @@ -407,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixHash(tau[:]) - func() { - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) - handshake.mixHash(msg.Empty[:]) - }() + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) + handshake.mixHash(msg.Empty[:]) handshake.state = handshakeResponseCreated @@ -437,7 +435,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { ) ok := func() bool { - // lock handshake state handshake.mutex.RLock() @@ -457,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) - func() { - ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + if err != nil { + return false + } + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) - func() { - ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() + ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if err != nil { + return false + } + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) // add preshared key (psk) @@ -485,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // authenticate transcript aead, _ := chacha20poly1305.New(key[:]) - _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) if err != nil { return false } @@ -566,8 +565,7 @@ func (peer *Peer) BeginSymmetricSession() error { setZero(recvKey[:]) keypair.created = time.Now() - keypair.sendNonce = 0 - keypair.replayFilter.Init() + keypair.replayFilter.Reset() keypair.isInitiator = isInitiator keypair.localIndex = peer.handshake.localIndex keypair.remoteIndex = peer.handshake.remoteIndex @@ -584,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error { defer keypairs.Unlock() previous := keypairs.previous - next := keypairs.loadNext() + next := keypairs.next.Load() current := keypairs.current if isInitiator { if next != nil { - keypairs.storeNext(nil) + keypairs.next.Store(nil) keypairs.previous = next device.DeleteKeypair(current) } else { @@ -598,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error { device.DeleteKeypair(previous) keypairs.current = keypair } else { - keypairs.storeNext(keypair) + keypairs.next.Store(keypair) device.DeleteKeypair(next) keypairs.previous = nil device.DeleteKeypair(previous) @@ -610,18 +608,18 @@ func (peer *Peer) BeginSymmetricSession() error { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { keypairs := &peer.keypairs - if keypairs.loadNext() != receivedKeypair { + if keypairs.next.Load() != receivedKeypair { return false } keypairs.Lock() defer keypairs.Unlock() - if keypairs.loadNext() != receivedKeypair { + if keypairs.next.Load() != receivedKeypair { return false } old := keypairs.previous keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - keypairs.current = keypairs.loadNext() - keypairs.storeNext(nil) + keypairs.current = keypairs.next.Load() + keypairs.next.Store(nil) return true } diff --git a/device/noise-types.go b/device/noise-types.go index f793ef5..e850359 100644 --- a/device/noise-types.go +++ b/device/noise-types.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -9,19 +9,18 @@ import ( "crypto/subtle" "encoding/hex" "errors" - - "golang.org/x/crypto/chacha20poly1305" ) const ( - NoisePublicKeySize = 32 - NoisePrivateKeySize = 32 + NoisePublicKeySize = 32 + NoisePrivateKeySize = 32 + NoisePresharedKeySize = 32 ) type ( NoisePublicKey [NoisePublicKeySize]byte NoisePrivateKey [NoisePrivateKeySize]byte - NoiseSymmetricKey [chacha20poly1305.KeySize]byte + NoisePresharedKey [NoisePresharedKeySize]byte NoiseNonce uint64 // padded to 12-bytes ) @@ -61,18 +60,10 @@ func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { return } -func (key NoisePrivateKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - func (key *NoisePublicKey) FromHex(src string) error { return loadExactHex(key[:], src) } -func (key NoisePublicKey) ToHex() string { - return hex.EncodeToString(key[:]) -} - func (key NoisePublicKey) IsZero() bool { var zero NoisePublicKey return key.Equals(zero) @@ -82,10 +73,6 @@ func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 } -func (key *NoiseSymmetricKey) FromHex(src string) error { +func (key *NoisePresharedKey) FromHex(src string) error { return loadExactHex(key[:], src) } - -func (key NoiseSymmetricKey) ToHex() string { - return hex.EncodeToString(key[:]) -} diff --git a/device/noise_test.go b/device/noise_test.go index ce89851..2dd5324 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -9,6 +9,9 @@ import ( "bytes" "encoding/binary" "testing" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/tun/tuntest" ) func TestCurveWrappers(t *testing.T) { @@ -21,14 +24,38 @@ func TestCurveWrappers(t *testing.T) { pk1 := sk1.publicKey() pk2 := sk2.publicKey() - ss1 := sk1.sharedSecret(pk2) - ss2 := sk2.sharedSecret(pk1) + ss1, err1 := sk1.sharedSecret(pk2) + ss2, err2 := sk2.sharedSecret(pk1) - if ss1 != ss2 { + if ss1 != ss2 || err1 != nil || err2 != nil { t.Fatal("Failed to compute shared secet") } } +func randDevice(t *testing.T) *Device { + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + tun := tuntest.NewChannelTUN() + logger := NewLogger(LogLevelError, "") + device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) + device.SetPrivateKey(sk) + return device +} + +func assertNil(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } +} + +func assertEqual(t *testing.T, a, b []byte) { + if !bytes.Equal(a, b) { + t.Fatal(a, "!=", b) + } +} + func TestNoiseHandshake(t *testing.T) { dev1 := randDevice(t) dev2 := randDevice(t) @@ -36,8 +63,16 @@ func TestNoiseHandshake(t *testing.T) { defer dev1.Close() defer dev2.Close() - peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) - peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) + peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) + if err != nil { + t.Fatal(err) + } + peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) + if err != nil { + t.Fatal(err) + } + peer1.Start() + peer2.Start() assertEqual( t, @@ -113,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) { t.Fatal("failed to derive keypair for peer 2", err) } - key1 := peer1.keypairs.loadNext() + key1 := peer1.keypairs.next.Load() key2 := peer2.keypairs.current // encrypting / decryption test diff --git a/device/peer.go b/device/peer.go index d13acd9..47a2f14 100644 --- a/device/peer.go +++ b/device/peer.go @@ -1,14 +1,13 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "encoding/base64" + "container/list" "errors" - "fmt" "sync" "sync/atomic" "time" @@ -16,28 +15,21 @@ import ( "golang.zx2c4.com/wireguard/conn" ) -const ( - PeerRoutineNumber = 3 -) - type Peer struct { - isRunning AtomicBool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer - keypairs Keypairs - handshake Handshake - device *Device - endpoint conn.Endpoint - persistentKeepaliveInterval uint16 - - // These fields are accessed with atomic operations, which must be - // 64-bit aligned even on 32-bit platforms. Go guarantees that an - // allocated struct will be 64-bit aligned. So we place - // atomically-accessed fields up front, so that they can share in - // this alignment before smaller fields throw it off. - stats struct { - txBytes uint64 // bytes send to peer (endpoint) - rxBytes uint64 // bytes received from peer - lastHandshakeNano int64 // nano seconds since epoch + isRunning atomic.Bool + keypairs Keypairs + handshake Handshake + device *Device + stopping sync.WaitGroup // routines pending stop + txBytes atomic.Uint64 // bytes send to peer (endpoint) + rxBytes atomic.Uint64 // bytes received from peer + lastHandshakeNano atomic.Int64 // nano seconds since epoch + + endpoint struct { + sync.Mutex + val conn.Endpoint + clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission + disableRoaming bool } timers struct { @@ -46,40 +38,32 @@ type Peer struct { newHandshake *Timer zeroKeyMaterial *Timer persistentKeepalive *Timer - handshakeAttempts uint32 - needAnotherKeepalive AtomicBool - sentLastMinuteHandshake AtomicBool + handshakeAttempts atomic.Uint32 + needAnotherKeepalive atomic.Bool + sentLastMinuteHandshake atomic.Bool } - signals struct { - newKeypairArrived chan struct{} - flushNonceQueue chan struct{} + state struct { + sync.Mutex // protects against concurrent Start/Stop } queue struct { - nonce chan *QueueOutboundElement // nonce / pre-handshake queue - outbound chan *QueueOutboundElement // sequential ordering of work - inbound chan *QueueInboundElement // sequential ordering of work - packetInNonceQueueIsAwaitingKey AtomicBool + staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } - routines struct { - sync.Mutex // held when stopping / starting routines - starting sync.WaitGroup // routines pending start - stopping sync.WaitGroup // routines pending stop - stop chan struct{} // size 0, stop all go routines in peer - } - - cookieGenerator CookieGenerator + cookieGenerator CookieGenerator + trieEntries list.List + persistentKeepaliveInterval atomic.Uint32 } func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { - if device.isClosed.Get() { + if device.isClosed() { return nil, errors.New("device closed") } // lock resources - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -87,131 +71,144 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { defer device.peers.Unlock() // check if over limit - if len(device.peers.keyMap) >= MaxPeers { return nil, errors.New("too many peers") } // create peer - peer := new(Peer) - peer.Lock() - defer peer.Unlock() peer.cookieGenerator.Init(pk) peer.device = device - peer.isRunning.Set(false) + peer.queue.outbound = newAutodrainingOutboundQueue(device) + peer.queue.inbound = newAutodrainingInboundQueue(device) + peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) // map public key - _, ok := device.peers.keyMap[pk] if ok { return nil, errors.New("adding existing peer") } // pre-compute DH - handshake := &peer.handshake handshake.mutex.Lock() - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk) handshake.remoteStatic = pk handshake.mutex.Unlock() // reset endpoint + peer.endpoint.Lock() + peer.endpoint.val = nil + peer.endpoint.disableRoaming = false + peer.endpoint.clearSrcOnTx = false + peer.endpoint.Unlock() - peer.endpoint = nil + // init timers + peer.timersInit() // add - device.peers.keyMap[pk] = peer - // start peer - - if peer.device.isUp.Get() { - peer.Start() - } - return peer, nil } -func (peer *Peer) SendBuffer(buffer []byte) error { +func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() - if peer.device.net.bind == nil { - return errors.New("no bind") + if peer.device.isClosed() { + return nil } - peer.RLock() - defer peer.RUnlock() - - if peer.endpoint == nil { + peer.endpoint.Lock() + endpoint := peer.endpoint.val + if endpoint == nil { + peer.endpoint.Unlock() return errors.New("no known endpoint for peer") } + if peer.endpoint.clearSrcOnTx { + endpoint.ClearSrc() + peer.endpoint.clearSrcOnTx = false + } + peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffer, peer.endpoint) + err := peer.device.net.bind.Send(buffers, endpoint) if err == nil { - atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) + var totalLen uint64 + for _, b := range buffers { + totalLen += uint64(len(b)) + } + peer.txBytes.Add(totalLen) } return err } func (peer *Peer) String() string { - base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) - abbreviatedKey := "invalid" - if len(base64Key) == 44 { - abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] + // The awful goo that follows is identical to: + // + // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) + // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43] + // return fmt.Sprintf("peer(%s)", abbreviatedKey) + // + // except that it is considerably more efficient. + src := peer.handshake.remoteStatic + b64 := func(input byte) byte { + return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3) } - return fmt.Sprintf("peer(%s)", abbreviatedKey) + b := []byte("peer(____…____)") + const first = len("peer(") + const second = len("peer(____…") + b[first+0] = b64((src[0] >> 2) & 63) + b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63) + b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63) + b[first+3] = b64(src[2] & 63) + b[second+0] = b64(src[29] & 63) + b[second+1] = b64((src[30] >> 2) & 63) + b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63) + b[second+3] = b64((src[31] << 2) & 63) + return string(b) } func (peer *Peer) Start() { - // should never start a peer on a closed device - - if peer.device.isClosed.Get() { + if peer.device.isClosed() { return } // prevent simultaneous start/stop operations + peer.state.Lock() + defer peer.state.Unlock() - peer.routines.Lock() - defer peer.routines.Unlock() - - if peer.isRunning.Get() { + if peer.isRunning.Load() { return } device := peer.device - device.log.Debug.Println(peer, "- Starting...") + device.log.Verbosef("%v - Starting", peer) // reset routine state + peer.stopping.Wait() + peer.stopping.Add(2) - peer.routines.starting.Wait() - peer.routines.stopping.Wait() - peer.routines.stop = make(chan struct{}) - peer.routines.starting.Add(PeerRoutineNumber) - peer.routines.stopping.Add(PeerRoutineNumber) + peer.handshake.mutex.Lock() + peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + peer.handshake.mutex.Unlock() - // prepare queues + peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes - peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + peer.timersStart() - peer.timersInit() - peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) - peer.signals.newKeypairArrived = make(chan struct{}, 1) - peer.signals.flushNonceQueue = make(chan struct{}, 1) + device.flushInboundQueue(peer.queue.inbound) + device.flushOutboundQueue(peer.queue.outbound) - // wait for routines to start + // Use the device batch size, not the bind batch size, as the device size is + // the size of the batch pools. + batchSize := peer.device.BatchSize() + go peer.RoutineSequentialSender(batchSize) + go peer.RoutineSequentialReceiver(batchSize) - go peer.RoutineNonce() - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() - - peer.routines.starting.Wait() - peer.isRunning.Set(true) + peer.isRunning.Store(true) } func (peer *Peer) ZeroAndFlushAll() { @@ -223,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() { keypairs.Lock() device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.current) - device.DeleteKeypair(keypairs.loadNext()) + device.DeleteKeypair(keypairs.next.Load()) keypairs.previous = nil keypairs.current = nil - keypairs.storeNext(nil) + keypairs.next.Store(nil) keypairs.Unlock() // clear handshake state @@ -237,7 +234,7 @@ func (peer *Peer) ZeroAndFlushAll() { handshake.Clear() handshake.mutex.Unlock() - peer.FlushNonceQueue() + peer.FlushStagedPackets() } func (peer *Peer) ExpireCurrentKeypairs() { @@ -245,58 +242,55 @@ func (peer *Peer) ExpireCurrentKeypairs() { handshake.mutex.Lock() peer.device.indexTable.Delete(handshake.localIndex) handshake.Clear() - handshake.mutex.Unlock() peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + handshake.mutex.Unlock() keypairs := &peer.keypairs keypairs.Lock() if keypairs.current != nil { - keypairs.current.sendNonce = RejectAfterMessages + keypairs.current.sendNonce.Store(RejectAfterMessages) } - if keypairs.next != nil { - keypairs.loadNext().sendNonce = RejectAfterMessages + if next := keypairs.next.Load(); next != nil { + next.sendNonce.Store(RejectAfterMessages) } keypairs.Unlock() } func (peer *Peer) Stop() { - - // prevent simultaneous start/stop operations + peer.state.Lock() + defer peer.state.Unlock() if !peer.isRunning.Swap(false) { return } - peer.routines.starting.Wait() - - peer.routines.Lock() - defer peer.routines.Unlock() - - peer.device.log.Debug.Println(peer, "- Stopping...") + peer.device.log.Verbosef("%v - Stopping", peer) peer.timersStop() - - // stop & wait for ongoing peer routines - - close(peer.routines.stop) - peer.routines.stopping.Wait() - - // close queues - - close(peer.queue.nonce) - close(peer.queue.outbound) - close(peer.queue.inbound) + // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. + peer.queue.inbound.c <- nil + peer.queue.outbound.c <- nil + peer.stopping.Wait() + peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us peer.ZeroAndFlushAll() } -var RoamingDisabled bool - func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { - if RoamingDisabled { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.disableRoaming { + return + } + peer.endpoint.clearSrcOnTx = false + peer.endpoint.val = endpoint +} + +func (peer *Peer) markEndpointSrcForClearing() { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.val == nil { return } - peer.Lock() - peer.endpoint = endpoint - peer.Unlock() + peer.endpoint.clearSrcOnTx = true } diff --git a/device/peer_test.go b/device/peer_test.go deleted file mode 100644 index 6aa238b..0000000 --- a/device/peer_test.go +++ /dev/null @@ -1,43 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "reflect" - "testing" - "unsafe" -) - -func checkAlignment(t *testing.T, name string, offset uintptr) { - t.Helper() - if offset%8 != 0 { - t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) - } -} - -// TestPeerAlignment checks that atomically-accessed fields are -// aligned to 64-bit boundaries, as required by the atomic package. -// -// Unfortunately, violating this rule on 32-bit platforms results in a -// hard segfault at runtime. -func TestPeerAlignment(t *testing.T) { - var p Peer - - typ := reflect.TypeOf(p) - t.Logf("Peer type size: %d, with fields:", typ.Size()) - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) - t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", - field.Name, - field.Offset, - field.Type.Size(), - field.Type.Align(), - ) - } - - checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats)) - checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) -} diff --git a/device/pools.go b/device/pools.go index e778d2e..94f3dc7 100644 --- a/device/pools.go +++ b/device/pools.go @@ -1,89 +1,120 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device -import "sync" +import ( + "sync" + "sync/atomic" +) -func (device *Device) PopulatePools() { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool = &sync.Pool{ - New: func() interface{} { - return new([MaxMessageSize]byte) - }, - } - device.pool.inboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueInboundElement) - }, - } - device.pool.outboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueOutboundElement) - }, - } - } else { - device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte) - } - device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.inboundElementReuseChan <- new(QueueInboundElement) - } - device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.outboundElementReuseChan <- new(QueueOutboundElement) +type WaitPool struct { + pool sync.Pool + cond sync.Cond + lock sync.Mutex + count atomic.Uint32 + max uint32 +} + +func NewWaitPool(max uint32, new func() any) *WaitPool { + p := &WaitPool{pool: sync.Pool{New: new}, max: max} + p.cond = sync.Cond{L: &p.lock} + return p +} + +func (p *WaitPool) Get() any { + if p.max != 0 { + p.lock.Lock() + for p.count.Load() >= p.max { + p.cond.Wait() } + p.count.Add(1) + p.lock.Unlock() } + return p.pool.Get() } -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - if PreallocatedBuffersPerPool == 0 { - return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte) - } else { - return <-device.pool.messageBufferReuseChan +func (p *WaitPool) Put(x any) { + p.pool.Put(x) + if p.max == 0 { + return } + p.count.Add(^uint32(0)) + p.cond.Signal() } -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool.Put(msg) - } else { - device.pool.messageBufferReuseChan <- msg - } +func (device *Device) PopulatePools() { + device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueInboundElement, 0, device.BatchSize()) + return &QueueInboundElementsContainer{elems: s} + }) + device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &QueueOutboundElementsContainer{elems: s} + }) + device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new([MaxMessageSize]byte) + }) + device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new(QueueInboundElement) + }) + device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new(QueueOutboundElement) + }) } -func (device *Device) GetInboundElement() *QueueInboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.inboundElementPool.Get().(*QueueInboundElement) - } else { - return <-device.pool.inboundElementReuseChan +func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { + c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) + c.Mutex = sync.Mutex{} + return c +} + +func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } + c.elems = c.elems[:0] + device.pool.inboundElementsContainer.Put(c) +} + +func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { + c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutInboundElement(msg *QueueInboundElement) { - if PreallocatedBuffersPerPool == 0 { - device.pool.inboundElementPool.Put(msg) - } else { - device.pool.inboundElementReuseChan <- msg +func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } + c.elems = c.elems[:0] + device.pool.outboundElementsContainer.Put(c) +} + +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + device.pool.messageBuffers.Put(msg) +} + +func (device *Device) GetInboundElement() *QueueInboundElement { + return device.pool.inboundElements.Get().(*QueueInboundElement) +} + +func (device *Device) PutInboundElement(elem *QueueInboundElement) { + elem.clearPointers() + device.pool.inboundElements.Put(elem) } func (device *Device) GetOutboundElement() *QueueOutboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.outboundElementPool.Get().(*QueueOutboundElement) - } else { - return <-device.pool.outboundElementReuseChan - } + return device.pool.outboundElements.Get().(*QueueOutboundElement) } -func (device *Device) PutOutboundElement(msg *QueueOutboundElement) { - if PreallocatedBuffersPerPool == 0 { - device.pool.outboundElementPool.Put(msg) - } else { - device.pool.outboundElementReuseChan <- msg - } +func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { + elem.clearPointers() + device.pool.outboundElements.Put(elem) } diff --git a/device/pools_test.go b/device/pools_test.go new file mode 100644 index 0000000..82d7493 --- /dev/null +++ b/device/pools_test.go @@ -0,0 +1,139 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "math/rand" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestWaitPool(t *testing.T) { + t.Skip("Currently disabled") + var wg sync.WaitGroup + var trials atomic.Int32 + startTrials := int32(100000) + if raceEnabled { + // This test can be very slow with -race. + startTrials /= 10 + } + trials.Store(startTrials) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + t.Skip("Not enough cores") + } + p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) + wg.Add(workers) + var max atomic.Uint32 + updateMax := func() { + count := p.count.Load() + if count > p.max { + t.Errorf("count (%d) > max (%d)", count, p.max) + } + for { + old := max.Load() + if count <= old { + break + } + if max.CompareAndSwap(old, count) { + break + } + } + } + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + updateMax() + x := p.Get() + updateMax() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + updateMax() + p.Put(x) + updateMax() + } + }() + } + wg.Wait() + if max.Load() != p.max { + t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) + } +} + +func BenchmarkWaitPool(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} + +func BenchmarkWaitPoolEmpty(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := NewWaitPool(0, func() any { return make([]byte, 16) }) + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} + +func BenchmarkSyncPool(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := sync.Pool{New: func() any { return make([]byte, 16) }} + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index f19c7be..25f700a 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -1,16 +1,19 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device +import "golang.zx2c4.com/wireguard/conn" + /* Reduce memory consumption for Android */ const ( + QueueStagedSize = conn.IdealBatchSize QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 - MaxSegmentSize = 2200 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram PreallocatedBuffersPerPool = 4096 ) diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index 18f0bea..ea763d0 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -1,13 +1,16 @@ -// +build !android,!ios +//go:build !android && !ios && !windows /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device +import "golang.zx2c4.com/wireguard/conn" + const ( + QueueStagedSize = conn.IdealBatchSize QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go index 4c83015..acd3cec 100644 --- a/device/queueconstants_ios.go +++ b/device/queueconstants_ios.go @@ -1,18 +1,21 @@ -// +build ios +//go:build ios /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device -/* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */ - -const ( - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = 1700 - PreallocatedBuffersPerPool = 1024 +// Fit within memory limits for iOS's Network Extension API, which has stricter requirements. +// These are vars instead of consts, because heavier network extensions might want to reduce +// them further. +var ( + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + PreallocatedBuffersPerPool uint32 = 1024 ) + +const MaxSegmentSize = 1700 diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go new file mode 100644 index 0000000..1eee32b --- /dev/null +++ b/device/queueconstants_windows.go @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +const ( + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + MaxSegmentSize = 2048 - 32 // largest possible UDP datagram + PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth +) diff --git a/device/race_disabled_test.go b/device/race_disabled_test.go new file mode 100644 index 0000000..bb5c450 --- /dev/null +++ b/device/race_disabled_test.go @@ -0,0 +1,10 @@ +//go:build !race + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +const raceEnabled = false diff --git a/device/race_enabled_test.go b/device/race_enabled_test.go new file mode 100644 index 0000000..4e9daea --- /dev/null +++ b/device/race_enabled_test.go @@ -0,0 +1,10 @@ +//go:build race + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +const raceEnabled = true diff --git a/device/receive.go b/device/receive.go index b53c9c0..1ab3e29 100644 --- a/device/receive.go +++ b/device/receive.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,10 +8,9 @@ package device import ( "bytes" "encoding/binary" + "errors" "net" - "strconv" "sync" - "sync/atomic" "time" "golang.org/x/crypto/chacha20poly1305" @@ -28,8 +27,6 @@ type QueueHandshakeElement struct { } type QueueInboundElement struct { - dropped int32 - sync.Mutex buffer *[MaxMessageSize]byte packet []byte counter uint64 @@ -37,38 +34,20 @@ type QueueInboundElement struct { endpoint conn.Endpoint } -func (elem *QueueInboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueInboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool { - select { - case inboundQueue <- element: - select { - case decryptionQueue <- element: - return true - default: - element.Drop() - element.Unlock() - return false - } - default: - device.PutInboundElement(element) - return false - } +type QueueInboundElementsContainer struct { + sync.Mutex + elems []*QueueInboundElement } -func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool { - select { - case queue <- element: - return true - default: - return false - } +// clearPointers clears elem fields that contain pointers. +// This makes the garbage collector's life easier and +// avoids accidentally keeping other objects around unnecessarily. +// It also reduces the possible collateral damage from use-after-free bugs. +func (elem *QueueInboundElement) clearPointers() { + elem.buffer = nil + elem.packet = nil + elem.keypair = nil + elem.endpoint = nil } /* Called when a new authenticated message has been received @@ -76,12 +55,12 @@ func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem * NOTE: Not thread safe, but called by sequential receiver! */ func (peer *Peer) keepKeyFreshReceiving() { - if peer.timers.sentLastMinuteHandshake.Get() { + if peer.timers.sentLastMinuteHandshake.Load() { return } keypair := peer.keypairs.Current() if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { - peer.timers.sentLastMinuteHandshake.Set(true) + peer.timers.sentLastMinuteHandshake.Store(true) peer.SendHandshakeInitiation(false) } } @@ -91,188 +70,189 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { - - logDebug := device.log.Debug +func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { + recvName := recv.PrettyName() defer func() { - logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) + device.queue.decryption.wg.Done() + device.queue.handshake.wg.Done() device.net.stopping.Done() }() - logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started") - device.net.starting.Done() + device.log.Verbosef("Routine: receive incoming %s - started", recvName) // receive datagrams until conn is closed - buffer := device.GetMessageBuffer() - var ( - err error - size int - endpoint conn.Endpoint + bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) + bufs = make([][]byte, maxBatchSize) + err error + sizes = make([]int, maxBatchSize) + count int + endpoints = make([]conn.Endpoint, maxBatchSize) + deathSpiral int + elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) - for { - - // read next datagram + for i := range bufsArrs { + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + } - switch IP { - case ipv4.Version: - size, endpoint, err = bind.ReceiveIPv4(buffer[:]) - case ipv6.Version: - size, endpoint, err = bind.ReceiveIPv6(buffer[:]) - default: - panic("invalid IP version") + defer func() { + for i := 0; i < maxBatchSize; i++ { + if bufsArrs[i] != nil { + device.PutMessageBuffer(bufsArrs[i]) + } } + }() + for { + count, err = recv(bufs, sizes, endpoints) if err != nil { - device.PutMessageBuffer(buffer) + if errors.Is(err, net.ErrClosed) { + return + } + device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) + if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { + return + } + if deathSpiral < 10 { + deathSpiral++ + time.Sleep(time.Second / 3) + continue + } return } + deathSpiral = 0 - if size < MinMessageSize { - continue - } - - // check size of packet + // handle each packet in the batch + for i, size := range sizes[:count] { + if size < MinMessageSize { + continue + } - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) + // check size of packet - var okay bool + packet := bufsArrs[i][:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - switch msgType { + switch msgType { - // check if transport + // check if transport - case MessageTransportType: + case MessageTransportType: - // check size + // check size - if len(packet) < MessageTransportSize { - continue - } + if len(packet) < MessageTransportSize { + continue + } - // lookup key pair + // lookup key pair - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indexTable.Lookup(receiver) - keypair := value.keypair - if keypair == nil { - continue - } + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { + continue + } - // check keypair expiry + // check keypair expiry - if keypair.created.Add(RejectAfterTime).Before(time.Now()) { - continue - } + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } - // create work element - peer := value.peer - elem := device.GetInboundElement() - elem.packet = packet - elem.buffer = buffer - elem.keypair = keypair - elem.dropped = AtomicFalse - elem.endpoint = endpoint - elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() - - // add to decryption queues - - if peer.isRunning.Get() { - if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { - buffer = device.GetMessageBuffer() + // create work element + peer := value.peer + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = bufsArrs[i] + elem.keypair = keypair + elem.endpoint = endpoints[i] + elem.counter = 0 + + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetInboundElementsContainer() + elemsForPeer.Lock() + elemsByPeer[peer] = elemsForPeer } - } + elemsForPeer.elems = append(elemsForPeer.elems, elem) + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + continue - continue + // otherwise it is a fixed size & handshake related packet - // otherwise it is a fixed size & handshake related packet + case MessageInitiationType: + if len(packet) != MessageInitiationSize { + continue + } - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize + case MessageResponseType: + if len(packet) != MessageResponseSize { + continue + } - case MessageResponseType: - okay = len(packet) == MessageResponseSize + case MessageCookieReplyType: + if len(packet) != MessageCookieReplySize { + continue + } - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize + default: + device.log.Verbosef("Received message with unknown type") + continue + } - default: - logDebug.Println("Received message with unknown type") + select { + case device.queue.handshake.c <- QueueHandshakeElement{ + msgType: msgType, + buffer: bufsArrs[i], + packet: packet, + endpoint: endpoints[i], + }: + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + default: + } } - - if okay { - if (device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - endpoint: endpoint, - }, - )) { - buffer = device.GetMessageBuffer() + for peer, elemsContainer := range elemsByPeer { + if peer.isRunning.Load() { + peer.queue.inbound.c <- elemsContainer + device.queue.decryption.c <- elemsContainer + } else { + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsContainer(elemsContainer) } + delete(elemsByPeer, peer) } } } -func (device *Device) RoutineDecryption() { - +func (device *Device) RoutineDecryption(id int) { var nonce [chacha20poly1305.NonceSize]byte - logDebug := device.log.Debug - defer func() { - logDebug.Println("Routine: decryption worker - stopped") - device.state.stopping.Done() - }() - logDebug.Println("Routine: decryption worker - started") - device.state.starting.Done() - - for { - select { - case <-device.signals.stop: - return - - case elem, ok := <-device.queue.decryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } + defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) + device.log.Verbosef("Routine: decryption worker %d - started", id) + for elemsContainer := range device.queue.decryption.c { + for _, elem := range elemsContainer.elems { // split message into fields - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] content := elem.packet[MessageTransportOffsetContent:] - // expand nonce - - nonce[0x4] = counter[0x0] - nonce[0x5] = counter[0x1] - nonce[0x6] = counter[0x2] - nonce[0x7] = counter[0x3] - - nonce[0x8] = counter[0x4] - nonce[0x9] = counter[0x5] - nonce[0xa] = counter[0x6] - nonce[0xb] = counter[0x7] - // decrypt and release to consumer - var err error elem.counter = binary.LittleEndian.Uint64(counter) + // copy counter to nonce + binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) elem.packet, err = elem.keypair.receive.Open( content[:0], nonce[:], @@ -280,51 +260,23 @@ func (device *Device) RoutineDecryption() { nil, ) if err != nil { - elem.Drop() - device.PutMessageBuffer(elem.buffer) + elem.packet = nil } - elem.Unlock() } + elemsContainer.Unlock() } } /* Handles incoming packets related to handshake */ -func (device *Device) RoutineHandshake() { - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - var elem QueueHandshakeElement - var ok bool - +func (device *Device) RoutineHandshake(id int) { defer func() { - logDebug.Println("Routine: handshake worker - stopped") - device.state.stopping.Done() - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - } + device.log.Verbosef("Routine: handshake worker %d - stopped", id) + device.queue.encryption.wg.Done() }() + device.log.Verbosef("Routine: handshake worker %d - started", id) - logDebug.Println("Routine: handshake worker - started") - device.state.starting.Done() - - for { - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - elem.buffer = nil - } - - select { - case elem, ok = <-device.queue.handshake: - case <-device.signals.stop: - return - } - - if !ok { - return - } + for elem := range device.queue.handshake.c { // handle cookie fields and ratelimiting @@ -338,8 +290,8 @@ func (device *Device) RoutineHandshake() { reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &reply) if err != nil { - logDebug.Println("Failed to decode cookie reply") - return + device.log.Verbosef("Failed to decode cookie reply") + goto skip } // lookup peer from index @@ -347,27 +299,27 @@ func (device *Device) RoutineHandshake() { entry := device.indexTable.Lookup(reply.Receiver) if entry.peer == nil { - continue + goto skip } // consume reply - if peer := entry.peer; peer.isRunning.Get() { - logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString()) + if peer := entry.peer; peer.isRunning.Load() { + device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) if !peer.cookieGenerator.ConsumeReply(&reply) { - logDebug.Println("Could not decrypt invalid cookie response") + device.log.Verbosef("Could not decrypt invalid cookie response") } } - continue + goto skip case MessageInitiationType, MessageResponseType: // check mac fields and maybe ratelimit if !device.cookieChecker.CheckMAC1(elem.packet) { - logDebug.Println("Received packet with invalid mac1") - continue + device.log.Verbosef("Received packet with invalid mac1") + goto skip } // endpoints destination address is the source of the datagram @@ -378,19 +330,19 @@ func (device *Device) RoutineHandshake() { if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { device.SendHandshakeCookie(&elem) - continue + goto skip } // check ratelimiter if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { - continue + goto skip } } default: - logError.Println("Invalid packet ended up in the handshake queue") - continue + device.log.Errorf("Invalid packet ended up in the handshake queue") + goto skip } // handle handshake initiation/response content @@ -404,19 +356,16 @@ func (device *Device) RoutineHandshake() { reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { - logError.Println("Failed to decode initiation message") - continue + device.log.Errorf("Failed to decode initiation message") + goto skip } // consume initiation peer := device.ConsumeMessageInitiation(&msg) if peer == nil { - logInfo.Println( - "Received invalid initiation message from", - elem.endpoint.DstToString(), - ) - continue + device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) + goto skip } // update timers @@ -427,8 +376,8 @@ func (device *Device) RoutineHandshake() { // update endpoint peer.SetEndpointFromPacket(elem.endpoint) - logDebug.Println(peer, "- Received handshake initiation") - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + device.log.Verbosef("%v - Received handshake initiation", peer) + peer.rxBytes.Add(uint64(len(elem.packet))) peer.SendHandshakeResponse() @@ -440,26 +389,23 @@ func (device *Device) RoutineHandshake() { reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { - logError.Println("Failed to decode response message") - continue + device.log.Errorf("Failed to decode response message") + goto skip } // consume response peer := device.ConsumeMessageResponse(&msg) if peer == nil { - logInfo.Println( - "Received invalid response message from", - elem.endpoint.DstToString(), - ) - continue + device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) + goto skip } // update endpoint peer.SetEndpointFromPacket(elem.endpoint) - logDebug.Println(peer, "- Received handshake response") - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + device.log.Verbosef("%v - Received handshake response", peer) + peer.rxBytes.Add(uint64(len(elem.packet))) // update timers @@ -471,178 +417,124 @@ func (device *Device) RoutineHandshake() { err = peer.BeginSymmetricSession() if err != nil { - logError.Println(peer, "- Failed to derive keypair:", err) - continue + device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) + goto skip } peer.timersSessionDerived() peer.timersHandshakeComplete() peer.SendKeepalive() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } } + skip: + device.PutMessageBuffer(elem.buffer) } } -func (peer *Peer) RoutineSequentialReceiver() { - +func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device := peer.device - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - var elem *QueueInboundElement - defer func() { - logDebug.Println(peer, "- Routine: sequential receiver - stopped") - peer.routines.stopping.Done() - if elem != nil { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - } - device.PutInboundElement(elem) - } + device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) + peer.stopping.Done() }() + device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - logDebug.Println(peer, "- Routine: sequential receiver - started") - - peer.routines.starting.Done() - - for { - if elem != nil { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - } - device.PutInboundElement(elem) - elem = nil - } + bufs := make([][]byte, 0, maxBatchSize) - var elemOk bool - select { - case <-peer.routines.stop: + for elemsContainer := range peer.queue.inbound.c { + if elemsContainer == nil { return - case elem, elemOk = <-peer.queue.inbound: - if !elemOk { - return - } - } - - // wait for decryption - - elem.Lock() - - if elem.IsDropped() { - continue - } - - // check for replay - - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - continue - } - - // update endpoint - peer.SetEndpointFromPacket(elem.endpoint) - - // check if using new keypair - if peer.ReceivedWithKeypair(elem.keypair) { - peer.timersHandshakeComplete() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } - } - - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) - - // check for keepalive - - if len(elem.packet) == 0 { - logDebug.Println(peer, "- Receiving keepalive packet") - continue } - peer.timersDataReceived() - - // verify source and strip padding - - switch elem.packet[0] >> 4 { - case ipv4.Version: - - // strip padding - - if len(elem.packet) < ipv4.HeaderLen { + elemsContainer.Lock() + validTailPacket := -1 + dataPacketReceived := false + rxBytesLen := uint64(0) + for i, elem := range elemsContainer.elems { + if elem.packet == nil { + // decryption failed continue } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { continue } - elem.packet = elem.packet[:length] - - // verify IPv4 source - - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.LookupIPv4(src) != peer { - logInfo.Println( - "IPv4 packet with disallowed source address from", - peer, - ) - continue + validTailPacket = i + if peer.ReceivedWithKeypair(elem.keypair) { + peer.SetEndpointFromPacket(elem.endpoint) + peer.timersHandshakeComplete() + peer.SendStagedPackets() } + rxBytesLen += uint64(len(elem.packet) + MinMessageSize) - case ipv6.Version: - - // strip padding - - if len(elem.packet) < ipv6.HeaderLen { + if len(elem.packet) == 0 { + device.log.Verbosef("%v - Receiving keepalive packet", peer) continue } + dataPacketReceived = true - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } - - elem.packet = elem.packet[:length] + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) + continue + } - // verify IPv6 source + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) + continue + } - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.LookupIPv6(src) != peer { - logInfo.Println( - "IPv6 packet with disallowed source address from", - peer, - ) + default: + device.log.Verbosef("Packet with invalid IP version from %v", peer) continue } - default: - logInfo.Println("Packet with invalid IP version from", peer) - continue + bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } - // write to tun device - - offset := MessageTransportOffsetContent - _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) - if len(peer.queue.inbound) == 0 { - err = device.tun.device.Flush() - if err != nil { - peer.device.log.Error.Printf("Unable to flush packets: %v", err) + peer.rxBytes.Add(rxBytesLen) + if validTailPacket >= 0 { + peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + } + if dataPacketReceived { + peer.timersDataReceived() + } + if len(bufs) > 0 { + _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) + if err != nil && !device.isClosed() { + device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - if err != nil && !device.isClosed.Get() { - logError.Println("Failed to write packet to TUN device:", err) + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) } + bufs = bufs[:0] + device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index c0bdba3..769720a 100644 --- a/device/send.go +++ b/device/send.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,14 +8,17 @@ package device import ( "bytes" "encoding/binary" + "errors" "net" + "os" "sync" - "sync/atomic" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/tun" ) /* Outbound flow @@ -43,8 +46,6 @@ import ( */ type QueueOutboundElement struct { - dropped int32 - sync.Mutex buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption @@ -52,80 +53,52 @@ type QueueOutboundElement struct { peer *Peer // related peer } +type QueueOutboundElementsContainer struct { + sync.Mutex + elems []*QueueOutboundElement +} + func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() - elem.dropped = AtomicFalse elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} elem.nonce = 0 - elem.keypair = nil - elem.peer = nil + // keypair and peer were cleared (if necessary) by clearPointers. return elem } -func (elem *QueueOutboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueOutboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - device.PutMessageBuffer(old.buffer) - device.PutOutboundElement(old) - default: - } - } - } +// clearPointers clears elem fields that contain pointers. +// This makes the garbage collector's life easier and +// avoids accidentally keeping other objects around unnecessarily. +// It also reduces the possible collateral damage from use-after-free bugs. +func (elem *QueueOutboundElement) clearPointers() { + elem.buffer = nil + elem.packet = nil + elem.keypair = nil + elem.peer = nil } -func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) { - select { - case outboundQueue <- element: +/* Queues a keepalive if no packets are queued for peer + */ +func (peer *Peer) SendKeepalive() { + if len(peer.queue.staged) == 0 && peer.isRunning.Load() { + elem := peer.device.NewOutboundElement() + elemsContainer := peer.device.GetOutboundElementsContainer() + elemsContainer.elems = append(elemsContainer.elems, elem) select { - case encryptionQueue <- element: - return + case peer.queue.staged <- elemsContainer: + peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: - element.Drop() - element.peer.device.PutMessageBuffer(element.buffer) - element.Unlock() + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + peer.device.PutOutboundElementsContainer(elemsContainer) } - default: - element.peer.device.PutMessageBuffer(element.buffer) - element.peer.device.PutOutboundElement(element) - } -} - -/* Queues a keepalive if no packets are queued for peer - */ -func (peer *Peer) SendKeepalive() bool { - if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() { - return false - } - elem := peer.device.NewOutboundElement() - elem.packet = nil - select { - case peer.queue.nonce <- elem: - peer.device.log.Debug.Println(peer, "- Sending keepalive packet") - return true - default: - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) - return false } + peer.SendStagedPackets() } func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { if !isRetry { - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.handshakeAttempts.Store(0) } peer.handshake.mutex.RLock() @@ -143,16 +116,16 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.handshake.lastSentHandshake = time.Now() peer.handshake.mutex.Unlock() - peer.device.log.Debug.Println(peer, "- Sending handshake initiation") + peer.device.log.Verbosef("%v - Sending handshake initiation", peer) msg, err := peer.device.CreateMessageInitiation(peer) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err) + peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) return err } - var buff [MessageInitiationSize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageInitiationSize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, msg) packet := writer.Bytes() peer.cookieGenerator.AddMacs(packet) @@ -160,9 +133,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + err = peer.SendBuffers([][]byte{packet}) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err) + peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } peer.timersHandshakeInitiated() @@ -174,23 +147,23 @@ func (peer *Peer) SendHandshakeResponse() error { peer.handshake.lastSentHandshake = time.Now() peer.handshake.mutex.Unlock() - peer.device.log.Debug.Println(peer, "- Sending handshake response") + peer.device.log.Verbosef("%v - Sending handshake response", peer) response, err := peer.device.CreateMessageResponse(peer) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to create response message:", err) + peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) return err } - var buff [MessageResponseSize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageResponseSize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, response) packet := writer.Bytes() peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() if err != nil { - peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err) + peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) return err } @@ -198,28 +171,29 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + // TODO: allocation could be avoided + err = peer.SendBuffers([][]byte{packet}) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to send handshake response", err) + peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } return err } func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { - - device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString()) + device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) if err != nil { - device.log.Error.Println("Failed to create cookie reply:", err) + device.log.Errorf("Failed to create cookie reply: %v", err) return err } - var buff [MessageCookieReplySize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageCookieReplySize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + // TODO: allocation could be avoided + device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) return nil } @@ -228,222 +202,221 @@ func (peer *Peer) keepKeyFreshSending() { if keypair == nil { return } - nonce := atomic.LoadUint64(&keypair.sendNonce) + nonce := keypair.sendNonce.Load() if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { peer.SendHandshakeInitiation(false) } } -/* Reads packets from the TUN and inserts - * into nonce queue for peer - * - * Obs. Single instance per TUN device - */ func (device *Device) RoutineReadFromTUN() { - - logDebug := device.log.Debug - logError := device.log.Error - defer func() { - logDebug.Println("Routine: TUN reader - stopped") + device.log.Verbosef("Routine: TUN reader - stopped") device.state.stopping.Done() + device.queue.encryption.wg.Done() }() - logDebug.Println("Routine: TUN reader - started") - device.state.starting.Done() - - var elem *QueueOutboundElement + device.log.Verbosef("Routine: TUN reader - started") + + var ( + batchSize = device.BatchSize() + readErr error + elems = make([]*QueueOutboundElement, batchSize) + bufs = make([][]byte, batchSize) + elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) + count = 0 + sizes = make([]int, batchSize) + offset = MessageTransportHeaderSize + ) + + for i := range elems { + elems[i] = device.NewOutboundElement() + bufs[i] = elems[i].buffer[:] + } - for { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + defer func() { + for _, elem := range elems { + if elem != nil { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } } - elem = device.NewOutboundElement() - - // read packet - - offset := MessageTransportHeaderSize - size, err := device.tun.device.Read(elem.buffer[:], offset) + }() - if err != nil { - if !device.isClosed.Get() { - logError.Println("Failed to read packet from TUN device:", err) - device.Close() + for { + // read packets + count, readErr = device.tun.device.Read(bufs, sizes, offset) + for i := 0; i < count; i++ { + if sizes[i] < 1 { + continue } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return - } - if size == 0 || size > MaxContentSize { - continue - } + elem := elems[i] + elem.packet = bufs[i][offset : offset+sizes[i]] - elem.packet = elem.buffer[offset : offset+size] + // lookup peer + var peer *Peer + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.allowedips.Lookup(dst) - // lookup peer + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.allowedips.Lookup(dst) - var peer *Peer - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue + default: + device.log.Verbosef("Received packet with unknown IP version") } - dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.allowedips.LookupIPv4(dst) - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { + if peer == nil { continue } - dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.allowedips.LookupIPv6(dst) - - default: - logDebug.Println("Received packet with unknown IP version") + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetOutboundElementsContainer() + elemsByPeer[peer] = elemsForPeer + } + elemsForPeer.elems = append(elemsForPeer.elems, elem) + elems[i] = device.NewOutboundElement() + bufs[i] = elems[i].buffer[:] } - if peer == nil { - continue + for peer, elemsForPeer := range elemsByPeer { + if peer.isRunning.Load() { + peer.StagePackets(elemsForPeer) + peer.SendStagedPackets() + } else { + for _, elem := range elemsForPeer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsContainer(elemsForPeer) + } + delete(elemsByPeer, peer) } - // insert into nonce/pre-handshake queue - - if peer.isRunning.Get() { - if peer.queue.packetInNonceQueueIsAwaitingKey.Get() { - peer.SendHandshakeInitiation(false) + if readErr != nil { + if errors.Is(readErr, tun.ErrTooManySegments) { + // TODO: record stat for this + // This will happen if MSS is surprisingly small (< 576) + // coincident with reasonably high throughput. + device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) + continue + } + if !device.isClosed() { + if !errors.Is(readErr, os.ErrClosed) { + device.log.Errorf("Failed to read packet from TUN device: %v", readErr) + } + go device.Close() } - addToNonceQueue(peer.queue.nonce, elem, device) - elem = nil + return } } } -func (peer *Peer) FlushNonceQueue() { - select { - case peer.signals.flushNonceQueue <- struct{}{}: - default: - } -} - -/* Queues packets when there is no handshake. - * Then assigns nonces to packets sequentially - * and creates "work" structs for workers - * - * Obs. A single instance per peer - */ -func (peer *Peer) RoutineNonce() { - var keypair *Keypair - - device := peer.device - logDebug := device.log.Debug - - flush := func() { - for { - select { - case elem := <-peer.queue.nonce: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - default: - return +func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { + for { + select { + case peer.queue.staged <- elems: + return + default: + } + select { + case tooOld := <-peer.queue.staged: + for _, elem := range tooOld.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) } + peer.device.PutOutboundElementsContainer(tooOld) + default: } } +} - defer func() { - flush() - logDebug.Println(peer, "- Routine: nonce worker - stopped") - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - peer.routines.stopping.Done() - }() +func (peer *Peer) SendStagedPackets() { +top: + if len(peer.queue.staged) == 0 || !peer.device.isUp() { + return + } - peer.routines.starting.Done() - logDebug.Println(peer, "- Routine: nonce worker - started") + keypair := peer.keypairs.Current() + if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { + peer.SendHandshakeInitiation(false) + return + } for { - NextPacket: - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - + var elemsContainerOOO *QueueOutboundElementsContainer select { - case <-peer.routines.stop: - return - - case <-peer.signals.flushNonceQueue: - flush() - goto NextPacket - - case elem, ok := <-peer.queue.nonce: - - if !ok { - return - } - - // make sure to always pick the newest key - - for { - - // check validity of newest key pair - - keypair = peer.keypairs.Current() - if keypair != nil && keypair.sendNonce < RejectAfterMessages { - if time.Since(keypair.created) < RejectAfterTime { - break + case elemsContainer := <-peer.queue.staged: + i := 0 + for _, elem := range elemsContainer.elems { + elem.peer = peer + elem.nonce = keypair.sendNonce.Add(1) - 1 + if elem.nonce >= RejectAfterMessages { + keypair.sendNonce.Store(RejectAfterMessages) + if elemsContainerOOO == nil { + elemsContainerOOO = peer.device.GetOutboundElementsContainer() } + elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) + continue + } else { + elemsContainer.elems[i] = elem + i++ } - peer.queue.packetInNonceQueueIsAwaitingKey.Set(true) - // no suitable key pair, request for new handshake - - select { - case <-peer.signals.newKeypairArrived: - default: - } - - peer.SendHandshakeInitiation(false) - - // wait for key to be established - - logDebug.Println(peer, "- Awaiting keypair") + elem.keypair = keypair + } + elemsContainer.Lock() + elemsContainer.elems = elemsContainer.elems[:i] - select { - case <-peer.signals.newKeypairArrived: - logDebug.Println(peer, "- Obtained awaited keypair") + if elemsContainerOOO != nil { + peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans + } - case <-peer.signals.flushNonceQueue: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - flush() - goto NextPacket + if len(elemsContainer.elems) == 0 { + peer.device.PutOutboundElementsContainer(elemsContainer) + goto top + } - case <-peer.routines.stop: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return + // add to parallel and sequential queue + if peer.isRunning.Load() { + peer.queue.outbound.c <- elemsContainer + peer.device.queue.encryption.c <- elemsContainer + } else { + for _, elem := range elemsContainer.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) } + peer.device.PutOutboundElementsContainer(elemsContainer) } - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - - // populate work element - elem.peer = peer - elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 - - // double check in case of race condition added by future code - - if elem.nonce >= RejectAfterMessages { - atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - goto NextPacket + if elemsContainerOOO != nil { + goto top } + default: + return + } + } +} - elem.keypair = keypair - elem.dropped = AtomicFalse - elem.Lock() - - // add to parallel and sequential queue - addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem) +func (peer *Peer) FlushStagedPackets() { + for { + select { + case elemsContainer := <-peer.queue.staged: + for _, elem := range elemsContainer.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsContainer(elemsContainer) + default: + return } } } @@ -468,55 +441,16 @@ func calculatePaddingSize(packetSize, mtu int) int { * * Obs. One instance per core */ -func (device *Device) RoutineEncryption() { - +func (device *Device) RoutineEncryption(id int) { + var paddingZeros [PaddingMultiple]byte var nonce [chacha20poly1305.NonceSize]byte - logDebug := device.log.Debug - - defer func() { - for { - select { - case elem, ok := <-device.queue.encryption: - if ok && !elem.IsDropped() { - elem.Drop() - device.PutMessageBuffer(elem.buffer) - elem.Unlock() - } - default: - goto out - } - } - out: - logDebug.Println("Routine: encryption worker - stopped") - device.state.stopping.Done() - }() - - logDebug.Println("Routine: encryption worker - started") - device.state.starting.Done() - - for { - - // fetch next element - - select { - case <-device.signals.stop: - return - - case elem, ok := <-device.queue.encryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } + defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) + device.log.Verbosef("Routine: encryption worker %d - started", id) + for elemsContainer := range device.queue.encryption.c { + for _, elem := range elemsContainer.elems { // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] fieldType := header[0:4] @@ -528,11 +462,8 @@ func (device *Device) RoutineEncryption() { binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) // pad content to multiple of 16 - - paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu))) - for i := 0; i < paddingSize; i++ { - elem.packet = append(elem.packet, 0) - } + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) // encrypt content and release to consumer @@ -543,82 +474,73 @@ func (device *Device) RoutineEncryption() { elem.packet, nil, ) - elem.Unlock() } + elemsContainer.Unlock() } } -/* Sequentially reads packets from queue and sends to endpoint - * - * Obs. Single instance per peer. - * The routine terminates then the outbound queue is closed. - */ -func (peer *Peer) RoutineSequentialSender() { - +func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device := peer.device - - logDebug := device.log.Debug - logError := device.log.Error - defer func() { - for { - select { - case elem, ok := <-peer.queue.outbound: - if ok { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - elem.Drop() - } - device.PutOutboundElement(elem) - } - default: - goto out - } - } - out: - logDebug.Println(peer, "- Routine: sequential sender - stopped") - peer.routines.stopping.Done() + defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) + peer.stopping.Done() }() + device.log.Verbosef("%v - Routine: sequential sender - started", peer) - logDebug.Println(peer, "- Routine: sequential sender - started") - - peer.routines.starting.Done() - - for { - select { + bufs := make([][]byte, 0, maxBatchSize) - case <-peer.routines.stop: + for elemsContainer := range peer.queue.outbound.c { + bufs = bufs[:0] + if elemsContainer == nil { return - - case elem, ok := <-peer.queue.outbound: - - if !ok { - return - } - - elem.Lock() - if elem.IsDropped() { + } + if !peer.isRunning.Load() { + // peer has been stopped; return re-usable elems to the shared pool. + // This is an optimization only. It is possible for the peer to be stopped + // immediately after this check, in which case, elem will get processed. + // The timers and SendBuffers code are resilient to a few stragglers. + // TODO: rework peer shutdown order to ensure + // that we never accidentally keep timers alive longer than necessary. + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) - continue } - - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketSent() - - // send message and return buffer to pool - - err := peer.SendBuffer(elem.packet) + continue + } + dataSent := false + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { - peer.timersDataSent() + dataSent = true } + bufs = append(bufs, elem.packet) + } + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() + + err := peer.SendBuffers(bufs) + if dataSent { + peer.timersDataSent() + } + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) - if err != nil { - logError.Println(peer, "- Failed to send data packet", err) - continue + } + device.PutOutboundElementsContainer(elemsContainer) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr } - - peer.keepKeyFreshSending() } + if err != nil { + device.log.Errorf("%v - Failed to send data packets: %v", peer, err) + continue + } + + peer.keepKeyFreshSending() } } diff --git a/device/sticky_default.go b/device/sticky_default.go index 56da4eb..1038256 100644 --- a/device/sticky_default.go +++ b/device/sticky_default.go @@ -1,4 +1,4 @@ -// +build !linux android +//go:build !linux package device diff --git a/device/sticky_linux.go b/device/sticky_linux.go index e3efc86..6057ff1 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -1,8 +1,6 @@ -// +build !android - /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * * This implements userspace semantics of "sticky sockets", modeled after * WireGuard's kernelspace implementation. This is more or less a straight port @@ -21,11 +19,19 @@ import ( "unsafe" "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + if !conn.StdNetSupportsStickySockets { + return nil, nil + } + if _, ok := bind.(*conn.StdNetBind); !ok { + return nil, nil + } + netlinkSock, err := createNetlinkRouteSocket() if err != nil { return nil, err @@ -49,6 +55,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl var reqPeer map[uint32]peerEndpointPtr var reqPeerLock sync.Mutex + defer netlinkCancel.Close() defer unix.Close(netlinkSock) for msg := make([]byte, 1<<16); ; { @@ -103,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl if !ok { break } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() + pePtr.peer.endpoint.Lock() + if &pePtr.peer.endpoint.val != pePtr.endpoint { + pePtr.peer.endpoint.Unlock() break } - if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx { - pePtr.peer.Unlock() + if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { + pePtr.peer.endpoint.Unlock() break } - pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc() - pePtr.peer.Unlock() + pePtr.peer.endpoint.clearSrcOnTx = true + pePtr.peer.endpoint.Unlock() } attr = attr[attrhdr.Len:] } @@ -127,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl device.peers.RLock() i := uint32(1) for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil { - peer.RUnlock() + peer.endpoint.Lock() + if peer.endpoint.val == nil { + peer.endpoint.Unlock() continue } - nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint) + nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) if nativeEP == nil { - peer.RUnlock() + peer.endpoint.Unlock() continue } - if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { - peer.RUnlock() + if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { + peer.endpoint.Unlock() break } nlmsg := struct { @@ -165,26 +172,26 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl Len: 8, Type: unix.RTA_DST, }, - nativeEP.Dst4().Addr, + nativeEP.DstIP().As4(), unix.RtAttr{ Len: 8, Type: unix.RTA_SRC, }, - nativeEP.Src4().Src, + nativeEP.SrcIP().As4(), unix.RtAttr{ Len: 8, Type: unix.RTA_MARK, }, - uint32(bind.LastMark()), + device.net.fwmark, } nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) reqPeerLock.Lock() reqPeer[i] = peerEndpointPtr{ peer: peer, - endpoint: &peer.endpoint, + endpoint: &peer.endpoint.val, } reqPeerLock.Unlock() - peer.RUnlock() + peer.endpoint.Unlock() i++ _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) if err != nil { @@ -200,13 +207,13 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl } func createNetlinkRouteSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) if err != nil { return -1, err } saddr := &unix.SockaddrNetlink{ Family: unix.AF_NETLINK, - Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), + Groups: unix.RTMGRP_IPV4_ROUTE, } err = unix.Bind(sock, saddr) if err != nil { diff --git a/device/timers.go b/device/timers.go index 0232eef..d4a4ed4 100644 --- a/device/timers.go +++ b/device/timers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * * This is based heavily on timers.c from the kernel implementation. */ @@ -8,16 +8,16 @@ package device import ( - "math/rand" "sync" - "sync/atomic" "time" + _ "unsafe" ) -/* This Timer structure and related functions should roughly copy the interface of - * the Linux kernel's struct timer_list. - */ +//go:linkname fastrandn runtime.fastrandn +func fastrandn(n uint32) uint32 +// A Timer manages time-based aspects of the WireGuard protocol. +// Timer roughly copies the interface of the Linux kernel's struct timer_list. type Timer struct { *time.Timer modifyingLock sync.RWMutex @@ -29,18 +29,17 @@ func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { timer := &Timer{} timer.Timer = time.AfterFunc(time.Hour, func() { timer.runningLock.Lock() + defer timer.runningLock.Unlock() timer.modifyingLock.Lock() if !timer.isPending { timer.modifyingLock.Unlock() - timer.runningLock.Unlock() return } timer.isPending = false timer.modifyingLock.Unlock() expirationFunction(peer) - timer.runningLock.Unlock() }) timer.Stop() return timer @@ -74,12 +73,12 @@ func (timer *Timer) IsPending() bool { } func (peer *Peer) timersActive() bool { - return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0 + return peer.isRunning.Load() && peer.device != nil && peer.device.isUp() } func expiredRetransmitHandshake(peer *Peer) { - if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { - peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2) + if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes { + peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) if peer.timersActive() { peer.timers.sendKeepalive.Del() @@ -88,7 +87,7 @@ func expiredRetransmitHandshake(peer *Peer) { /* We drop all packets without a keypair and don't try again, * if we try unsuccessfully for too long to make a handshake. */ - peer.FlushNonceQueue() + peer.FlushStagedPackets() /* We set a timer for destroying any residue that might be left * of a partial exchange. @@ -97,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) { peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) } } else { - atomic.AddUint32(&peer.timers.handshakeAttempts, 1) - peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) + peer.timers.handshakeAttempts.Add(1) + peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(true) } @@ -113,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) { func expiredSendKeepalive(peer *Peer) { peer.SendKeepalive() - if peer.timers.needAnotherKeepalive.Get() { - peer.timers.needAnotherKeepalive.Set(false) + if peer.timers.needAnotherKeepalive.Load() { + peer.timers.needAnotherKeepalive.Store(false) if peer.timersActive() { peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } @@ -122,24 +117,19 @@ func expiredSendKeepalive(peer *Peer) { } func expiredNewHandshake(peer *Peer) { - peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) + peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(false) - } func expiredZeroKeyMaterial(peer *Peer) { - peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) + peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds())) peer.ZeroAndFlushAll() } func expiredPersistentKeepalive(peer *Peer) { - if peer.persistentKeepaliveInterval > 0 { + if peer.persistentKeepaliveInterval.Load() > 0 { peer.SendKeepalive() } } @@ -147,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) { /* Should be called after an authenticated data packet is sent. */ func (peer *Peer) timersDataSent() { if peer.timersActive() && !peer.timers.newHandshake.IsPending() { - peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) + peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) } } @@ -157,7 +147,7 @@ func (peer *Peer) timersDataReceived() { if !peer.timers.sendKeepalive.IsPending() { peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } else { - peer.timers.needAnotherKeepalive.Set(true) + peer.timers.needAnotherKeepalive.Store(true) } } } @@ -179,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() { /* Should be called after a handshake initiation message is sent. */ func (peer *Peer) timersHandshakeInitiated() { if peer.timersActive() { - peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) + peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) } } @@ -188,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() { if peer.timersActive() { peer.timers.retransmitHandshake.Del() } - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - peer.timers.sentLastMinuteHandshake.Set(false) - atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) + peer.timers.handshakeAttempts.Store(0) + peer.timers.sentLastMinuteHandshake.Store(false) + peer.lastHandshakeNano.Store(time.Now().UnixNano()) } /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ @@ -202,8 +192,9 @@ func (peer *Peer) timersSessionDerived() { /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { - if peer.persistentKeepaliveInterval > 0 && peer.timersActive() { - peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second) + keepalive := peer.persistentKeepaliveInterval.Load() + if keepalive > 0 && peer.timersActive() { + peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) } } @@ -213,9 +204,12 @@ func (peer *Peer) timersInit() { peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - peer.timers.sentLastMinuteHandshake.Set(false) - peer.timers.needAnotherKeepalive.Set(false) +} + +func (peer *Peer) timersStart() { + peer.timers.handshakeAttempts.Store(0) + peer.timers.sentLastMinuteHandshake.Store(false) + peer.timers.needAnotherKeepalive.Store(false) } func (peer *Peer) timersStop() { diff --git a/device/tun.go b/device/tun.go index 1f88f33..2a2ace9 100644 --- a/device/tun.go +++ b/device/tun.go @@ -1,12 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "sync/atomic" + "fmt" "golang.zx2c4.com/wireguard/tun" ) @@ -14,43 +14,40 @@ import ( const DefaultMTU = 1420 func (device *Device) RoutineTUNEventReader() { - setUp := false - logDebug := device.log.Debug - logInfo := device.log.Info - logError := device.log.Error - - logDebug.Println("Routine: event worker - started") - device.state.starting.Done() + device.log.Verbosef("Routine: event worker - started") for event := range device.tun.device.Events() { if event&tun.EventMTUUpdate != 0 { mtu, err := device.tun.device.MTU() - old := atomic.LoadInt32(&device.tun.mtu) if err != nil { - logError.Println("Failed to load updated MTU of device:", err) - } else if int(old) != mtu { - if mtu+MessageTransportSize > MaxMessageSize { - logInfo.Println("MTU updated:", mtu, "(too large)") - } else { - logInfo.Println("MTU updated:", mtu) - } - atomic.StoreInt32(&device.tun.mtu, int32(mtu)) + device.log.Errorf("Failed to load updated MTU of device: %v", err) + continue + } + if mtu < 0 { + device.log.Errorf("MTU not updated to negative value: %v", mtu) + continue + } + var tooLarge string + if mtu > MaxContentSize { + tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) + mtu = MaxContentSize + } + old := device.tun.mtu.Swap(int32(mtu)) + if int(old) != mtu { + device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) } } - if event&tun.EventUp != 0 && !setUp { - logInfo.Println("Interface set up") - setUp = true + if event&tun.EventUp != 0 { + device.log.Verbosef("Interface up requested") device.Up() } - if event&tun.EventDown != 0 && setUp { - logInfo.Println("Interface set down") - setUp = false + if event&tun.EventDown != 0 { + device.log.Verbosef("Interface down requested") device.Down() } } - logDebug.Println("Routine: event worker - stopped") - device.state.stopping.Done() + device.log.Verbosef("Routine: event worker - stopped") } diff --git a/device/tun_test.go b/device/tun_test.go deleted file mode 100644 index a2db2a5..0000000 --- a/device/tun_test.go +++ /dev/null @@ -1,56 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "errors" - "os" - - "golang.zx2c4.com/wireguard/tun" -) - -// newDummyTUN creates a dummy TUN device with the specified name. -func newDummyTUN(name string) tun.Device { - return &dummyTUN{ - name: name, - packets: make(chan []byte, 100), - events: make(chan tun.Event, 10), - } -} - -// A dummyTUN is a tun.Device which is used in unit tests. -type dummyTUN struct { - name string - mtu int - packets chan []byte - events chan tun.Event -} - -func (d *dummyTUN) Events() chan tun.Event { return d.events } -func (*dummyTUN) File() *os.File { return nil } -func (*dummyTUN) Flush() error { return nil } -func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil } -func (d *dummyTUN) Name() (string, error) { return d.name, nil } - -func (d *dummyTUN) Close() error { - close(d.events) - close(d.packets) - return nil -} - -func (d *dummyTUN) Read(b []byte, offset int) (int, error) { - buf, ok := <-d.packets - if !ok { - return 0, errors.New("device closed") - } - copy(b[offset:], buf) - return len(buf), nil -} - -func (d *dummyTUN) Write(b []byte, offset int) (int, error) { - d.packets <- b[offset:] - return len(b), nil -} diff --git a/device/uapi.go b/device/uapi.go index 9f9c9bd..d81dae3 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -1,45 +1,77 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "bufio" + "bytes" "errors" "fmt" "io" "net" + "net/netip" "strconv" "strings" - "sync/atomic" + "sync" "time" - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ipc" ) type IPCError struct { - int64 + code int64 // error code + err error // underlying/wrapped error } func (s IPCError) Error() string { - return fmt.Sprintf("IPC error: %d", s.int64) + return fmt.Sprintf("IPC error %d: %v", s.code, s.err) +} + +func (s IPCError) Unwrap() error { + return s.err } func (s IPCError) ErrorCode() int64 { - return s.int64 + return s.code +} + +func ipcErrorf(code int64, msg string, args ...any) *IPCError { + return &IPCError{code: code, err: fmt.Errorf(msg, args...)} } -func (device *Device) IpcGetOperation(socket *bufio.Writer) error { - lines := make([]string, 0, 100) - send := func(line string) { - lines = append(lines, line) +var byteBufferPool = &sync.Pool{ + New: func() any { return new(bytes.Buffer) }, +} + +// IpcGetOperation implements the WireGuard configuration protocol "get" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. +func (device *Device) IpcGetOperation(w io.Writer) error { + device.ipcMutex.RLock() + defer device.ipcMutex.RUnlock() + + buf := byteBufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer byteBufferPool.Put(buf) + sendf := func(format string, args ...any) { + fmt.Fprintf(buf, format, args...) + buf.WriteByte('\n') + } + keyf := func(prefix string, key *[32]byte) { + buf.Grow(len(key)*2 + 2 + len(prefix)) + buf.WriteString(prefix) + buf.WriteByte('=') + const hex = "0123456789abcdef" + for i := 0; i < len(key); i++ { + buf.WriteByte(hex[key[i]>>4]) + buf.WriteByte(hex[key[i]&0xf]) + } + buf.WriteByte('\n') } func() { - // lock required resources device.net.RLock() @@ -54,353 +86,326 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error { // serialize device related values if !device.staticIdentity.privateKey.IsZero() { - send("private_key=" + device.staticIdentity.privateKey.ToHex()) + keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey)) } if device.net.port != 0 { - send(fmt.Sprintf("listen_port=%d", device.net.port)) + sendf("listen_port=%d", device.net.port) } if device.net.fwmark != 0 { - send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) + sendf("fwmark=%d", device.net.fwmark) } - // serialize each peer state - for _, peer := range device.peers.keyMap { - peer.RLock() - defer peer.RUnlock() - - send("public_key=" + peer.handshake.remoteStatic.ToHex()) - send("preshared_key=" + peer.handshake.presharedKey.ToHex()) - send("protocol_version=1") - if peer.endpoint != nil { - send("endpoint=" + peer.endpoint.DstToString()) + // Serialize peer state. + peer.handshake.mutex.RLock() + keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) + keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) + peer.handshake.mutex.RUnlock() + sendf("protocol_version=1") + peer.endpoint.Lock() + if peer.endpoint.val != nil { + sendf("endpoint=%s", peer.endpoint.val.DstToString()) } + peer.endpoint.Unlock() - nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) + nano := peer.lastHandshakeNano.Load() secs := nano / time.Second.Nanoseconds() nano %= time.Second.Nanoseconds() - send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) - send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) - send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))) - send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))) - send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) - - for _, ip := range device.allowedips.EntriesForPeer(peer) { - send("allowed_ip=" + ip.String()) - } + sendf("last_handshake_time_sec=%d", secs) + sendf("last_handshake_time_nsec=%d", nano) + sendf("tx_bytes=%d", peer.txBytes.Load()) + sendf("rx_bytes=%d", peer.rxBytes.Load()) + sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) + return true + }) } }() // send lines (does not require resource locks) - - for _, line := range lines { - _, err := socket.WriteString(line + "\n") - if err != nil { - return &IPCError{ipc.IpcErrorIO} - } + if _, err := w.Write(buf.Bytes()); err != nil { + return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err) } return nil } -func (device *Device) IpcSetOperation(socket *bufio.Reader) error { - scanner := bufio.NewScanner(socket) - logError := device.log.Error - logDebug := device.log.Debug +// IpcSetOperation implements the WireGuard configuration protocol "set" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. +func (device *Device) IpcSetOperation(r io.Reader) (err error) { + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() - var peer *Peer + defer func() { + if err != nil { + device.log.Errorf("%v", err) + } + }() - dummy := false - createdNewPeer := false + peer := new(ipcSetPeer) deviceConfig := true + scanner := bufio.NewScanner(r) for scanner.Scan() { - - // parse line - line := scanner.Text() if line == "" { + // Blank line means terminate operation. + peer.handlePostConfig() return nil } - parts := strings.Split(line, "=") - if len(parts) != 2 { - return &IPCError{ipc.IpcErrorProtocol} + key, value, ok := strings.Cut(line, "=") + if !ok { + return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line) } - key := parts[0] - value := parts[1] - - /* device configuration */ - - if deviceConfig { - - switch key { - case "private_key": - var sk NoisePrivateKey - err := sk.FromMaybeZeroHex(value) - if err != nil { - logError.Println("Failed to set private_key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - logDebug.Println("UAPI: Updating private key") - device.SetPrivateKey(sk) - - case "listen_port": - - // parse port number - - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to parse listen_port:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - // update port and rebind - - logDebug.Println("UAPI: Updating listen port") - - device.net.Lock() - device.net.port = uint16(port) - device.net.Unlock() - - if err := device.BindUpdate(); err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{ipc.IpcErrorPortInUse} - } - case "fwmark": - - // parse fwmark field - - fwmark, err := func() (uint32, error) { - if value == "" { - return 0, nil - } - mark, err := strconv.ParseUint(value, 10, 32) - return uint32(mark), err - }() - - if err != nil { - logError.Println("Invalid fwmark", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - logDebug.Println("UAPI: Updating fwmark") - - if err := device.BindSetMark(uint32(fwmark)); err != nil { - logError.Println("Failed to update fwmark:", err) - return &IPCError{ipc.IpcErrorPortInUse} - } - - case "public_key": - // switch to peer configuration - logDebug.Println("UAPI: Transition to peer configuration") + if key == "public_key" { + if deviceConfig { deviceConfig = false - - case "replace_peers": - if value != "true" { - logError.Println("Failed to set replace_peers, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - logDebug.Println("UAPI: Removing all peers") - device.RemoveAllPeers() - - default: - logError.Println("Invalid UAPI device key:", key) - return &IPCError{ipc.IpcErrorInvalid} } + peer.handlePostConfig() + // Load/create the peer we are now configuring. + err := device.handlePublicKeyLine(peer, value) + if err != nil { + return err + } + continue } - /* peer configuration */ - - if !deviceConfig { - - switch key { - - case "public_key": - var publicKey NoisePublicKey - err := publicKey.FromHex(value) - if err != nil { - logError.Println("Failed to get peer by public key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - // ignore peer with public key of device - - device.staticIdentity.RLock() - dummy = device.staticIdentity.publicKey.Equals(publicKey) - device.staticIdentity.RUnlock() - - if dummy { - peer = &Peer{} - } else { - peer = device.LookupPeer(publicKey) - } + var err error + if deviceConfig { + err = device.handleDeviceLine(key, value) + } else { + err = device.handlePeerLine(peer, key, value) + } + if err != nil { + return err + } + } + peer.handlePostConfig() - createdNewPeer = peer == nil - if createdNewPeer { - peer, err = device.NewPeer(publicKey) - if err != nil { - logError.Println("Failed to create new peer:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - if peer == nil { - dummy = true - peer = &Peer{} - } else { - logDebug.Println(peer, "- UAPI: Created") - } - } + if err := scanner.Err(); err != nil { + return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err) + } + return nil +} - case "update_only": +func (device *Device) handleDeviceLine(key, value string) error { + switch key { + case "private_key": + var sk NoisePrivateKey + err := sk.FromMaybeZeroHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) + } + device.log.Verbosef("UAPI: Updating private key") + device.SetPrivateKey(sk) - // allow disabling of creation + case "listen_port": + port, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) + } - if value != "true" { - logError.Println("Failed to set update only, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - if createdNewPeer && !dummy { - device.RemovePeer(peer.handshake.remoteStatic) - peer = &Peer{} - dummy = true - } + // update port and rebind + device.log.Verbosef("UAPI: Updating listen port") - case "remove": + device.net.Lock() + device.net.port = uint16(port) + device.net.Unlock() - // remove currently selected peer from device + if err := device.BindUpdate(); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) + } - if value != "true" { - logError.Println("Failed to set remove, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - if !dummy { - logDebug.Println(peer, "- UAPI: Removing") - device.RemovePeer(peer.handshake.remoteStatic) - } - peer = &Peer{} - dummy = true + case "fwmark": + mark, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) + } - case "preshared_key": + device.log.Verbosef("UAPI: Updating fwmark") + if err := device.BindSetMark(uint32(mark)); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) + } - // update PSK + case "replace_peers": + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) + } + device.log.Verbosef("UAPI: Removing all peers") + device.RemoveAllPeers() - logDebug.Println(peer, "- UAPI: Updating preshared key") + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) + } - peer.handshake.mutex.Lock() - err := peer.handshake.presharedKey.FromHex(value) - peer.handshake.mutex.Unlock() + return nil +} - if err != nil { - logError.Println("Failed to set preshared key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } +// An ipcSetPeer is the current state of an IPC set operation on a peer. +type ipcSetPeer struct { + *Peer // Peer is the current peer being operated on + dummy bool // dummy reports whether this peer is a temporary, placeholder peer + created bool // new reports whether this is a newly created peer + pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on +} - case "endpoint": +func (peer *ipcSetPeer) handlePostConfig() { + if peer.Peer == nil || peer.dummy { + return + } + if peer.created { + peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil + } + if peer.device.isUp() { + peer.Start() + if peer.pkaOn { + peer.SendKeepalive() + } + peer.SendStagedPackets() + } +} - // set endpoint destination +func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error { + // Load/create the peer we are configuring. + var publicKey NoisePublicKey + err := publicKey.FromHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) + } - logDebug.Println(peer, "- UAPI: Updating endpoint") + // Ignore peer with the same public key as this device. + device.staticIdentity.RLock() + peer.dummy = device.staticIdentity.publicKey.Equals(publicKey) + device.staticIdentity.RUnlock() - err := func() error { - peer.Lock() - defer peer.Unlock() - endpoint, err := conn.CreateEndpoint(value) - if err != nil { - return err - } - peer.endpoint = endpoint - return nil - }() + if peer.dummy { + peer.Peer = &Peer{} + } else { + peer.Peer = device.LookupPeer(publicKey) + } - if err != nil { - logError.Println("Failed to set endpoint:", err, ":", value) - return &IPCError{ipc.IpcErrorInvalid} - } - - case "persistent_keepalive_interval": - - // update persistent keepalive interval - - logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval") - - secs, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to set persistent keepalive interval:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - old := peer.persistentKeepaliveInterval - peer.persistentKeepaliveInterval = uint16(secs) - - // send immediate keepalive if we're turning it on and before it wasn't on - - if old == 0 && secs != 0 { - if err != nil { - logError.Println("Failed to get tun device status:", err) - return &IPCError{ipc.IpcErrorIO} - } - if device.isUp.Get() && !dummy { - peer.SendKeepalive() - } - } + peer.created = peer.Peer == nil + if peer.created { + peer.Peer, err = device.NewPeer(publicKey) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) + } + device.log.Verbosef("%v - UAPI: Created", peer.Peer) + } + return nil +} - case "replace_allowed_ips": +func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error { + switch key { + case "update_only": + // allow disabling of creation + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) + } + if peer.created && !peer.dummy { + device.RemovePeer(peer.handshake.remoteStatic) + peer.Peer = &Peer{} + peer.dummy = true + } - logDebug.Println(peer, "- UAPI: Removing all allowedips") + case "remove": + // remove currently selected peer from device + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) + } + if !peer.dummy { + device.log.Verbosef("%v - UAPI: Removing", peer.Peer) + device.RemovePeer(peer.handshake.remoteStatic) + } + peer.Peer = &Peer{} + peer.dummy = true - if value != "true" { - logError.Println("Failed to replace allowedips, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } + case "preshared_key": + device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer) - if dummy { - continue - } + peer.handshake.mutex.Lock() + err := peer.handshake.presharedKey.FromHex(value) + peer.handshake.mutex.Unlock() - device.allowedips.RemoveByPeer(peer) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) + } - case "allowed_ip": + case "endpoint": + device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer) + endpoint, err := device.net.bind.ParseEndpoint(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) + } + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + peer.endpoint.val = endpoint - logDebug.Println(peer, "- UAPI: Adding allowedip") + case "persistent_keepalive_interval": + device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) - _, network, err := net.ParseCIDR(value) - if err != nil { - logError.Println("Failed to set allowed ip:", err) - return &IPCError{ipc.IpcErrorInvalid} - } + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) + } - if dummy { - continue - } + old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint(ones), peer) + // Send immediate keepalive if we're turning it on and before it wasn't on. + peer.pkaOn = old == 0 && secs != 0 - case "protocol_version": + case "replace_allowed_ips": + device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) + } + if peer.dummy { + return nil + } + device.allowedips.RemoveByPeer(peer.Peer) - if value != "1" { - logError.Println("Invalid protocol version:", value) - return &IPCError{ipc.IpcErrorInvalid} - } + case "allowed_ip": + device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) + prefix, err := netip.ParsePrefix(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) + } + if peer.dummy { + return nil + } + device.allowedips.Insert(prefix, peer.Peer) - default: - logError.Println("Invalid UAPI peer key:", key) - return &IPCError{ipc.IpcErrorInvalid} - } + case "protocol_version": + if value != "1" { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) } + + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) } return nil } -func (device *Device) IpcHandle(socket net.Conn) { +func (device *Device) IpcGet() (string, error) { + buf := new(strings.Builder) + if err := device.IpcGetOperation(buf); err != nil { + return "", err + } + return buf.String(), nil +} - // create buffered read/writer +func (device *Device) IpcSet(uapiConf string) error { + return device.IpcSetOperation(strings.NewReader(uapiConf)) +} +func (device *Device) IpcHandle(socket net.Conn) { defer socket.Close() buffered := func(s io.ReadWriter) *bufio.ReadWriter { @@ -409,45 +414,44 @@ func (device *Device) IpcHandle(socket net.Conn) { return bufio.NewReadWriter(reader, writer) }(socket) - defer buffered.Flush() - - op, err := buffered.ReadString('\n') - if err != nil { - return - } - - // handle operation - - var status *IPCError + for { + op, err := buffered.ReadString('\n') + if err != nil { + return + } - switch op { - case "set=1\n": - err = device.IpcSetOperation(buffered.Reader) - if err != nil && !errors.As(err, &status) { - // should never happen - device.log.Error.Println("Invalid UAPI error:", err) - status = &IPCError{1} + // handle operation + switch op { + case "set=1\n": + err = device.IpcSetOperation(buffered.Reader) + case "get=1\n": + var nextByte byte + nextByte, err = buffered.ReadByte() + if err != nil { + return + } + if nextByte != '\n' { + err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) + break + } + err = device.IpcGetOperation(buffered.Writer) + default: + device.log.Errorf("invalid UAPI operation: %v", op) + return } - case "get=1\n": - err = device.IpcGetOperation(buffered.Writer) + // write status + var status *IPCError if err != nil && !errors.As(err, &status) { - // should never happen - device.log.Error.Println("Invalid UAPI error:", err) - status = &IPCError{1} + // shouldn't happen + status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err) } - - default: - device.log.Error.Println("Invalid UAPI operation:", op) - return - } - - // write status - - if status != nil { - device.log.Error.Println(status) - fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") + if status != nil { + device.log.Errorf("%v", status) + fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") + } + buffered.Flush() } } diff --git a/device/version.go b/device/version.go deleted file mode 100644 index 0877595..0000000 --- a/device/version.go +++ /dev/null @@ -1,3 +0,0 @@ -package device - -const WireGuardGoVersion = "0.0.20200320" |