aboutsummaryrefslogtreecommitdiffstats
path: root/device
diff options
context:
space:
mode:
Diffstat (limited to 'device')
-rw-r--r--device/allowedips.go359
-rw-r--r--device/allowedips_rand_test.go120
-rw-r--r--device/allowedips_test.go75
-rw-r--r--device/bind_test.go15
-rw-r--r--device/bindsocketshim.go36
-rw-r--r--device/channels.go137
-rw-r--r--device/constants.go3
-rw-r--r--device/cookie.go7
-rw-r--r--device/cookie_test.go7
-rw-r--r--device/device.go477
-rw-r--r--device/device_test.go510
-rw-r--r--device/devicestate_string.go16
-rw-r--r--device/endpoint_test.go40
-rw-r--r--device/indextable.go7
-rw-r--r--device/ip.go2
-rw-r--r--device/kdf_test.go4
-rw-r--r--device/keypair.go17
-rw-r--r--device/logger.go67
-rw-r--r--device/misc.go48
-rw-r--r--device/mobilequirks.go19
-rw-r--r--device/noise-helpers.go12
-rw-r--r--device/noise-protocol.go96
-rw-r--r--device/noise-types.go25
-rw-r--r--device/noise_test.go49
-rw-r--r--device/peer.go262
-rw-r--r--device/peer_test.go43
-rw-r--r--device/pools.go157
-rw-r--r--device/pools_test.go139
-rw-r--r--device/queueconstants_android.go7
-rw-r--r--device/queueconstants_default.go7
-rw-r--r--device/queueconstants_ios.go23
-rw-r--r--device/queueconstants_windows.go15
-rw-r--r--device/race_disabled_test.go10
-rw-r--r--device/race_enabled_test.go10
-rw-r--r--device/receive.go630
-rw-r--r--device/send.go632
-rw-r--r--device/sticky_default.go2
-rw-r--r--device/sticky_linux.go55
-rw-r--r--device/timers.go76
-rw-r--r--device/tun.go49
-rw-r--r--device/tun_test.go56
-rw-r--r--device/uapi.go636
-rw-r--r--device/version.go3
43 files changed, 2684 insertions, 2276 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
index 143bda3..fa46f97 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -1,173 +1,201 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
+ "container/list"
+ "encoding/binary"
"errors"
"math/bits"
"net"
+ "net/netip"
"sync"
"unsafe"
)
-type trieEntry struct {
- cidr uint
- child [2]*trieEntry
- bits net.IP
- peer *Peer
-
- // index of "branching" bit
-
- bit_at_byte uint
- bit_at_shift uint
-}
-
-func isLittleEndian() bool {
- one := uint32(1)
- return *(*byte)(unsafe.Pointer(&one)) != 0
+type parentIndirection struct {
+ parentBit **trieEntry
+ parentBitType uint8
}
-func swapU32(i uint32) uint32 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes32(i)
-}
-
-func swapU64(i uint64) uint64 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes64(i)
+type trieEntry struct {
+ peer *Peer
+ child [2]*trieEntry
+ parent parentIndirection
+ cidr uint8
+ bitAtByte uint8
+ bitAtShift uint8
+ bits []byte
+ perPeerElem *list.Element
}
-func commonBits(ip1 net.IP, ip2 net.IP) uint {
+func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
- a := (*uint32)(unsafe.Pointer(&ip1[0]))
- b := (*uint32)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
- return uint(bits.LeadingZeros32(swapU32(x)))
+ a := binary.BigEndian.Uint32(ip1)
+ b := binary.BigEndian.Uint32(ip2)
+ x := a ^ b
+ return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len {
- a := (*uint64)(unsafe.Pointer(&ip1[0]))
- b := (*uint64)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
+ a := binary.BigEndian.Uint64(ip1)
+ b := binary.BigEndian.Uint64(ip2)
+ x := a ^ b
if x != 0 {
- return uint(bits.LeadingZeros64(swapU64(x)))
+ return uint8(bits.LeadingZeros64(x))
}
- a = (*uint64)(unsafe.Pointer(&ip1[8]))
- b = (*uint64)(unsafe.Pointer(&ip2[8]))
- x = *a ^ *b
- return 64 + uint(bits.LeadingZeros64(swapU64(x)))
+ a = binary.BigEndian.Uint64(ip1[8:])
+ b = binary.BigEndian.Uint64(ip2[8:])
+ x = a ^ b
+ return 64 + uint8(bits.LeadingZeros64(x))
} else {
panic("Wrong size bit string")
}
}
-func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
- if node == nil {
- return node
- }
-
- // walk recursively
-
- node.child[0] = node.child[0].removeByPeer(p)
- node.child[1] = node.child[1].removeByPeer(p)
+func (node *trieEntry) addToPeerEntries() {
+ node.perPeerElem = node.peer.trieEntries.PushBack(node)
+}
- if node.peer != p {
- return node
+func (node *trieEntry) removeFromPeerEntries() {
+ if node.perPeerElem != nil {
+ node.peer.trieEntries.Remove(node.perPeerElem)
+ node.perPeerElem = nil
}
+}
- // remove peer & merge
+func (node *trieEntry) choose(ip []byte) byte {
+ return (ip[node.bitAtByte] >> node.bitAtShift) & 1
+}
- node.peer = nil
- if node.child[0] == nil {
- return node.child[1]
+func (node *trieEntry) maskSelf() {
+ mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
+ for i := 0; i < len(mask); i++ {
+ node.bits[i] &= mask[i]
}
- return node.child[0]
}
-func (node *trieEntry) choose(ip net.IP) byte {
- return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
+func (node *trieEntry) zeroizePointers() {
+ // Make the garbage collector's life slightly easier
+ node.peer = nil
+ node.child[0] = nil
+ node.child[1] = nil
+ node.parent.parentBit = nil
}
-func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
-
- // at leaf
-
- if node == nil {
- return &trieEntry{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
+ for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
+ parent = node
+ if parent.cidr == cidr {
+ exact = true
+ return
}
+ bit := node.choose(ip)
+ node = node.child[bit]
}
+ return
+}
- // traverse deeper
-
- common := commonBits(node.bits, ip)
- if node.cidr <= cidr && common >= node.cidr {
- if node.cidr == cidr {
- node.peer = peer
- return node
+func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
+ if *trie.parentBit == nil {
+ node := &trieEntry{
+ peer: peer,
+ parent: trie,
+ bits: ip,
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
}
- bit := node.choose(ip)
- node.child[bit] = node.child[bit].insert(ip, cidr, peer)
- return node
+ node.maskSelf()
+ node.addToPeerEntries()
+ *trie.parentBit = node
+ return
+ }
+ node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
+ if exact {
+ node.removeFromPeerEntries()
+ node.peer = peer
+ node.addToPeerEntries()
+ return
}
-
- // split node
newNode := &trieEntry{
- bits: ip,
- peer: peer,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+ peer: peer,
+ bits: ip,
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
}
+ newNode.maskSelf()
+ newNode.addToPeerEntries()
- cidr = min(cidr, common)
-
- // check for shorter prefix
+ var down *trieEntry
+ if node == nil {
+ down = *trie.parentBit
+ } else {
+ bit := node.choose(ip)
+ down = node.child[bit]
+ if down == nil {
+ newNode.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = newNode
+ return
+ }
+ }
+ common := commonBits(down.bits, ip)
+ if common < cidr {
+ cidr = common
+ }
+ parent := node
if newNode.cidr == cidr {
- bit := newNode.choose(node.bits)
- newNode.child[bit] = node
- return newNode
+ bit := newNode.choose(down.bits)
+ down.parent = parentIndirection{&newNode.child[bit], bit}
+ newNode.child[bit] = down
+ if parent == nil {
+ newNode.parent = trie
+ *trie.parentBit = newNode
+ } else {
+ bit := parent.choose(newNode.bits)
+ newNode.parent = parentIndirection{&parent.child[bit], bit}
+ parent.child[bit] = newNode
+ }
+ return
}
- // create new parent for node & newNode
-
- parent := &trieEntry{
- bits: ip,
- peer: nil,
- cidr: cidr,
- bit_at_byte: cidr / 8,
- bit_at_shift: 7 - (cidr % 8),
+ node = &trieEntry{
+ bits: append([]byte{}, newNode.bits...),
+ cidr: cidr,
+ bitAtByte: cidr / 8,
+ bitAtShift: 7 - (cidr % 8),
+ }
+ node.maskSelf()
+
+ bit := node.choose(down.bits)
+ down.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = down
+ bit = node.choose(newNode.bits)
+ newNode.parent = parentIndirection{&node.child[bit], bit}
+ node.child[bit] = newNode
+ if parent == nil {
+ node.parent = trie
+ *trie.parentBit = node
+ } else {
+ bit := parent.choose(node.bits)
+ node.parent = parentIndirection{&parent.child[bit], bit}
+ parent.child[bit] = node
}
-
- bit := parent.choose(ip)
- parent.child[bit] = newNode
- parent.child[bit^1] = node
-
- return parent
}
-func (node *trieEntry) lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
- size := uint(len(ip))
+ size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
- if node.bit_at_byte == size {
+ if node.bitAtByte == size {
break
}
bit := node.choose(ip)
@@ -176,76 +204,91 @@ func (node *trieEntry) lookup(ip net.IP) *Peer {
return found
}
-func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
- if node == nil {
- return results
- }
- if node.peer == p {
- mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
- results = append(results, net.IPNet{
- Mask: mask,
- IP: node.bits.Mask(mask),
- })
- }
- results = node.child[0].entriesForPeer(p, results)
- results = node.child[1].entriesForPeer(p, results)
- return results
-}
-
type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
- allowed := make([]net.IPNet, 0, 10)
- allowed = table.IPv4.entriesForPeer(peer, allowed)
- allowed = table.IPv6.entriesForPeer(peer, allowed)
- return allowed
-}
-
-func (table *AllowedIPs) Reset() {
- table.mutex.Lock()
- defer table.mutex.Unlock()
-
- table.IPv4 = nil
- table.IPv6 = nil
+ for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
+ node := elem.Value.(*trieEntry)
+ a, _ := netip.AddrFromSlice(node.bits)
+ if !cb(netip.PrefixFrom(a, int(node.cidr))) {
+ return
+ }
+ }
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- table.IPv4 = table.IPv4.removeByPeer(peer)
- table.IPv6 = table.IPv6.removeByPeer(peer)
+ var next *list.Element
+ for elem := peer.trieEntries.Front(); elem != nil; elem = next {
+ next = elem.Next()
+ node := elem.Value.(*trieEntry)
+
+ node.removeFromPeerEntries()
+ node.peer = nil
+ if node.child[0] != nil && node.child[1] != nil {
+ continue
+ }
+ bit := 0
+ if node.child[0] == nil {
+ bit = 1
+ }
+ child := node.child[bit]
+ if child != nil {
+ child.parent = node.parent
+ }
+ *node.parent.parentBit = child
+ if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
+ node.zeroizePointers()
+ continue
+ }
+ parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
+ if parent.peer != nil {
+ node.zeroizePointers()
+ continue
+ }
+ child = parent.child[node.parent.parentBitType^1]
+ if child != nil {
+ child.parent = parent.parent
+ }
+ *parent.parent.parentBit = child
+ node.zeroizePointers()
+ parent.zeroizePointers()
+ }
}
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
+func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- switch len(ip) {
- case net.IPv6len:
- table.IPv6 = table.IPv6.insert(ip, cidr, peer)
- case net.IPv4len:
- table.IPv4 = table.IPv4.insert(ip, cidr, peer)
- default:
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else {
panic(errors.New("inserting unknown address type"))
}
}
-func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
- return table.IPv4.lookup(address)
-}
-
-func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
- table.mutex.RLock()
- defer table.mutex.RUnlock()
- return table.IPv6.lookup(address)
+ switch len(ip) {
+ case net.IPv6len:
+ return table.IPv6.lookup(ip)
+ case net.IPv4len:
+ return table.IPv4.lookup(ip)
+ default:
+ panic(errors.New("looking up unknown address type"))
+ }
}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 3947830..07065c3 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -1,25 +1,28 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
+ "net"
+ "net/netip"
"sort"
"testing"
)
const (
- NumberOfPeers = 100
- NumberOfAddresses = 250
- NumberOfTests = 10000
+ NumberOfPeers = 100
+ NumberOfPeerRemovals = 4
+ NumberOfAddresses = 250
+ NumberOfTests = 10000
)
type SlowNode struct {
peer *Peer
- cidr uint
+ cidr uint8
bits []byte
}
@@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}
-func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
+func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer
@@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer {
return nil
}
-func TestTrieRandomIPv4(t *testing.T) {
- var trie *trieEntry
- var slow SlowRouter
- var peers []*Peer
-
- rand.Seed(1)
-
- const AddressLength = 4
-
- for n := 0; n < NumberOfPeers; n += 1 {
- peers = append(peers, &Peer{})
- }
-
- for n := 0; n < NumberOfAddresses; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- trie = trie.insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
- }
-
- for n := 0; n < NumberOfTests; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := trie.lookup(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
+ n := 0
+ for _, x := range r {
+ if x.peer != peer {
+ r[n] = x
+ n++
}
}
+ return r[:n]
}
-func TestTrieRandomIPv6(t *testing.T) {
- var trie *trieEntry
- var slow SlowRouter
+func TestTrieRandom(t *testing.T) {
+ var slow4, slow6 SlowRouter
var peers []*Peer
+ var allowedIPs AllowedIPs
rand.Seed(1)
- const AddressLength = 16
-
- for n := 0; n < NumberOfPeers; n += 1 {
+ for n := 0; n < NumberOfPeers; n++ {
peers = append(peers, &Peer{})
}
- for n := 0; n < NumberOfAddresses; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
- index := rand.Int() % NumberOfPeers
- trie = trie.insert(addr[:], cidr, peers[index])
- slow = slow.Insert(addr[:], cidr, peers[index])
+ for n := 0; n < NumberOfAddresses; n++ {
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ cidr := uint8(rand.Intn(32) + 1)
+ index := rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
+ slow4 = slow4.Insert(addr4[:], cidr, peers[index])
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ cidr = uint8(rand.Intn(128) + 1)
+ index = rand.Intn(NumberOfPeers)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
+ slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
- for n := 0; n < NumberOfTests; n += 1 {
- var addr [AddressLength]byte
- rand.Read(addr[:])
- peer1 := slow.Lookup(addr[:])
- peer2 := trie.lookup(addr[:])
- if peer1 != peer2 {
- t.Error("Trie did not match naive implementation, for:", addr)
+ var p int
+ for p = 0; ; p++ {
+ for n := 0; n < NumberOfTests; n++ {
+ var addr4 [4]byte
+ rand.Read(addr4[:])
+ peer1 := slow4.Lookup(addr4[:])
+ peer2 := allowedIPs.Lookup(addr4[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
+ }
+
+ var addr6 [16]byte
+ rand.Read(addr6[:])
+ peer1 = slow6.Lookup(addr6[:])
+ peer2 = allowedIPs.Lookup(addr6[:])
+ if peer1 != peer2 {
+ t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
+ }
+ }
+ if p >= len(peers) || p >= NumberOfPeerRemovals {
+ break
}
+ allowedIPs.RemoveByPeer(peers[p])
+ slow4 = slow4.RemoveByPeer(peers[p])
+ slow6 = slow6.RemoveByPeer(peers[p])
+ }
+ for ; p < len(peers); p++ {
+ allowedIPs.RemoveByPeer(peers[p])
+ }
+
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Failed to remove all nodes from trie by peer")
}
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 005df48..cde068e 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,40 +8,17 @@ package device
import (
"math/rand"
"net"
+ "net/netip"
"testing"
)
-/* Todo: More comprehensive
- */
-
type testPairCommonBits struct {
s1 []byte
s2 []byte
- match uint
-}
-
-type testPairTrieInsert struct {
- key []byte
- cidr uint
- peer *Peer
-}
-
-type testPairTrieLookup struct {
- key []byte
- peer *Peer
-}
-
-func printTrie(t *testing.T, p *trieEntry) {
- if p == nil {
- return
- }
- t.Log(p)
- printTrie(t, p.child[0])
- printTrie(t, p.child[1])
+ match uint8
}
func TestCommonBits(t *testing.T) {
-
tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
@@ -62,27 +39,28 @@ func TestCommonBits(t *testing.T) {
}
}
-func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
+func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
+ root := parentIndirection{&trie, 2}
rand.Seed(1)
const AddressLength = 4
- for n := 0; n < peerNumber; n += 1 {
+ for n := 0; n < peerNumber; n++ {
peers = append(peers, &Peer{})
}
- for n := 0; n < addressNumber; n += 1 {
+ for n := 0; n < addressNumber; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
- cidr := uint(rand.Uint32() % (AddressLength * 8))
+ cidr := uint8(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
- trie = trie.insert(addr[:], cidr, peers[index])
+ root.insert(addr[:], cidr, peers[index])
}
- for n := 0; n < b.N; n += 1 {
+ for n := 0; n < b.N; n++ {
var addr [AddressLength]byte
rand.Read(addr[:])
trie.lookup(addr[:])
@@ -117,21 +95,21 @@ func TestTrieIPv4(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *trieEntry
+ var allowedIPs AllowedIPs
- insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
- trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
+ insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.lookup([]byte{a, b, c, d})
+ p := allowedIPs.Lookup([]byte{a, b, c, d})
if p != peer {
t.Error("Assert EQ failed")
}
}
assertNEQ := func(peer *Peer, a, b, c, d byte) {
- p := trie.lookup([]byte{a, b, c, d})
+ p := allowedIPs.Lookup([]byte{a, b, c, d})
if p == peer {
t.Error("Assert NEQ failed")
}
@@ -173,7 +151,7 @@ func TestTrieIPv4(t *testing.T) {
assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0)
- trie = trie.removeByPeer(a)
+ allowedIPs.RemoveByPeer(a)
assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0)
@@ -181,12 +159,21 @@ func TestTrieIPv4(t *testing.T) {
assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0)
- trie = nil
+ allowedIPs.RemoveByPeer(a)
+ allowedIPs.RemoveByPeer(b)
+ allowedIPs.RemoveByPeer(c)
+ allowedIPs.RemoveByPeer(d)
+ allowedIPs.RemoveByPeer(e)
+ allowedIPs.RemoveByPeer(g)
+ allowedIPs.RemoveByPeer(h)
+ if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
+ t.Error("Expected removing all the peers to empty trie, but it did not")
+ }
insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24)
- trie = trie.removeByPeer(a)
+ allowedIPs.RemoveByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
}
@@ -204,7 +191,7 @@ func TestTrieIPv6(t *testing.T) {
g := &Peer{}
h := &Peer{}
- var trie *trieEntry
+ var allowedIPs AllowedIPs
expand := func(a uint32) []byte {
var out [4]byte
@@ -215,13 +202,13 @@ func TestTrieIPv6(t *testing.T) {
return out[:]
}
- insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
+ insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- trie = trie.insert(addr, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
@@ -230,7 +217,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- p := trie.lookup(addr)
+ p := allowedIPs.Lookup(addr)
if p != peer {
t.Error("Assert EQ failed")
}
diff --git a/device/bind_test.go b/device/bind_test.go
index 339adbe..302a521 100644
--- a/device/bind_test.go
+++ b/device/bind_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -14,14 +14,11 @@ import (
type DummyDatagram struct {
msg []byte
endpoint conn.Endpoint
- world bool // better type
}
type DummyBind struct {
in6 chan DummyDatagram
- ou6 chan DummyDatagram
in4 chan DummyDatagram
- ou4 chan DummyDatagram
closed bool
}
@@ -29,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil
}
-func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
+func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6
if !ok {
return 0, nil, errors.New("closed")
}
- copy(buff, datagram.msg)
+ copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
-func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
+func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4
if !ok {
return 0, nil, errors.New("closed")
}
- copy(buff, datagram.msg)
+ copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
@@ -54,6 +51,6 @@ func (b *DummyBind) Close() error {
return nil
}
-func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
+func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
return nil
}
diff --git a/device/bindsocketshim.go b/device/bindsocketshim.go
deleted file mode 100644
index 896c7d2..0000000
--- a/device/bindsocketshim.go
+++ /dev/null
@@ -1,36 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "errors"
-
- "golang.zx2c4.com/wireguard/conn"
-)
-
-// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
-func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
- if device.net.bind == nil {
- return errors.New("Bind is not yet initialized")
- }
-
- if iface, ok := device.net.bind.(conn.BindSocketToInterface); ok {
- return iface.BindSocketToInterface4(interfaceIndex, blackhole)
- }
- return nil
-}
-
-// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
-func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
- if device.net.bind == nil {
- return errors.New("Bind is not yet initialized")
- }
-
- if iface, ok := device.net.bind.(conn.BindSocketToInterface); ok {
- return iface.BindSocketToInterface6(interfaceIndex, blackhole)
- }
- return nil
-}
diff --git a/device/channels.go b/device/channels.go
new file mode 100644
index 0000000..e526f6b
--- /dev/null
+++ b/device/channels.go
@@ -0,0 +1,137 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "runtime"
+ "sync"
+)
+
+// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
+// An outboundQueue is ref-counted using its wg field.
+// An outboundQueue created with newOutboundQueue has one reference.
+// Every additional writer must call wg.Add(1).
+// Every completed writer must call wg.Done().
+// When no further writers will be added,
+// call wg.Done to remove the initial reference.
+// When the refcount hits 0, the queue's channel is closed.
+type outboundQueue struct {
+ c chan *QueueOutboundElementsContainer
+ wg sync.WaitGroup
+}
+
+func newOutboundQueue() *outboundQueue {
+ q := &outboundQueue{
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
+ }
+ q.wg.Add(1)
+ go func() {
+ q.wg.Wait()
+ close(q.c)
+ }()
+ return q
+}
+
+// A inboundQueue is similar to an outboundQueue; see those docs.
+type inboundQueue struct {
+ c chan *QueueInboundElementsContainer
+ wg sync.WaitGroup
+}
+
+func newInboundQueue() *inboundQueue {
+ q := &inboundQueue{
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
+ }
+ q.wg.Add(1)
+ go func() {
+ q.wg.Wait()
+ close(q.c)
+ }()
+ return q
+}
+
+// A handshakeQueue is similar to an outboundQueue; see those docs.
+type handshakeQueue struct {
+ c chan QueueHandshakeElement
+ wg sync.WaitGroup
+}
+
+func newHandshakeQueue() *handshakeQueue {
+ q := &handshakeQueue{
+ c: make(chan QueueHandshakeElement, QueueHandshakeSize),
+ }
+ q.wg.Add(1)
+ go func() {
+ q.wg.Wait()
+ close(q.c)
+ }()
+ return q
+}
+
+type autodrainingInboundQueue struct {
+ c chan *QueueInboundElementsContainer
+}
+
+// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
+// It is useful in cases in which is it hard to manage the lifetime of the channel.
+// The returned channel must not be closed. Senders should signal shutdown using
+// some other means, such as sending a sentinel nil values.
+func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
+ q := &autodrainingInboundQueue{
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
+ }
+ runtime.SetFinalizer(q, device.flushInboundQueue)
+ return q
+}
+
+func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
+ for {
+ select {
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsContainer(elemsContainer)
+ default:
+ return
+ }
+ }
+}
+
+type autodrainingOutboundQueue struct {
+ c chan *QueueOutboundElementsContainer
+}
+
+// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
+// It is useful in cases in which is it hard to manage the lifetime of the channel.
+// The returned channel must not be closed. Senders should signal shutdown using
+// some other means, such as sending a sentinel nil values.
+// All sends to the channel must be best-effort, because there may be no receivers.
+func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
+ q := &autodrainingOutboundQueue{
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
+ }
+ runtime.SetFinalizer(q, device.flushOutboundQueue)
+ return q
+}
+
+func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
+ for {
+ select {
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsContainer)
+ default:
+ return
+ }
+ }
+}
diff --git a/device/constants.go b/device/constants.go
index 1a4b8ea..59854a1 100644
--- a/device/constants.go
+++ b/device/constants.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -35,7 +35,6 @@ const (
/* Implementation constants */
const (
- UnderLoadQueueSize = QueueHandshakeSize / 8
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
MaxPeers = 1 << 16 // maximum number of configured peers
)
diff --git a/device/cookie.go b/device/cookie.go
index c658ca3..876f05d 100644
--- a/device/cookie.go
+++ b/device/cookie.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2])
}
-func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
+func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
st.RLock()
defer st.RUnlock()
@@ -119,7 +119,6 @@ func (st *CookieChecker) CreateReply(
recv uint32,
src []byte,
) (*MessageCookieReply, error) {
-
st.RLock()
// refresh cookie secret
@@ -204,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
-
if err != nil {
return false
}
@@ -215,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
}
func (st *CookieGenerator) AddMacs(msg []byte) {
-
size := len(msg)
smac2 := size - blake2s.Size128
diff --git a/device/cookie_test.go b/device/cookie_test.go
index 7e4c362..4f1e50a 100644
--- a/device/cookie_test.go
+++ b/device/cookie_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -10,7 +10,6 @@ import (
)
func TestCookieMAC1(t *testing.T) {
-
// setup generator / checker
var (
@@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20
- srcBad1 := []byte{192, 168, 13, 37, 40, 01}
+ srcBad1 := []byte{192, 168, 13, 37, 40, 1}
if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed")
}
- srcBad2 := []byte{192, 168, 13, 38, 40, 01}
+ srcBad2 := []byte{192, 168, 13, 38, 40, 1}
if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed")
}
diff --git a/device/device.go b/device/device.go
index c64432e..83c33ee 100644
--- a/device/device.go
+++ b/device/device.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -11,8 +11,6 @@ import (
"sync/atomic"
"time"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
@@ -20,28 +18,33 @@ import (
)
type Device struct {
- isUp AtomicBool // device is (going) up
- isClosed AtomicBool // device is closed? (acting as guard)
- log *Logger
-
- // synchronized resources (locks acquired in order)
-
state struct {
- starting sync.WaitGroup
+ // state holds the device's state. It is accessed atomically.
+ // Use the device.deviceState method to read it.
+ // device.deviceState does not acquire the mutex, so it captures only a snapshot.
+ // During state transitions, the state variable is updated before the device itself.
+ // The state is thus either the current state of the device or
+ // the intended future state of the device.
+ // For example, while executing a call to Up, state will be deviceStateUp.
+ // There is no guarantee that that intended future state of the device
+ // will become the actual state; Up can fail.
+ // The device can also change state multiple times between time of check and time of use.
+ // Unsynchronized uses of state must therefore be advisory/best-effort only.
+ state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
+ // stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup
+ // mu protects state changes.
sync.Mutex
- changing AtomicBool
- current bool
}
net struct {
- starting sync.WaitGroup
stopping sync.WaitGroup
sync.RWMutex
bind conn.Bind // bind interface
netlinkCancel *rwcancel.RWCancel
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
+ brokenRoaming bool
}
staticIdentity struct {
@@ -51,153 +54,176 @@ type Device struct {
}
peers struct {
- sync.RWMutex
- keyMap map[NoisePublicKey]*Peer
+ sync.RWMutex // protects keyMap
+ keyMap map[NoisePublicKey]*Peer
}
- // unprotected / "self-synchronising resources"
+ rate struct {
+ underLoadUntil atomic.Int64
+ limiter ratelimiter.Ratelimiter
+ }
allowedips AllowedIPs
indexTable IndexTable
cookieChecker CookieChecker
- rate struct {
- underLoadUntil atomic.Value
- limiter ratelimiter.Ratelimiter
- }
-
pool struct {
- messageBufferPool *sync.Pool
- messageBufferReuseChan chan *[MaxMessageSize]byte
- inboundElementPool *sync.Pool
- inboundElementReuseChan chan *QueueInboundElement
- outboundElementPool *sync.Pool
- outboundElementReuseChan chan *QueueOutboundElement
+ inboundElementsContainer *WaitPool
+ outboundElementsContainer *WaitPool
+ messageBuffers *WaitPool
+ inboundElements *WaitPool
+ outboundElements *WaitPool
}
queue struct {
- encryption chan *QueueOutboundElement
- decryption chan *QueueInboundElement
- handshake chan QueueHandshakeElement
- }
-
- signals struct {
- stop chan struct{}
+ encryption *outboundQueue
+ decryption *inboundQueue
+ handshake *handshakeQueue
}
tun struct {
device tun.Device
- mtu int32
+ mtu atomic.Int32
}
+
+ ipcMutex sync.RWMutex
+ closed chan struct{}
+ log *Logger
}
-/* Converts the peer into a "zombie", which remains in the peer map,
- * but processes no packets and does not exists in the routing table.
- *
- * Must hold device.peers.Mutex
- */
-func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
+// deviceState represents the state of a Device.
+// There are three states: down, up, closed.
+// Transitions:
+//
+// down -----+
+// ↑↓ ↓
+// up -> closed
+type deviceState uint32
+
+//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
+const (
+ deviceStateDown deviceState = iota
+ deviceStateUp
+ deviceStateClosed
+)
- // stop routing and processing of packets
+// deviceState returns device.state.state as a deviceState
+// See those docs for how to interpret this value.
+func (device *Device) deviceState() deviceState {
+ return deviceState(device.state.state.Load())
+}
+
+// isClosed reports whether the device is closed (or is closing).
+// See device.state.state comments for how to interpret this value.
+func (device *Device) isClosed() bool {
+ return device.deviceState() == deviceStateClosed
+}
+
+// isUp reports whether the device is up (or is attempting to come up).
+// See device.state.state comments for how to interpret this value.
+func (device *Device) isUp() bool {
+ return device.deviceState() == deviceStateUp
+}
+// Must hold device.peers.Lock()
+func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
+ // stop routing and processing of packets
device.allowedips.RemoveByPeer(peer)
peer.Stop()
// remove from peer map
-
delete(device.peers.keyMap, key)
}
-func deviceUpdateState(device *Device) {
-
- // check if state already being updated (guard)
-
- if device.state.changing.Swap(true) {
- return
- }
-
- // compare to current state of device
-
+// changeState attempts to change the device state to match want.
+func (device *Device) changeState(want deviceState) (err error) {
device.state.Lock()
-
- newIsUp := device.isUp.Get()
-
- if newIsUp == device.state.current {
- device.state.changing.Set(false)
- device.state.Unlock()
- return
+ defer device.state.Unlock()
+ old := device.deviceState()
+ if old == deviceStateClosed {
+ // once closed, always closed
+ device.log.Verbosef("Interface closed, ignored requested state %s", want)
+ return nil
}
-
- // change state of device
-
- switch newIsUp {
- case true:
- if err := device.BindUpdate(); err != nil {
- device.log.Error.Printf("Unable to update bind: %v\n", err)
- device.isUp.Set(false)
+ switch want {
+ case old:
+ return nil
+ case deviceStateUp:
+ device.state.state.Store(uint32(deviceStateUp))
+ err = device.upLocked()
+ if err == nil {
break
}
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Start()
- if peer.persistentKeepaliveInterval > 0 {
- peer.SendKeepalive()
- }
+ fallthrough // up failed; bring the device all the way back down
+ case deviceStateDown:
+ device.state.state.Store(uint32(deviceStateDown))
+ errDown := device.downLocked()
+ if err == nil {
+ err = errDown
}
- device.peers.RUnlock()
-
- case false:
- device.BindClose()
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Stop()
- }
- device.peers.RUnlock()
}
+ device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
+ return
+}
- // update state variables
-
- device.state.current = newIsUp
- device.state.changing.Set(false)
- device.state.Unlock()
+// upLocked attempts to bring the device up and reports whether it succeeded.
+// The caller must hold device.state.mu and is responsible for updating device.state.state.
+func (device *Device) upLocked() error {
+ if err := device.BindUpdate(); err != nil {
+ device.log.Errorf("Unable to update bind: %v", err)
+ return err
+ }
- // check for state change in the mean time
+ // The IPC set operation waits for peers to be created before calling Start() on them,
+ // so if there's a concurrent IPC set request happening, we should wait for it to complete.
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
- deviceUpdateState(device)
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Start()
+ if peer.persistentKeepaliveInterval.Load() > 0 {
+ peer.SendKeepalive()
+ }
+ }
+ device.peers.RUnlock()
+ return nil
}
-func (device *Device) Up() {
-
- // closed device cannot be brought up
+// downLocked attempts to bring the device down.
+// The caller must hold device.state.mu and is responsible for updating device.state.state.
+func (device *Device) downLocked() error {
+ err := device.BindClose()
+ if err != nil {
+ device.log.Errorf("Bind close failed: %v", err)
+ }
- if device.isClosed.Get() {
- return
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Stop()
}
+ device.peers.RUnlock()
+ return err
+}
- device.isUp.Set(true)
- deviceUpdateState(device)
+func (device *Device) Up() error {
+ return device.changeState(deviceStateUp)
}
-func (device *Device) Down() {
- device.isUp.Set(false)
- deviceUpdateState(device)
+func (device *Device) Down() error {
+ return device.changeState(deviceStateDown)
}
func (device *Device) IsUnderLoad() bool {
-
// check if currently under load
-
now := time.Now()
- underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
+ underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad {
- device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
+ device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
return true
}
-
// check if recently under load
-
- until := device.rate.underLoadUntil.Load().(time.Time)
- return until.After(now)
+ return device.rate.underLoadUntil.Load() > now.UnixNano()
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@@ -224,7 +250,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
publicKey := sk.publicKey()
for key, peer := range device.peers.keyMap {
if peer.handshake.remoteStatic.Equals(publicKey) {
- unsafeRemovePeer(device, peer, key)
+ peer.handshake.mutex.RUnlock()
+ removePeerLocked(device, peer, key)
+ peer.handshake.mutex.RLock()
}
}
@@ -239,7 +267,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
@@ -253,70 +281,63 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
return nil
}
-func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
+func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device)
-
- device.isUp.Set(false)
- device.isClosed.Set(false)
-
+ device.state.state.Store(uint32(deviceStateDown))
+ device.closed = make(chan struct{})
device.log = logger
-
+ device.net.bind = bind
device.tun.device = tunDevice
mtu, err := device.tun.device.MTU()
if err != nil {
- logger.Error.Println("Trouble determining MTU, assuming default:", err)
+ device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU
}
- device.tun.mtu = int32(mtu)
-
+ device.tun.mtu.Store(int32(mtu))
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
-
device.rate.limiter.Init()
- device.rate.underLoadUntil.Store(time.Time{})
-
device.indexTable.Init()
- device.allowedips.Reset()
device.PopulatePools()
// create queues
- device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
- device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
- device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
-
- // prepare signals
-
- device.signals.stop = make(chan struct{})
-
- // prepare net
-
- device.net.port = 0
- device.net.bind = nil
+ device.queue.handshake = newHandshakeQueue()
+ device.queue.encryption = newOutboundQueue()
+ device.queue.decryption = newInboundQueue()
// start workers
cpus := runtime.NumCPU()
- device.state.starting.Wait()
device.state.stopping.Wait()
- for i := 0; i < cpus; i += 1 {
- device.state.starting.Add(3)
- device.state.stopping.Add(3)
- go device.RoutineEncryption()
- go device.RoutineDecryption()
- go device.RoutineHandshake()
+ device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
+ for i := 0; i < cpus; i++ {
+ go device.RoutineEncryption(i + 1)
+ go device.RoutineDecryption(i + 1)
+ go device.RoutineHandshake(i + 1)
}
- device.state.starting.Add(2)
- device.state.stopping.Add(2)
+ device.state.stopping.Add(1) // RoutineReadFromTUN
+ device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
- device.state.starting.Wait()
-
return device
}
+// BatchSize returns the BatchSize for the device as a whole which is the max of
+// the bind batch size and the tun batch size. The batch size reported by device
+// is the size used to construct memory pools, and is the allowed batch size for
+// the lifetime of the device.
+func (device *Device) BatchSize() int {
+ size := device.net.bind.BatchSize()
+ dSize := device.tun.device.BatchSize()
+ if size < dSize {
+ size = dSize
+ }
+ return size
+}
+
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
@@ -331,7 +352,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
peer, ok := device.peers.keyMap[key]
if ok {
- unsafeRemovePeer(device, peer, key)
+ removePeerLocked(device, peer, key)
}
}
@@ -340,67 +361,50 @@ func (device *Device) RemoveAllPeers() {
defer device.peers.Unlock()
for key, peer := range device.peers.keyMap {
- unsafeRemovePeer(device, peer, key)
+ removePeerLocked(device, peer, key)
}
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}
-func (device *Device) FlushPacketQueues() {
- for {
- select {
- case elem, ok := <-device.queue.decryption:
- if ok {
- elem.Drop()
- }
- case elem, ok := <-device.queue.encryption:
- if ok {
- elem.Drop()
- }
- case <-device.queue.handshake:
- default:
- return
- }
- }
-
-}
-
func (device *Device) Close() {
- if device.isClosed.Swap(true) {
- return
- }
-
- device.state.starting.Wait()
-
- device.log.Info.Println("Device closing")
- device.state.changing.Set(true)
device.state.Lock()
defer device.state.Unlock()
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
+ if device.isClosed() {
+ return
+ }
+ device.state.state.Store(uint32(deviceStateClosed))
+ device.log.Verbosef("Device closing")
device.tun.device.Close()
- device.BindClose()
-
- device.isUp.Set(false)
-
- close(device.signals.stop)
+ device.downLocked()
+ // Remove peers before closing queues,
+ // because peers assume that queues are active.
device.RemoveAllPeers()
+ // We kept a reference to the encryption and decryption queues,
+ // in case we started any new peers that might write to them.
+ // No new peers are coming; we are done with these queues.
+ device.queue.encryption.wg.Done()
+ device.queue.decryption.wg.Done()
+ device.queue.handshake.wg.Done()
device.state.stopping.Wait()
- device.FlushPacketQueues()
device.rate.limiter.Close()
- device.state.changing.Set(false)
- device.log.Info.Println("Interface closed")
+ device.log.Verbosef("Device closed")
+ close(device.closed)
}
func (device *Device) Wait() chan struct{} {
- return device.signals.stop
+ return device.closed
}
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
- if device.isClosed.Get() {
+ if !device.isUp() {
return
}
@@ -416,7 +420,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
device.peers.RUnlock()
}
-func unsafeCloseBind(device *Device) error {
+// closeBindLocked closes the device's net.bind.
+// The caller must hold the net mutex.
+func closeBindLocked(device *Device) error {
var err error
netc := &device.net
if netc.netlinkCancel != nil {
@@ -424,7 +430,6 @@ func unsafeCloseBind(device *Device) error {
}
if netc.bind != nil {
err = netc.bind.Close()
- netc.bind = nil
}
netc.stopping.Wait()
return err
@@ -437,34 +442,26 @@ func (device *Device) Bind() conn.Bind {
}
func (device *Device) BindSetMark(mark uint32) error {
-
device.net.Lock()
defer device.net.Unlock()
// check if modified
-
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
-
device.net.fwmark = mark
- if device.isUp.Get() && device.net.bind != nil {
+ if device.isUp() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
-
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@@ -472,76 +469,68 @@ func (device *Device) BindSetMark(mark uint32) error {
}
func (device *Device) BindUpdate() error {
-
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
-
- if err := unsafeCloseBind(device); err != nil {
+ if err := closeBindLocked(device); err != nil {
return err
}
// open new sockets
+ if !device.isUp() {
+ return nil
+ }
- if device.isUp.Get() {
+ // bind to new port
+ var err error
+ var recvFns []conn.ReceiveFunc
+ netc := &device.net
- // bind to new port
+ recvFns, netc.port, err = netc.bind.Open(netc.port)
+ if err != nil {
+ netc.port = 0
+ return err
+ }
- var err error
- netc := &device.net
- netc.bind, netc.port, err = conn.CreateBind(netc.port)
- if err != nil {
- netc.bind = nil
- netc.port = 0
- return err
- }
- netc.netlinkCancel, err = device.startRouteListener(netc.bind)
+ netc.netlinkCancel, err = device.startRouteListener(netc.bind)
+ if err != nil {
+ netc.bind.Close()
+ netc.port = 0
+ return err
+ }
+
+ // set fwmark
+ if netc.fwmark != 0 {
+ err = netc.bind.SetMark(netc.fwmark)
if err != nil {
- netc.bind.Close()
- netc.bind = nil
- netc.port = 0
return err
}
+ }
- // set fwmark
-
- if netc.fwmark != 0 {
- err = netc.bind.SetMark(netc.fwmark)
- if err != nil {
- return err
- }
- }
-
- // clear cached source addresses
-
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- }
- device.peers.RUnlock()
-
- // start receiving routines
-
- device.net.starting.Add(2)
- device.net.stopping.Add(2)
- go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
- go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
- device.net.starting.Wait()
+ // clear cached source addresses
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.markEndpointSrcForClearing()
+ }
+ device.peers.RUnlock()
- device.log.Debug.Println("UDP bind has been updated")
+ // start receiving routines
+ device.net.stopping.Add(len(recvFns))
+ device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
+ device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ batchSize := netc.bind.BatchSize()
+ for _, fn := range recvFns {
+ go device.RoutineReceiveIncoming(batchSize, fn)
}
+ device.log.Verbosef("UDP bind has been updated")
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
- err := unsafeCloseBind(device)
+ err := closeBindLocked(device)
device.net.Unlock()
return err
}
diff --git a/device/device_test.go b/device/device_test.go
index 5ea5410..fff172b 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -1,102 +1,476 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
- "bufio"
"bytes"
- "net"
- "strings"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "math/rand"
+ "net/netip"
+ "os"
+ "runtime"
+ "runtime/pprof"
+ "sync"
+ "sync/atomic"
"testing"
"time"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/conn/bindtest"
+ "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/tuntest"
)
-func TestTwoDevicePing(t *testing.T) {
- // TODO(crawshaw): pick unused ports on localhost
- cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
-listen_port=53511
-replace_peers=true
-public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
-protocol_version=1
-replace_allowed_ips=true
-allowed_ip=1.0.0.2/32
-endpoint=127.0.0.1:53512`
- tun1 := tuntest.NewChannelTUN()
- dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
- dev1.Up()
- defer dev1.Close()
- if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
- t.Fatal(err)
- }
-
- cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
-listen_port=53512
-replace_peers=true
-public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
-protocol_version=1
-replace_allowed_ips=true
-allowed_ip=1.0.0.1/32
-endpoint=127.0.0.1:53511`
- tun2 := tuntest.NewChannelTUN()
- dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
- dev2.Up()
- defer dev2.Close()
- if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
- t.Fatal(err)
+// uapiCfg returns a string that contains cfg formatted use with IpcSet.
+// cfg is a series of alternating key/value strings.
+// uapiCfg exists because editors and humans like to insert
+// whitespace into configs, which can cause failures, some of which are silent.
+// For example, a leading blank newline causes the remainder
+// of the config to be silently ignored.
+func uapiCfg(cfg ...string) string {
+ if len(cfg)%2 != 0 {
+ panic("odd number of args to uapiReader")
+ }
+ buf := new(bytes.Buffer)
+ for i, s := range cfg {
+ buf.WriteString(s)
+ sep := byte('\n')
+ if i%2 == 0 {
+ sep = '='
+ }
+ buf.WriteByte(sep)
}
+ return buf.String()
+}
- t.Run("ping 1.0.0.1", func(t *testing.T) {
- msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
- tun2.Outbound <- msg2to1
+// genConfigs generates a pair of configs that connect to each other.
+// The configs use distinct, probably-usable ports.
+func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
+ var key1, key2 NoisePrivateKey
+ _, err := rand.Read(key1[:])
+ if err != nil {
+ tb.Errorf("unable to generate private key random bytes: %v", err)
+ }
+ _, err = rand.Read(key2[:])
+ if err != nil {
+ tb.Errorf("unable to generate private key random bytes: %v", err)
+ }
+ pub1, pub2 := key1.publicKey(), key2.publicKey()
+
+ cfgs[0] = uapiCfg(
+ "private_key", hex.EncodeToString(key1[:]),
+ "listen_port", "0",
+ "replace_peers", "true",
+ "public_key", hex.EncodeToString(pub2[:]),
+ "protocol_version", "1",
+ "replace_allowed_ips", "true",
+ "allowed_ip", "1.0.0.2/32",
+ )
+ endpointCfgs[0] = uapiCfg(
+ "public_key", hex.EncodeToString(pub2[:]),
+ "endpoint", "127.0.0.1:%d",
+ )
+ cfgs[1] = uapiCfg(
+ "private_key", hex.EncodeToString(key2[:]),
+ "listen_port", "0",
+ "replace_peers", "true",
+ "public_key", hex.EncodeToString(pub1[:]),
+ "protocol_version", "1",
+ "replace_allowed_ips", "true",
+ "allowed_ip", "1.0.0.1/32",
+ )
+ endpointCfgs[1] = uapiCfg(
+ "public_key", hex.EncodeToString(pub1[:]),
+ "endpoint", "127.0.0.1:%d",
+ )
+ return
+}
+
+// A testPair is a pair of testPeers.
+type testPair [2]testPeer
+
+// A testPeer is a peer used for testing.
+type testPeer struct {
+ tun *tuntest.ChannelTUN
+ dev *Device
+ ip netip.Addr
+}
+
+type SendDirection bool
+
+const (
+ Ping SendDirection = true
+ Pong SendDirection = false
+)
+
+func (d SendDirection) String() string {
+ if d == Ping {
+ return "ping"
+ }
+ return "pong"
+}
+
+func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) {
+ tb.Helper()
+ p0, p1 := pair[0], pair[1]
+ if !ping {
+ // pong is the new ping
+ p0, p1 = p1, p0
+ }
+ msg := tuntest.Ping(p0.ip, p1.ip)
+ p1.tun.Outbound <- msg
+ timer := time.NewTimer(5 * time.Second)
+ defer timer.Stop()
+ var err error
+ select {
+ case msgRecv := <-p0.tun.Inbound:
+ if !bytes.Equal(msg, msgRecv) {
+ err = fmt.Errorf("%s did not transit correctly", ping)
+ }
+ case <-timer.C:
+ err = fmt.Errorf("%s did not transit", ping)
+ case <-done:
+ }
+ if err != nil {
+ // The error may have occurred because the test is done.
select {
- case msgRecv := <-tun1.Inbound:
- if !bytes.Equal(msg2to1, msgRecv) {
- t.Error("ping did not transit correctly")
+ case <-done:
+ return
+ default:
+ }
+ // Real error.
+ tb.Error(err)
+ }
+}
+
+// genTestPair creates a testPair.
+func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
+ cfg, endpointCfg := genConfigs(tb)
+ var binds [2]conn.Bind
+ if realSocket {
+ binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
+ } else {
+ binds = bindtest.NewChannelBinds()
+ }
+ // Bring up a ChannelTun for each config.
+ for i := range pair {
+ p := &pair[i]
+ p.tun = tuntest.NewChannelTUN()
+ p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
+ level := LogLevelVerbose
+ if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
+ level = LogLevelError
+ }
+ p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
+ if err := p.dev.IpcSet(cfg[i]); err != nil {
+ tb.Errorf("failed to configure device %d: %v", i, err)
+ p.dev.Close()
+ continue
+ }
+ if err := p.dev.Up(); err != nil {
+ tb.Errorf("failed to bring up device %d: %v", i, err)
+ p.dev.Close()
+ continue
+ }
+ endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port)
+ }
+ for i := range pair {
+ p := &pair[i]
+ if err := p.dev.IpcSet(endpointCfg[i]); err != nil {
+ tb.Errorf("failed to configure device endpoint %d: %v", i, err)
+ p.dev.Close()
+ continue
+ }
+ // The device is ready. Close it when the test completes.
+ tb.Cleanup(p.dev.Close)
+ }
+ return
+}
+
+func TestTwoDevicePing(t *testing.T) {
+ goroutineLeakCheck(t)
+ pair := genTestPair(t, true)
+ t.Run("ping 1.0.0.1", func(t *testing.T) {
+ pair.Send(t, Ping, nil)
+ })
+ t.Run("ping 1.0.0.2", func(t *testing.T) {
+ pair.Send(t, Pong, nil)
+ })
+}
+
+func TestUpDown(t *testing.T) {
+ goroutineLeakCheck(t)
+ const itrials = 50
+ const otrials = 10
+
+ for n := 0; n < otrials; n++ {
+ pair := genTestPair(t, false)
+ for i := range pair {
+ for k := range pair[i].dev.peers.keyMap {
+ pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
+ }
+ }
+ var wg sync.WaitGroup
+ wg.Add(len(pair))
+ for i := range pair {
+ go func(d *Device) {
+ defer wg.Done()
+ for i := 0; i < itrials; i++ {
+ if err := d.Up(); err != nil {
+ t.Errorf("failed up bring up device: %v", err)
+ }
+ time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
+ if err := d.Down(); err != nil {
+ t.Errorf("failed to bring down device: %v", err)
+ }
+ time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
+ }
+ }(pair[i].dev)
+ }
+ wg.Wait()
+ for i := range pair {
+ pair[i].dev.Up()
+ pair[i].dev.Close()
+ }
+ }
+}
+
+// TestConcurrencySafety does other things concurrently with tunnel use.
+// It is intended to be used with the race detector to catch data races.
+func TestConcurrencySafety(t *testing.T) {
+ pair := genTestPair(t, true)
+ done := make(chan struct{})
+
+ const warmupIters = 10
+ var warmup sync.WaitGroup
+ warmup.Add(warmupIters)
+ go func() {
+ // Send data continuously back and forth until we're done.
+ // Note that we may continue to attempt to send data
+ // even after done is closed.
+ i := warmupIters
+ for ping := Ping; ; ping = !ping {
+ pair.Send(t, ping, done)
+ select {
+ case <-done:
+ return
+ default:
+ }
+ if i > 0 {
+ warmup.Done()
+ i--
}
- case <-time.After(300 * time.Millisecond):
- t.Error("ping did not transit")
+ }
+ }()
+ warmup.Wait()
+
+ applyCfg := func(cfg string) {
+ err := pair[0].dev.IpcSet(cfg)
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ // Change persistent_keepalive_interval concurrently with tunnel use.
+ t.Run("persistentKeepaliveInterval", func(t *testing.T) {
+ var pub NoisePublicKey
+ for key := range pair[0].dev.peers.keyMap {
+ pub = key
+ break
+ }
+ cfg := uapiCfg(
+ "public_key", hex.EncodeToString(pub[:]),
+ "persistent_keepalive_interval", "1",
+ )
+ for i := 0; i < 1000; i++ {
+ applyCfg(cfg)
}
})
- t.Run("ping 1.0.0.2", func(t *testing.T) {
- msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
- tun1.Outbound <- msg1to2
- select {
- case msgRecv := <-tun2.Inbound:
- if !bytes.Equal(msg1to2, msgRecv) {
- t.Error("return ping did not transit correctly")
+ // Change private keys concurrently with tunnel use.
+ t.Run("privateKey", func(t *testing.T) {
+ bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
+ good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]))
+ // Set iters to a large number like 1000 to flush out data races quickly.
+ // Don't leave it large. That can cause logical races
+ // in which the handshake is interleaved with key changes
+ // such that the private key appears to be unchanging but
+ // other state gets reset, which can cause handshake failures like
+ // "Received packet with invalid mac1".
+ const iters = 1
+ for i := 0; i < iters; i++ {
+ applyCfg(bad)
+ applyCfg(good)
+ }
+ })
+
+ // Perform bind updates and keepalive sends concurrently with tunnel use.
+ t.Run("bindUpdate and keepalive", func(t *testing.T) {
+ const iters = 10
+ for i := 0; i < iters; i++ {
+ for _, peer := range pair {
+ peer.dev.BindUpdate()
+ peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
}
- case <-time.After(300 * time.Millisecond):
- t.Error("return ping did not transit")
}
})
+
+ close(done)
}
-func assertNil(t *testing.T, err error) {
- if err != nil {
- t.Fatal(err)
+func BenchmarkLatency(b *testing.B) {
+ pair := genTestPair(b, true)
+
+ // Establish a connection.
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
}
}
-func assertEqual(t *testing.T, a, b []byte) {
- if !bytes.Equal(a, b) {
- t.Fatal(a, "!=", b)
+func BenchmarkThroughput(b *testing.B) {
+ pair := genTestPair(b, true)
+
+ // Establish a connection.
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
+
+ // Measure how long it takes to receive b.N packets,
+ // starting when we receive the first packet.
+ var recv atomic.Uint64
+ var elapsed time.Duration
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ var start time.Time
+ for {
+ <-pair[0].tun.Inbound
+ new := recv.Add(1)
+ if new == 1 {
+ start = time.Now()
+ }
+ // Careful! Don't change this to else if; b.N can be equal to 1.
+ if new == uint64(b.N) {
+ elapsed = time.Since(start)
+ return
+ }
+ }
+ }()
+
+ // Send packets as fast as we can until we've received enough.
+ ping := tuntest.Ping(pair[0].ip, pair[1].ip)
+ pingc := pair[1].tun.Outbound
+ var sent uint64
+ for recv.Load() != uint64(b.N) {
+ sent++
+ pingc <- ping
}
+ wg.Wait()
+
+ b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op")
+ b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss")
}
-func randDevice(t *testing.T) *Device {
- sk, err := newPrivateKey()
- if err != nil {
- t.Fatal(err)
+func BenchmarkUAPIGet(b *testing.B) {
+ pair := genTestPair(b, true)
+ pair.Send(b, Ping, nil)
+ pair.Send(b, Pong, nil)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ pair[0].dev.IpcGetOperation(io.Discard)
+ }
+}
+
+func goroutineLeakCheck(t *testing.T) {
+ goroutines := func() (int, []byte) {
+ p := pprof.Lookup("goroutine")
+ b := new(bytes.Buffer)
+ p.WriteTo(b, 1)
+ return p.Count(), b.Bytes()
+ }
+
+ startGoroutines, startStacks := goroutines()
+ t.Cleanup(func() {
+ if t.Failed() {
+ return
+ }
+ // Give goroutines time to exit, if they need it.
+ for i := 0; i < 10000; i++ {
+ if runtime.NumGoroutine() <= startGoroutines {
+ return
+ }
+ time.Sleep(1 * time.Millisecond)
+ }
+ endGoroutines, endStacks := goroutines()
+ t.Logf("starting stacks:\n%s\n", startStacks)
+ t.Logf("ending stacks:\n%s\n", endStacks)
+ t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
+ })
+}
+
+type fakeBindSized struct {
+ size int
+}
+
+func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ return nil, 0, nil
+}
+func (b *fakeBindSized) Close() error { return nil }
+func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
+func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
+func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
+func (b *fakeBindSized) BatchSize() int { return b.size }
+
+type fakeTUNDeviceSized struct {
+ size int
+}
+
+func (t *fakeTUNDeviceSized) File() *os.File { return nil }
+func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
+ return 0, nil
+}
+func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
+func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
+func (t *fakeTUNDeviceSized) Close() error { return nil }
+func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
+
+func TestBatchSize(t *testing.T) {
+ d := Device{}
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 1, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
}
- tun := newDummyTUN("dummy")
- logger := NewLogger(LogLevelError, "")
- device := NewDevice(tun, logger)
- device.SetPrivateKey(sk)
- return device
}
diff --git a/device/devicestate_string.go b/device/devicestate_string.go
new file mode 100644
index 0000000..6577dd4
--- /dev/null
+++ b/device/devicestate_string.go
@@ -0,0 +1,16 @@
+// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT.
+
+package device
+
+import "strconv"
+
+const _deviceState_name = "DownUpClosed"
+
+var _deviceState_index = [...]uint8{0, 4, 6, 12}
+
+func (i deviceState) String() string {
+ if i >= deviceState(len(_deviceState_index)-1) {
+ return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]]
+}
diff --git a/device/endpoint_test.go b/device/endpoint_test.go
index e66d493..93a4998 100644
--- a/device/endpoint_test.go
+++ b/device/endpoint_test.go
@@ -1,53 +1,49 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
- "net"
+ "net/netip"
)
type DummyEndpoint struct {
- src [16]byte
- dst [16]byte
+ src, dst netip.Addr
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
- var end DummyEndpoint
- if _, err := rand.Read(end.src[:]); err != nil {
+ var src, dst [16]byte
+ if _, err := rand.Read(src[:]); err != nil {
return nil, err
}
- _, err := rand.Read(end.dst[:])
- return &end, err
+ _, err := rand.Read(dst[:])
+ return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
- var addr net.UDPAddr
- addr.IP = e.SrcIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.SrcIP(), 1000).String()
}
func (e *DummyEndpoint) DstToString() string {
- var addr net.UDPAddr
- addr.IP = e.DstIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.DstIP(), 1000).String()
}
-func (e *DummyEndpoint) SrcToBytes() []byte {
- return e.src[:]
+func (e *DummyEndpoint) DstToBytes() []byte {
+ out := e.DstIP().AsSlice()
+ out = append(out, byte(1000&0xff))
+ out = append(out, byte((1000>>8)&0xff))
+ return out
}
-func (e *DummyEndpoint) DstIP() net.IP {
- return e.dst[:]
+func (e *DummyEndpoint) DstIP() netip.Addr {
+ return e.dst
}
-func (e *DummyEndpoint) SrcIP() net.IP {
- return e.src[:]
+func (e *DummyEndpoint) SrcIP() netip.Addr {
+ return e.src
}
diff --git a/device/indextable.go b/device/indextable.go
index 5e10eef..00ade7d 100644
--- a/device/indextable.go
+++ b/device/indextable.go
@@ -1,14 +1,14 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"crypto/rand"
+ "encoding/binary"
"sync"
- "unsafe"
)
type IndexTableEntry struct {
@@ -25,7 +25,8 @@ type IndexTable struct {
func randUint32() (uint32, error) {
var integer [4]byte
_, err := rand.Read(integer[:])
- return *(*uint32)(unsafe.Pointer(&integer[0])), err
+ // Arbitrary endianness; both are intrinsified by the Go compiler.
+ return binary.LittleEndian.Uint32(integer[:]), err
}
func (table *IndexTable) Init() {
diff --git a/device/ip.go b/device/ip.go
index 3bc6929..eaf2363 100644
--- a/device/ip.go
+++ b/device/ip.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/kdf_test.go b/device/kdf_test.go
index 1a3bc87..f9c76d6 100644
--- a/device/kdf_test.go
+++ b/device/kdf_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -20,7 +20,7 @@ type KDFTest struct {
t2 string
}
-func assertEquals(t *testing.T, a string, b string) {
+func assertEquals(t *testing.T, a, b string) {
if a != b {
t.Fatal("expected", a, "=", b)
}
diff --git a/device/keypair.go b/device/keypair.go
index 2f2f222..e3540d7 100644
--- a/device/keypair.go
+++ b/device/keypair.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -10,7 +10,6 @@ import (
"sync"
"sync/atomic"
"time"
- "unsafe"
"golang.zx2c4.com/wireguard/replay"
)
@@ -23,10 +22,10 @@ import (
*/
type Keypair struct {
- sendNonce uint64
+ sendNonce atomic.Uint64
send cipher.AEAD
receive cipher.AEAD
- replayFilter replay.ReplayFilter
+ replayFilter replay.Filter
isInitiator bool
created time.Time
localIndex uint32
@@ -37,15 +36,7 @@ type Keypairs struct {
sync.RWMutex
current *Keypair
previous *Keypair
- next *Keypair
-}
-
-func (kp *Keypairs) storeNext(next *Keypair) {
- atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
-}
-
-func (kp *Keypairs) loadNext() *Keypair {
- return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
+ next atomic.Pointer[Keypair]
}
func (kp *Keypairs) Current() *Keypair {
diff --git a/device/logger.go b/device/logger.go
index 3c4d744..22b0df0 100644
--- a/device/logger.go
+++ b/device/logger.go
@@ -1,59 +1,48 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
- "io"
- "io/ioutil"
"log"
"os"
)
+// A Logger provides logging for a Device.
+// The functions are Printf-style functions.
+// They must be safe for concurrent use.
+// They do not require a trailing newline in the format.
+// If nil, that level of logging will be silent.
+type Logger struct {
+ Verbosef func(format string, args ...any)
+ Errorf func(format string, args ...any)
+}
+
+// Log levels for use with NewLogger.
const (
LogLevelSilent = iota
LogLevelError
- LogLevelInfo
- LogLevelDebug
+ LogLevelVerbose
)
-type Logger struct {
- Debug *log.Logger
- Info *log.Logger
- Error *log.Logger
-}
+// Function for use in Logger for discarding logged lines.
+func DiscardLogf(format string, args ...any) {}
+// NewLogger constructs a Logger that writes to stdout.
+// It logs at the specified log level and above.
+// It decorates log lines with the log level, date, time, and prepend.
func NewLogger(level int, prepend string) *Logger {
- output := os.Stdout
- logger := new(Logger)
-
- logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
- if level >= LogLevelDebug {
- return output, output, output
- }
- if level >= LogLevelInfo {
- return output, output, ioutil.Discard
- }
- if level >= LogLevelError {
- return output, ioutil.Discard, ioutil.Discard
- }
- return ioutil.Discard, ioutil.Discard, ioutil.Discard
- }()
-
- logger.Debug = log.New(logDebug,
- "DEBUG: "+prepend,
- log.Ldate|log.Ltime,
- )
-
- logger.Info = log.New(logInfo,
- "INFO: "+prepend,
- log.Ldate|log.Ltime,
- )
- logger.Error = log.New(logErr,
- "ERROR: "+prepend,
- log.Ldate|log.Ltime,
- )
+ logger := &Logger{DiscardLogf, DiscardLogf}
+ logf := func(prefix string) func(string, ...any) {
+ return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
+ }
+ if level >= LogLevelVerbose {
+ logger.Verbosef = logf("DEBUG")
+ }
+ if level >= LogLevelError {
+ logger.Errorf = logf("ERROR")
+ }
return logger
}
diff --git a/device/misc.go b/device/misc.go
deleted file mode 100644
index 30d1156..0000000
--- a/device/misc.go
+++ /dev/null
@@ -1,48 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "sync/atomic"
-)
-
-/* Atomic Boolean */
-
-const (
- AtomicFalse = int32(iota)
- AtomicTrue
-)
-
-type AtomicBool struct {
- int32
-}
-
-func (a *AtomicBool) Get() bool {
- return atomic.LoadInt32(&a.int32) == AtomicTrue
-}
-
-func (a *AtomicBool) Swap(val bool) bool {
- flag := AtomicFalse
- if val {
- flag = AtomicTrue
- }
- return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
-}
-
-func (a *AtomicBool) Set(val bool) {
- flag := AtomicFalse
- if val {
- flag = AtomicTrue
- }
- atomic.StoreInt32(&a.int32, flag)
-}
-
-func min(a, b uint) uint {
- if a > b {
- return b
- }
- return a
-}
diff --git a/device/mobilequirks.go b/device/mobilequirks.go
new file mode 100644
index 0000000..0a0080e
--- /dev/null
+++ b/device/mobilequirks.go
@@ -0,0 +1,19 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
+// though it will try to deal with it, and race maybe, if called after.
+func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
+ device.net.brokenRoaming = true
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.endpoint.Lock()
+ peer.endpoint.disableRoaming = peer.endpoint.val != nil
+ peer.endpoint.Unlock()
+ }
+ device.peers.RUnlock()
+}
diff --git a/device/noise-helpers.go b/device/noise-helpers.go
index b3b5acf..c2f356b 100644
--- a/device/noise-helpers.go
+++ b/device/noise-helpers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -9,6 +9,7 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/subtle"
+ "errors"
"hash"
"golang.org/x/crypto/blake2s"
@@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return
}
-func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
+var errInvalidPublicKey = errors.New("invalid public key")
+
+func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk)
- return ss
+ if isZero(ss[:]) {
+ return ss, errInvalidPublicKey
+ }
+ return ss, nil
}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index be92b4b..e8f6145 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -20,7 +20,6 @@ import (
type handshakeState int
-// TODO(crawshaw): add commentary describing each state and the transitions
const (
handshakeZeroed = handshakeState(iota)
handshakeInitiationCreated
@@ -121,7 +120,7 @@ type Handshake struct {
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
- presharedKey NoiseSymmetricKey // psk
+ presharedKey NoisePresharedKey // psk
localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending
@@ -139,11 +138,11 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte
)
-func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
+func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data)
}
-func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
+func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil)
hash.Write(h[:])
hash.Write(data)
@@ -176,8 +175,6 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
- var errZeroECDHResult = errors.New("ECDH returned all zeros")
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -205,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- if isZero(ss[:]) {
- return nil, errZeroECDHResult
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
}
var key [chacha20poly1305.KeySize]byte
KDF2(
@@ -222,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errZeroECDHResult
+ return nil, errInvalidPublicKey
}
KDF2(
&handshake.chainKey,
@@ -265,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
- var err error
var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- if isZero(ss[:]) {
+ ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
@@ -283,7 +279,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer
peer := device.LookupPeer(peerPK)
- if peer == nil {
+ if peer == nil || !peer.isRunning.Load() {
return nil
}
@@ -319,11 +315,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
handshake.mutex.RUnlock()
if replay {
- device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake replay @ %v\n", peer, timestamp)
+ device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
return nil
}
if flood {
- device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake flood\n", peer)
+ device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
return nil
}
@@ -385,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
- handshake.mixKey(ss[:])
- ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- handshake.mixKey(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
+ ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
// add preshared key
@@ -407,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:])
- func() {
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
- handshake.mixHash(msg.Empty[:])
- }()
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
+ handshake.mixHash(msg.Empty[:])
handshake.state = handshakeResponseCreated
@@ -437,7 +435,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
)
ok := func() bool {
-
// lock handshake state
handshake.mutex.RLock()
@@ -457,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
- func() {
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
// add preshared key (psk)
@@ -485,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
- _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
@@ -566,8 +565,7 @@ func (peer *Peer) BeginSymmetricSession() error {
setZero(recvKey[:])
keypair.created = time.Now()
- keypair.sendNonce = 0
- keypair.replayFilter.Init()
+ keypair.replayFilter.Reset()
keypair.isInitiator = isInitiator
keypair.localIndex = peer.handshake.localIndex
keypair.remoteIndex = peer.handshake.remoteIndex
@@ -584,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.loadNext()
+ next := keypairs.next.Load()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.storeNext(nil)
+ keypairs.next.Store(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@@ -598,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.storeNext(keypair)
+ keypairs.next.Store(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@@ -610,18 +608,18 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
- if keypairs.loadNext() != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.loadNext() != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.loadNext()
- keypairs.storeNext(nil)
+ keypairs.current = keypairs.next.Load()
+ keypairs.next.Store(nil)
return true
}
diff --git a/device/noise-types.go b/device/noise-types.go
index f793ef5..e850359 100644
--- a/device/noise-types.go
+++ b/device/noise-types.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -9,19 +9,18 @@ import (
"crypto/subtle"
"encoding/hex"
"errors"
-
- "golang.org/x/crypto/chacha20poly1305"
)
const (
- NoisePublicKeySize = 32
- NoisePrivateKeySize = 32
+ NoisePublicKeySize = 32
+ NoisePrivateKeySize = 32
+ NoisePresharedKeySize = 32
)
type (
NoisePublicKey [NoisePublicKeySize]byte
NoisePrivateKey [NoisePrivateKeySize]byte
- NoiseSymmetricKey [chacha20poly1305.KeySize]byte
+ NoisePresharedKey [NoisePresharedKeySize]byte
NoiseNonce uint64 // padded to 12-bytes
)
@@ -61,18 +60,10 @@ func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
return
}
-func (key NoisePrivateKey) ToHex() string {
- return hex.EncodeToString(key[:])
-}
-
func (key *NoisePublicKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}
-func (key NoisePublicKey) ToHex() string {
- return hex.EncodeToString(key[:])
-}
-
func (key NoisePublicKey) IsZero() bool {
var zero NoisePublicKey
return key.Equals(zero)
@@ -82,10 +73,6 @@ func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
}
-func (key *NoiseSymmetricKey) FromHex(src string) error {
+func (key *NoisePresharedKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}
-
-func (key NoiseSymmetricKey) ToHex() string {
- return hex.EncodeToString(key[:])
-}
diff --git a/device/noise_test.go b/device/noise_test.go
index ce89851..2dd5324 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -9,6 +9,9 @@ import (
"bytes"
"encoding/binary"
"testing"
+
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/tun/tuntest"
)
func TestCurveWrappers(t *testing.T) {
@@ -21,14 +24,38 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey()
pk2 := sk2.publicKey()
- ss1 := sk1.sharedSecret(pk2)
- ss2 := sk2.sharedSecret(pk1)
+ ss1, err1 := sk1.sharedSecret(pk2)
+ ss2, err2 := sk2.sharedSecret(pk1)
- if ss1 != ss2 {
+ if ss1 != ss2 || err1 != nil || err2 != nil {
t.Fatal("Failed to compute shared secet")
}
}
+func randDevice(t *testing.T) *Device {
+ sk, err := newPrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+ tun := tuntest.NewChannelTUN()
+ logger := NewLogger(LogLevelError, "")
+ device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
+ device.SetPrivateKey(sk)
+ return device
+}
+
+func assertNil(t *testing.T, err error) {
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func assertEqual(t *testing.T, a, b []byte) {
+ if !bytes.Equal(a, b) {
+ t.Fatal(a, "!=", b)
+ }
+}
+
func TestNoiseHandshake(t *testing.T) {
dev1 := randDevice(t)
dev2 := randDevice(t)
@@ -36,8 +63,16 @@ func TestNoiseHandshake(t *testing.T) {
defer dev1.Close()
defer dev2.Close()
- peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
- peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
+ peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
+ if err != nil {
+ t.Fatal(err)
+ }
+ peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
+ if err != nil {
+ t.Fatal(err)
+ }
+ peer1.Start()
+ peer2.Start()
assertEqual(
t,
@@ -113,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}
- key1 := peer1.keypairs.loadNext()
+ key1 := peer1.keypairs.next.Load()
key2 := peer2.keypairs.current
// encrypting / decryption test
diff --git a/device/peer.go b/device/peer.go
index d13acd9..47a2f14 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -1,14 +1,13 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
- "encoding/base64"
+ "container/list"
"errors"
- "fmt"
"sync"
"sync/atomic"
"time"
@@ -16,28 +15,21 @@ import (
"golang.zx2c4.com/wireguard/conn"
)
-const (
- PeerRoutineNumber = 3
-)
-
type Peer struct {
- isRunning AtomicBool
- sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
- keypairs Keypairs
- handshake Handshake
- device *Device
- endpoint conn.Endpoint
- persistentKeepaliveInterval uint16
-
- // These fields are accessed with atomic operations, which must be
- // 64-bit aligned even on 32-bit platforms. Go guarantees that an
- // allocated struct will be 64-bit aligned. So we place
- // atomically-accessed fields up front, so that they can share in
- // this alignment before smaller fields throw it off.
- stats struct {
- txBytes uint64 // bytes send to peer (endpoint)
- rxBytes uint64 // bytes received from peer
- lastHandshakeNano int64 // nano seconds since epoch
+ isRunning atomic.Bool
+ keypairs Keypairs
+ handshake Handshake
+ device *Device
+ stopping sync.WaitGroup // routines pending stop
+ txBytes atomic.Uint64 // bytes send to peer (endpoint)
+ rxBytes atomic.Uint64 // bytes received from peer
+ lastHandshakeNano atomic.Int64 // nano seconds since epoch
+
+ endpoint struct {
+ sync.Mutex
+ val conn.Endpoint
+ clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
+ disableRoaming bool
}
timers struct {
@@ -46,40 +38,32 @@ type Peer struct {
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
- handshakeAttempts uint32
- needAnotherKeepalive AtomicBool
- sentLastMinuteHandshake AtomicBool
+ handshakeAttempts atomic.Uint32
+ needAnotherKeepalive atomic.Bool
+ sentLastMinuteHandshake atomic.Bool
}
- signals struct {
- newKeypairArrived chan struct{}
- flushNonceQueue chan struct{}
+ state struct {
+ sync.Mutex // protects against concurrent Start/Stop
}
queue struct {
- nonce chan *QueueOutboundElement // nonce / pre-handshake queue
- outbound chan *QueueOutboundElement // sequential ordering of work
- inbound chan *QueueInboundElement // sequential ordering of work
- packetInNonceQueueIsAwaitingKey AtomicBool
+ staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
+ outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
+ inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
- routines struct {
- sync.Mutex // held when stopping / starting routines
- starting sync.WaitGroup // routines pending start
- stopping sync.WaitGroup // routines pending stop
- stop chan struct{} // size 0, stop all go routines in peer
- }
-
- cookieGenerator CookieGenerator
+ cookieGenerator CookieGenerator
+ trieEntries list.List
+ persistentKeepaliveInterval atomic.Uint32
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
- if device.isClosed.Get() {
+ if device.isClosed() {
return nil, errors.New("device closed")
}
// lock resources
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -87,131 +71,144 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
defer device.peers.Unlock()
// check if over limit
-
if len(device.peers.keyMap) >= MaxPeers {
return nil, errors.New("too many peers")
}
// create peer
-
peer := new(Peer)
- peer.Lock()
- defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
- peer.isRunning.Set(false)
+ peer.queue.outbound = newAutodrainingOutboundQueue(device)
+ peer.queue.inbound = newAutodrainingInboundQueue(device)
+ peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
-
_, ok := device.peers.keyMap[pk]
if ok {
return nil, errors.New("adding existing peer")
}
// pre-compute DH
-
handshake := &peer.handshake
handshake.mutex.Lock()
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk
handshake.mutex.Unlock()
// reset endpoint
+ peer.endpoint.Lock()
+ peer.endpoint.val = nil
+ peer.endpoint.disableRoaming = false
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.Unlock()
- peer.endpoint = nil
+ // init timers
+ peer.timersInit()
// add
-
device.peers.keyMap[pk] = peer
- // start peer
-
- if peer.device.isUp.Get() {
- peer.Start()
- }
-
return peer, nil
}
-func (peer *Peer) SendBuffer(buffer []byte) error {
+func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
- if peer.device.net.bind == nil {
- return errors.New("no bind")
+ if peer.device.isClosed() {
+ return nil
}
- peer.RLock()
- defer peer.RUnlock()
-
- if peer.endpoint == nil {
+ peer.endpoint.Lock()
+ endpoint := peer.endpoint.val
+ if endpoint == nil {
+ peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
+ if peer.endpoint.clearSrcOnTx {
+ endpoint.ClearSrc()
+ peer.endpoint.clearSrcOnTx = false
+ }
+ peer.endpoint.Unlock()
- err := peer.device.net.bind.Send(buffer, peer.endpoint)
+ err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
- atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
+ var totalLen uint64
+ for _, b := range buffers {
+ totalLen += uint64(len(b))
+ }
+ peer.txBytes.Add(totalLen)
}
return err
}
func (peer *Peer) String() string {
- base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
- abbreviatedKey := "invalid"
- if len(base64Key) == 44 {
- abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
+ // The awful goo that follows is identical to:
+ //
+ // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
+ // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
+ // return fmt.Sprintf("peer(%s)", abbreviatedKey)
+ //
+ // except that it is considerably more efficient.
+ src := peer.handshake.remoteStatic
+ b64 := func(input byte) byte {
+ return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
}
- return fmt.Sprintf("peer(%s)", abbreviatedKey)
+ b := []byte("peer(____…____)")
+ const first = len("peer(")
+ const second = len("peer(____…")
+ b[first+0] = b64((src[0] >> 2) & 63)
+ b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
+ b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
+ b[first+3] = b64(src[2] & 63)
+ b[second+0] = b64(src[29] & 63)
+ b[second+1] = b64((src[30] >> 2) & 63)
+ b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
+ b[second+3] = b64((src[31] << 2) & 63)
+ return string(b)
}
func (peer *Peer) Start() {
-
// should never start a peer on a closed device
-
- if peer.device.isClosed.Get() {
+ if peer.device.isClosed() {
return
}
// prevent simultaneous start/stop operations
+ peer.state.Lock()
+ defer peer.state.Unlock()
- peer.routines.Lock()
- defer peer.routines.Unlock()
-
- if peer.isRunning.Get() {
+ if peer.isRunning.Load() {
return
}
device := peer.device
- device.log.Debug.Println(peer, "- Starting...")
+ device.log.Verbosef("%v - Starting", peer)
// reset routine state
+ peer.stopping.Wait()
+ peer.stopping.Add(2)
- peer.routines.starting.Wait()
- peer.routines.stopping.Wait()
- peer.routines.stop = make(chan struct{})
- peer.routines.starting.Add(PeerRoutineNumber)
- peer.routines.stopping.Add(PeerRoutineNumber)
+ peer.handshake.mutex.Lock()
+ peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
+ peer.handshake.mutex.Unlock()
- // prepare queues
+ peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
- peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
- peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
- peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
+ peer.timersStart()
- peer.timersInit()
- peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
- peer.signals.newKeypairArrived = make(chan struct{}, 1)
- peer.signals.flushNonceQueue = make(chan struct{}, 1)
+ device.flushInboundQueue(peer.queue.inbound)
+ device.flushOutboundQueue(peer.queue.outbound)
- // wait for routines to start
+ // Use the device batch size, not the bind batch size, as the device size is
+ // the size of the batch pools.
+ batchSize := peer.device.BatchSize()
+ go peer.RoutineSequentialSender(batchSize)
+ go peer.RoutineSequentialReceiver(batchSize)
- go peer.RoutineNonce()
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
-
- peer.routines.starting.Wait()
- peer.isRunning.Set(true)
+ peer.isRunning.Store(true)
}
func (peer *Peer) ZeroAndFlushAll() {
@@ -223,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.loadNext())
+ device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil
keypairs.current = nil
- keypairs.storeNext(nil)
+ keypairs.next.Store(nil)
keypairs.Unlock()
// clear handshake state
@@ -237,7 +234,7 @@ func (peer *Peer) ZeroAndFlushAll() {
handshake.Clear()
handshake.mutex.Unlock()
- peer.FlushNonceQueue()
+ peer.FlushStagedPackets()
}
func (peer *Peer) ExpireCurrentKeypairs() {
@@ -245,58 +242,55 @@ func (peer *Peer) ExpireCurrentKeypairs() {
handshake.mutex.Lock()
peer.device.indexTable.Delete(handshake.localIndex)
handshake.Clear()
- handshake.mutex.Unlock()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
+ handshake.mutex.Unlock()
keypairs := &peer.keypairs
keypairs.Lock()
if keypairs.current != nil {
- keypairs.current.sendNonce = RejectAfterMessages
+ keypairs.current.sendNonce.Store(RejectAfterMessages)
}
- if keypairs.next != nil {
- keypairs.loadNext().sendNonce = RejectAfterMessages
+ if next := keypairs.next.Load(); next != nil {
+ next.sendNonce.Store(RejectAfterMessages)
}
keypairs.Unlock()
}
func (peer *Peer) Stop() {
-
- // prevent simultaneous start/stop operations
+ peer.state.Lock()
+ defer peer.state.Unlock()
if !peer.isRunning.Swap(false) {
return
}
- peer.routines.starting.Wait()
-
- peer.routines.Lock()
- defer peer.routines.Unlock()
-
- peer.device.log.Debug.Println(peer, "- Stopping...")
+ peer.device.log.Verbosef("%v - Stopping", peer)
peer.timersStop()
-
- // stop & wait for ongoing peer routines
-
- close(peer.routines.stop)
- peer.routines.stopping.Wait()
-
- // close queues
-
- close(peer.queue.nonce)
- close(peer.queue.outbound)
- close(peer.queue.inbound)
+ // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
+ peer.queue.inbound.c <- nil
+ peer.queue.outbound.c <- nil
+ peer.stopping.Wait()
+ peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
peer.ZeroAndFlushAll()
}
-var RoamingDisabled bool
-
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
- if RoamingDisabled {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.disableRoaming {
+ return
+ }
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.val = endpoint
+}
+
+func (peer *Peer) markEndpointSrcForClearing() {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.val == nil {
return
}
- peer.Lock()
- peer.endpoint = endpoint
- peer.Unlock()
+ peer.endpoint.clearSrcOnTx = true
}
diff --git a/device/peer_test.go b/device/peer_test.go
deleted file mode 100644
index 6aa238b..0000000
--- a/device/peer_test.go
+++ /dev/null
@@ -1,43 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "reflect"
- "testing"
- "unsafe"
-)
-
-func checkAlignment(t *testing.T, name string, offset uintptr) {
- t.Helper()
- if offset%8 != 0 {
- t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
- }
-}
-
-// TestPeerAlignment checks that atomically-accessed fields are
-// aligned to 64-bit boundaries, as required by the atomic package.
-//
-// Unfortunately, violating this rule on 32-bit platforms results in a
-// hard segfault at runtime.
-func TestPeerAlignment(t *testing.T) {
- var p Peer
-
- typ := reflect.TypeOf(p)
- t.Logf("Peer type size: %d, with fields:", typ.Size())
- for i := 0; i < typ.NumField(); i++ {
- field := typ.Field(i)
- t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
- field.Name,
- field.Offset,
- field.Type.Size(),
- field.Type.Align(),
- )
- }
-
- checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
- checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
-}
diff --git a/device/pools.go b/device/pools.go
index e778d2e..94f3dc7 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -1,89 +1,120 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
-import "sync"
+import (
+ "sync"
+ "sync/atomic"
+)
-func (device *Device) PopulatePools() {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.messageBufferPool = &sync.Pool{
- New: func() interface{} {
- return new([MaxMessageSize]byte)
- },
- }
- device.pool.inboundElementPool = &sync.Pool{
- New: func() interface{} {
- return new(QueueInboundElement)
- },
- }
- device.pool.outboundElementPool = &sync.Pool{
- New: func() interface{} {
- return new(QueueOutboundElement)
- },
- }
- } else {
- device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
- device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
- }
- device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
- device.pool.inboundElementReuseChan <- new(QueueInboundElement)
- }
- device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
- device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
+type WaitPool struct {
+ pool sync.Pool
+ cond sync.Cond
+ lock sync.Mutex
+ count atomic.Uint32
+ max uint32
+}
+
+func NewWaitPool(max uint32, new func() any) *WaitPool {
+ p := &WaitPool{pool: sync.Pool{New: new}, max: max}
+ p.cond = sync.Cond{L: &p.lock}
+ return p
+}
+
+func (p *WaitPool) Get() any {
+ if p.max != 0 {
+ p.lock.Lock()
+ for p.count.Load() >= p.max {
+ p.cond.Wait()
}
+ p.count.Add(1)
+ p.lock.Unlock()
}
+ return p.pool.Get()
}
-func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
- } else {
- return <-device.pool.messageBufferReuseChan
+func (p *WaitPool) Put(x any) {
+ p.pool.Put(x)
+ if p.max == 0 {
+ return
}
+ p.count.Add(^uint32(0))
+ p.cond.Signal()
}
-func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.messageBufferPool.Put(msg)
- } else {
- device.pool.messageBufferReuseChan <- msg
- }
+func (device *Device) PopulatePools() {
+ device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueInboundElement, 0, device.BatchSize())
+ return &QueueInboundElementsContainer{elems: s}
+ })
+ device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueOutboundElement, 0, device.BatchSize())
+ return &QueueOutboundElementsContainer{elems: s}
+ })
+ device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ return new([MaxMessageSize]byte)
+ })
+ device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ return new(QueueInboundElement)
+ })
+ device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ return new(QueueOutboundElement)
+ })
}
-func (device *Device) GetInboundElement() *QueueInboundElement {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.inboundElementPool.Get().(*QueueInboundElement)
- } else {
- return <-device.pool.inboundElementReuseChan
+func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
+ c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
+}
+
+func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
}
+ c.elems = c.elems[:0]
+ device.pool.inboundElementsContainer.Put(c)
+}
+
+func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
+ c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
}
-func (device *Device) PutInboundElement(msg *QueueInboundElement) {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.inboundElementPool.Put(msg)
- } else {
- device.pool.inboundElementReuseChan <- msg
+func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
}
+ c.elems = c.elems[:0]
+ device.pool.outboundElementsContainer.Put(c)
+}
+
+func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
+ return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
+}
+
+func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
+ device.pool.messageBuffers.Put(msg)
+}
+
+func (device *Device) GetInboundElement() *QueueInboundElement {
+ return device.pool.inboundElements.Get().(*QueueInboundElement)
+}
+
+func (device *Device) PutInboundElement(elem *QueueInboundElement) {
+ elem.clearPointers()
+ device.pool.inboundElements.Put(elem)
}
func (device *Device) GetOutboundElement() *QueueOutboundElement {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
- } else {
- return <-device.pool.outboundElementReuseChan
- }
+ return device.pool.outboundElements.Get().(*QueueOutboundElement)
}
-func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.outboundElementPool.Put(msg)
- } else {
- device.pool.outboundElementReuseChan <- msg
- }
+func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
+ elem.clearPointers()
+ device.pool.outboundElements.Put(elem)
}
diff --git a/device/pools_test.go b/device/pools_test.go
new file mode 100644
index 0000000..82d7493
--- /dev/null
+++ b/device/pools_test.go
@@ -0,0 +1,139 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "math/rand"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestWaitPool(t *testing.T) {
+ t.Skip("Currently disabled")
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ startTrials := int32(100000)
+ if raceEnabled {
+ // This test can be very slow with -race.
+ startTrials /= 10
+ }
+ trials.Store(startTrials)
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ t.Skip("Not enough cores")
+ }
+ p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ var max atomic.Uint32
+ updateMax := func() {
+ count := p.count.Load()
+ if count > p.max {
+ t.Errorf("count (%d) > max (%d)", count, p.max)
+ }
+ for {
+ old := max.Load()
+ if count <= old {
+ break
+ }
+ if max.CompareAndSwap(old, count) {
+ break
+ }
+ }
+ }
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ updateMax()
+ x := p.Get()
+ updateMax()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ updateMax()
+ p.Put(x)
+ updateMax()
+ }
+ }()
+ }
+ wg.Wait()
+ if max.Load() != p.max {
+ t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
+ }
+}
+
+func BenchmarkWaitPool(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkWaitPoolEmpty(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := NewWaitPool(0, func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkSyncPool(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := sync.Pool{New: func() any { return make([]byte, 16) }}
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go
index f19c7be..25f700a 100644
--- a/device/queueconstants_android.go
+++ b/device/queueconstants_android.go
@@ -1,16 +1,19 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+import "golang.zx2c4.com/wireguard/conn"
+
/* Reduce memory consumption for Android */
const (
+ QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
- MaxSegmentSize = 2200
+ MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 4096
)
diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go
index 18f0bea..ea763d0 100644
--- a/device/queueconstants_default.go
+++ b/device/queueconstants_default.go
@@ -1,13 +1,16 @@
-// +build !android,!ios
+//go:build !android && !ios && !windows
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+import "golang.zx2c4.com/wireguard/conn"
+
const (
+ QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go
index 4c83015..acd3cec 100644
--- a/device/queueconstants_ios.go
+++ b/device/queueconstants_ios.go
@@ -1,18 +1,21 @@
-// +build ios
+//go:build ios
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
-/* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */
-
-const (
- QueueOutboundSize = 1024
- QueueInboundSize = 1024
- QueueHandshakeSize = 1024
- MaxSegmentSize = 1700
- PreallocatedBuffersPerPool = 1024
+// Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
+// These are vars instead of consts, because heavier network extensions might want to reduce
+// them further.
+var (
+ QueueStagedSize = 128
+ QueueOutboundSize = 1024
+ QueueInboundSize = 1024
+ QueueHandshakeSize = 1024
+ PreallocatedBuffersPerPool uint32 = 1024
)
+
+const MaxSegmentSize = 1700
diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go
new file mode 100644
index 0000000..1eee32b
--- /dev/null
+++ b/device/queueconstants_windows.go
@@ -0,0 +1,15 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+const (
+ QueueStagedSize = 128
+ QueueOutboundSize = 1024
+ QueueInboundSize = 1024
+ QueueHandshakeSize = 1024
+ MaxSegmentSize = 2048 - 32 // largest possible UDP datagram
+ PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
+)
diff --git a/device/race_disabled_test.go b/device/race_disabled_test.go
new file mode 100644
index 0000000..bb5c450
--- /dev/null
+++ b/device/race_disabled_test.go
@@ -0,0 +1,10 @@
+//go:build !race
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+const raceEnabled = false
diff --git a/device/race_enabled_test.go b/device/race_enabled_test.go
new file mode 100644
index 0000000..4e9daea
--- /dev/null
+++ b/device/race_enabled_test.go
@@ -0,0 +1,10 @@
+//go:build race
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+const raceEnabled = true
diff --git a/device/receive.go b/device/receive.go
index b53c9c0..1ab3e29 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,10 +8,9 @@ package device
import (
"bytes"
"encoding/binary"
+ "errors"
"net"
- "strconv"
"sync"
- "sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
@@ -28,8 +27,6 @@ type QueueHandshakeElement struct {
}
type QueueInboundElement struct {
- dropped int32
- sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
@@ -37,38 +34,20 @@ type QueueInboundElement struct {
endpoint conn.Endpoint
}
-func (elem *QueueInboundElement) Drop() {
- atomic.StoreInt32(&elem.dropped, AtomicTrue)
-}
-
-func (elem *QueueInboundElement) IsDropped() bool {
- return atomic.LoadInt32(&elem.dropped) == AtomicTrue
-}
-
-func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool {
- select {
- case inboundQueue <- element:
- select {
- case decryptionQueue <- element:
- return true
- default:
- element.Drop()
- element.Unlock()
- return false
- }
- default:
- device.PutInboundElement(element)
- return false
- }
+type QueueInboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueInboundElement
}
-func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool {
- select {
- case queue <- element:
- return true
- default:
- return false
- }
+// clearPointers clears elem fields that contain pointers.
+// This makes the garbage collector's life easier and
+// avoids accidentally keeping other objects around unnecessarily.
+// It also reduces the possible collateral damage from use-after-free bugs.
+func (elem *QueueInboundElement) clearPointers() {
+ elem.buffer = nil
+ elem.packet = nil
+ elem.keypair = nil
+ elem.endpoint = nil
}
/* Called when a new authenticated message has been received
@@ -76,12 +55,12 @@ func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
- if peer.timers.sentLastMinuteHandshake.Get() {
+ if peer.timers.sentLastMinuteHandshake.Load() {
return
}
keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
- peer.timers.sentLastMinuteHandshake.Set(true)
+ peer.timers.sentLastMinuteHandshake.Store(true)
peer.SendHandshakeInitiation(false)
}
}
@@ -91,188 +70,189 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
-
- logDebug := device.log.Debug
+func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
+ recvName := recv.PrettyName()
defer func() {
- logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
+ device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
+ device.queue.decryption.wg.Done()
+ device.queue.handshake.wg.Done()
device.net.stopping.Done()
}()
- logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started")
- device.net.starting.Done()
+ device.log.Verbosef("Routine: receive incoming %s - started", recvName)
// receive datagrams until conn is closed
- buffer := device.GetMessageBuffer()
-
var (
- err error
- size int
- endpoint conn.Endpoint
+ bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
+ bufs = make([][]byte, maxBatchSize)
+ err error
+ sizes = make([]int, maxBatchSize)
+ count int
+ endpoints = make([]conn.Endpoint, maxBatchSize)
+ deathSpiral int
+ elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
- for {
-
- // read next datagram
+ for i := range bufsArrs {
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
+ }
- switch IP {
- case ipv4.Version:
- size, endpoint, err = bind.ReceiveIPv4(buffer[:])
- case ipv6.Version:
- size, endpoint, err = bind.ReceiveIPv6(buffer[:])
- default:
- panic("invalid IP version")
+ defer func() {
+ for i := 0; i < maxBatchSize; i++ {
+ if bufsArrs[i] != nil {
+ device.PutMessageBuffer(bufsArrs[i])
+ }
}
+ }()
+ for {
+ count, err = recv(bufs, sizes, endpoints)
if err != nil {
- device.PutMessageBuffer(buffer)
+ if errors.Is(err, net.ErrClosed) {
+ return
+ }
+ device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
+ if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
+ return
+ }
+ if deathSpiral < 10 {
+ deathSpiral++
+ time.Sleep(time.Second / 3)
+ continue
+ }
return
}
+ deathSpiral = 0
- if size < MinMessageSize {
- continue
- }
-
- // check size of packet
+ // handle each packet in the batch
+ for i, size := range sizes[:count] {
+ if size < MinMessageSize {
+ continue
+ }
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ // check size of packet
- var okay bool
+ packet := bufsArrs[i][:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- switch msgType {
+ switch msgType {
- // check if transport
+ // check if transport
- case MessageTransportType:
+ case MessageTransportType:
- // check size
+ // check size
- if len(packet) < MessageTransportSize {
- continue
- }
+ if len(packet) < MessageTransportSize {
+ continue
+ }
- // lookup key pair
+ // lookup key pair
- receiver := binary.LittleEndian.Uint32(
- packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
- )
- value := device.indexTable.Lookup(receiver)
- keypair := value.keypair
- if keypair == nil {
- continue
- }
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indexTable.Lookup(receiver)
+ keypair := value.keypair
+ if keypair == nil {
+ continue
+ }
- // check keypair expiry
+ // check keypair expiry
- if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
- continue
- }
+ if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
- // create work element
- peer := value.peer
- elem := device.GetInboundElement()
- elem.packet = packet
- elem.buffer = buffer
- elem.keypair = keypair
- elem.dropped = AtomicFalse
- elem.endpoint = endpoint
- elem.counter = 0
- elem.Mutex = sync.Mutex{}
- elem.Lock()
-
- // add to decryption queues
-
- if peer.isRunning.Get() {
- if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
- buffer = device.GetMessageBuffer()
+ // create work element
+ peer := value.peer
+ elem := device.GetInboundElement()
+ elem.packet = packet
+ elem.buffer = bufsArrs[i]
+ elem.keypair = keypair
+ elem.endpoint = endpoints[i]
+ elem.counter = 0
+
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetInboundElementsContainer()
+ elemsForPeer.Lock()
+ elemsByPeer[peer] = elemsForPeer
}
- }
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
+ continue
- continue
+ // otherwise it is a fixed size & handshake related packet
- // otherwise it is a fixed size & handshake related packet
+ case MessageInitiationType:
+ if len(packet) != MessageInitiationSize {
+ continue
+ }
- case MessageInitiationType:
- okay = len(packet) == MessageInitiationSize
+ case MessageResponseType:
+ if len(packet) != MessageResponseSize {
+ continue
+ }
- case MessageResponseType:
- okay = len(packet) == MessageResponseSize
+ case MessageCookieReplyType:
+ if len(packet) != MessageCookieReplySize {
+ continue
+ }
- case MessageCookieReplyType:
- okay = len(packet) == MessageCookieReplySize
+ default:
+ device.log.Verbosef("Received message with unknown type")
+ continue
+ }
- default:
- logDebug.Println("Received message with unknown type")
+ select {
+ case device.queue.handshake.c <- QueueHandshakeElement{
+ msgType: msgType,
+ buffer: bufsArrs[i],
+ packet: packet,
+ endpoint: endpoints[i],
+ }:
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
+ default:
+ }
}
-
- if okay {
- if (device.addToHandshakeQueue(
- device.queue.handshake,
- QueueHandshakeElement{
- msgType: msgType,
- buffer: buffer,
- packet: packet,
- endpoint: endpoint,
- },
- )) {
- buffer = device.GetMessageBuffer()
+ for peer, elemsContainer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.queue.inbound.c <- elemsContainer
+ device.queue.decryption.c <- elemsContainer
+ } else {
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsContainer(elemsContainer)
}
+ delete(elemsByPeer, peer)
}
}
}
-func (device *Device) RoutineDecryption() {
-
+func (device *Device) RoutineDecryption(id int) {
var nonce [chacha20poly1305.NonceSize]byte
- logDebug := device.log.Debug
- defer func() {
- logDebug.Println("Routine: decryption worker - stopped")
- device.state.stopping.Done()
- }()
- logDebug.Println("Routine: decryption worker - started")
- device.state.starting.Done()
-
- for {
- select {
- case <-device.signals.stop:
- return
-
- case elem, ok := <-device.queue.decryption:
-
- if !ok {
- return
- }
-
- // check if dropped
-
- if elem.IsDropped() {
- continue
- }
+ defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
+ device.log.Verbosef("Routine: decryption worker %d - started", id)
+ for elemsContainer := range device.queue.decryption.c {
+ for _, elem := range elemsContainer.elems {
// split message into fields
-
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
- // expand nonce
-
- nonce[0x4] = counter[0x0]
- nonce[0x5] = counter[0x1]
- nonce[0x6] = counter[0x2]
- nonce[0x7] = counter[0x3]
-
- nonce[0x8] = counter[0x4]
- nonce[0x9] = counter[0x5]
- nonce[0xa] = counter[0x6]
- nonce[0xb] = counter[0x7]
-
// decrypt and release to consumer
-
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
+ // copy counter to nonce
+ binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
@@ -280,51 +260,23 @@ func (device *Device) RoutineDecryption() {
nil,
)
if err != nil {
- elem.Drop()
- device.PutMessageBuffer(elem.buffer)
+ elem.packet = nil
}
- elem.Unlock()
}
+ elemsContainer.Unlock()
}
}
/* Handles incoming packets related to handshake
*/
-func (device *Device) RoutineHandshake() {
-
- logInfo := device.log.Info
- logError := device.log.Error
- logDebug := device.log.Debug
-
- var elem QueueHandshakeElement
- var ok bool
-
+func (device *Device) RoutineHandshake(id int) {
defer func() {
- logDebug.Println("Routine: handshake worker - stopped")
- device.state.stopping.Done()
- if elem.buffer != nil {
- device.PutMessageBuffer(elem.buffer)
- }
+ device.log.Verbosef("Routine: handshake worker %d - stopped", id)
+ device.queue.encryption.wg.Done()
}()
+ device.log.Verbosef("Routine: handshake worker %d - started", id)
- logDebug.Println("Routine: handshake worker - started")
- device.state.starting.Done()
-
- for {
- if elem.buffer != nil {
- device.PutMessageBuffer(elem.buffer)
- elem.buffer = nil
- }
-
- select {
- case elem, ok = <-device.queue.handshake:
- case <-device.signals.stop:
- return
- }
-
- if !ok {
- return
- }
+ for elem := range device.queue.handshake.c {
// handle cookie fields and ratelimiting
@@ -338,8 +290,8 @@ func (device *Device) RoutineHandshake() {
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil {
- logDebug.Println("Failed to decode cookie reply")
- return
+ device.log.Verbosef("Failed to decode cookie reply")
+ goto skip
}
// lookup peer from index
@@ -347,27 +299,27 @@ func (device *Device) RoutineHandshake() {
entry := device.indexTable.Lookup(reply.Receiver)
if entry.peer == nil {
- continue
+ goto skip
}
// consume reply
- if peer := entry.peer; peer.isRunning.Get() {
- logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString())
+ if peer := entry.peer; peer.isRunning.Load() {
+ device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
if !peer.cookieGenerator.ConsumeReply(&reply) {
- logDebug.Println("Could not decrypt invalid cookie response")
+ device.log.Verbosef("Could not decrypt invalid cookie response")
}
}
- continue
+ goto skip
case MessageInitiationType, MessageResponseType:
// check mac fields and maybe ratelimit
if !device.cookieChecker.CheckMAC1(elem.packet) {
- logDebug.Println("Received packet with invalid mac1")
- continue
+ device.log.Verbosef("Received packet with invalid mac1")
+ goto skip
}
// endpoints destination address is the source of the datagram
@@ -378,19 +330,19 @@ func (device *Device) RoutineHandshake() {
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
device.SendHandshakeCookie(&elem)
- continue
+ goto skip
}
// check ratelimiter
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
- continue
+ goto skip
}
}
default:
- logError.Println("Invalid packet ended up in the handshake queue")
- continue
+ device.log.Errorf("Invalid packet ended up in the handshake queue")
+ goto skip
}
// handle handshake initiation/response content
@@ -404,19 +356,16 @@ func (device *Device) RoutineHandshake() {
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
- logError.Println("Failed to decode initiation message")
- continue
+ device.log.Errorf("Failed to decode initiation message")
+ goto skip
}
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
- logInfo.Println(
- "Received invalid initiation message from",
- elem.endpoint.DstToString(),
- )
- continue
+ device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
+ goto skip
}
// update timers
@@ -427,8 +376,8 @@ func (device *Device) RoutineHandshake() {
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
- logDebug.Println(peer, "- Received handshake initiation")
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ device.log.Verbosef("%v - Received handshake initiation", peer)
+ peer.rxBytes.Add(uint64(len(elem.packet)))
peer.SendHandshakeResponse()
@@ -440,26 +389,23 @@ func (device *Device) RoutineHandshake() {
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
- logError.Println("Failed to decode response message")
- continue
+ device.log.Errorf("Failed to decode response message")
+ goto skip
}
// consume response
peer := device.ConsumeMessageResponse(&msg)
if peer == nil {
- logInfo.Println(
- "Received invalid response message from",
- elem.endpoint.DstToString(),
- )
- continue
+ device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
+ goto skip
}
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
- logDebug.Println(peer, "- Received handshake response")
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ device.log.Verbosef("%v - Received handshake response", peer)
+ peer.rxBytes.Add(uint64(len(elem.packet)))
// update timers
@@ -471,178 +417,124 @@ func (device *Device) RoutineHandshake() {
err = peer.BeginSymmetricSession()
if err != nil {
- logError.Println(peer, "- Failed to derive keypair:", err)
- continue
+ device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
+ goto skip
}
peer.timersSessionDerived()
peer.timersHandshakeComplete()
peer.SendKeepalive()
- select {
- case peer.signals.newKeypairArrived <- struct{}{}:
- default:
- }
}
+ skip:
+ device.PutMessageBuffer(elem.buffer)
}
}
-func (peer *Peer) RoutineSequentialReceiver() {
-
+func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
- logInfo := device.log.Info
- logError := device.log.Error
- logDebug := device.log.Debug
-
- var elem *QueueInboundElement
-
defer func() {
- logDebug.Println(peer, "- Routine: sequential receiver - stopped")
- peer.routines.stopping.Done()
- if elem != nil {
- if !elem.IsDropped() {
- device.PutMessageBuffer(elem.buffer)
- }
- device.PutInboundElement(elem)
- }
+ device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
+ peer.stopping.Done()
}()
+ device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
- logDebug.Println(peer, "- Routine: sequential receiver - started")
-
- peer.routines.starting.Done()
-
- for {
- if elem != nil {
- if !elem.IsDropped() {
- device.PutMessageBuffer(elem.buffer)
- }
- device.PutInboundElement(elem)
- elem = nil
- }
+ bufs := make([][]byte, 0, maxBatchSize)
- var elemOk bool
- select {
- case <-peer.routines.stop:
+ for elemsContainer := range peer.queue.inbound.c {
+ if elemsContainer == nil {
return
- case elem, elemOk = <-peer.queue.inbound:
- if !elemOk {
- return
- }
- }
-
- // wait for decryption
-
- elem.Lock()
-
- if elem.IsDropped() {
- continue
- }
-
- // check for replay
-
- if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
- continue
- }
-
- // update endpoint
- peer.SetEndpointFromPacket(elem.endpoint)
-
- // check if using new keypair
- if peer.ReceivedWithKeypair(elem.keypair) {
- peer.timersHandshakeComplete()
- select {
- case peer.signals.newKeypairArrived <- struct{}{}:
- default:
- }
- }
-
- peer.keepKeyFreshReceiving()
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketReceived()
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
-
- // check for keepalive
-
- if len(elem.packet) == 0 {
- logDebug.Println(peer, "- Receiving keepalive packet")
- continue
}
- peer.timersDataReceived()
-
- // verify source and strip padding
-
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
-
- // strip padding
-
- if len(elem.packet) < ipv4.HeaderLen {
+ elemsContainer.Lock()
+ validTailPacket := -1
+ dataPacketReceived := false
+ rxBytesLen := uint64(0)
+ for i, elem := range elemsContainer.elems {
+ if elem.packet == nil {
+ // decryption failed
continue
}
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue
}
- elem.packet = elem.packet[:length]
-
- // verify IPv4 source
-
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.allowedips.LookupIPv4(src) != peer {
- logInfo.Println(
- "IPv4 packet with disallowed source address from",
- peer,
- )
- continue
+ validTailPacket = i
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.SetEndpointFromPacket(elem.endpoint)
+ peer.timersHandshakeComplete()
+ peer.SendStagedPackets()
}
+ rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
- case ipv6.Version:
-
- // strip padding
-
- if len(elem.packet) < ipv6.HeaderLen {
+ if len(elem.packet) == 0 {
+ device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
+ dataPacketReceived = true
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- if int(length) > len(elem.packet) {
- continue
- }
-
- elem.packet = elem.packet[:length]
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
+ continue
+ }
- // verify IPv6 source
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
+ continue
+ }
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.allowedips.LookupIPv6(src) != peer {
- logInfo.Println(
- "IPv6 packet with disallowed source address from",
- peer,
- )
+ default:
+ device.log.Verbosef("Packet with invalid IP version from %v", peer)
continue
}
- default:
- logInfo.Println("Packet with invalid IP version from", peer)
- continue
+ bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
- // write to tun device
-
- offset := MessageTransportOffsetContent
- _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
- if len(peer.queue.inbound) == 0 {
- err = device.tun.device.Flush()
- if err != nil {
- peer.device.log.Error.Printf("Unable to flush packets: %v", err)
+ peer.rxBytes.Add(rxBytesLen)
+ if validTailPacket >= 0 {
+ peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
+ }
+ if dataPacketReceived {
+ peer.timersDataReceived()
+ }
+ if len(bufs) > 0 {
+ _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
+ if err != nil && !device.isClosed() {
+ device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
- if err != nil && !device.isClosed.Get() {
- logError.Println("Failed to write packet to TUN device:", err)
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
}
+ bufs = bufs[:0]
+ device.PutInboundElementsContainer(elemsContainer)
}
}
diff --git a/device/send.go b/device/send.go
index c0bdba3..769720a 100644
--- a/device/send.go
+++ b/device/send.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,14 +8,17 @@ package device
import (
"bytes"
"encoding/binary"
+ "errors"
"net"
+ "os"
"sync"
- "sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/tun"
)
/* Outbound flow
@@ -43,8 +46,6 @@ import (
*/
type QueueOutboundElement struct {
- dropped int32
- sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
@@ -52,80 +53,52 @@ type QueueOutboundElement struct {
peer *Peer // related peer
}
+type QueueOutboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueOutboundElement
+}
+
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
- elem.dropped = AtomicFalse
elem.buffer = device.GetMessageBuffer()
- elem.Mutex = sync.Mutex{}
elem.nonce = 0
- elem.keypair = nil
- elem.peer = nil
+ // keypair and peer were cleared (if necessary) by clearPointers.
return elem
}
-func (elem *QueueOutboundElement) Drop() {
- atomic.StoreInt32(&elem.dropped, AtomicTrue)
-}
-
-func (elem *QueueOutboundElement) IsDropped() bool {
- return atomic.LoadInt32(&elem.dropped) == AtomicTrue
-}
-
-func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) {
- for {
- select {
- case queue <- element:
- return
- default:
- select {
- case old := <-queue:
- device.PutMessageBuffer(old.buffer)
- device.PutOutboundElement(old)
- default:
- }
- }
- }
+// clearPointers clears elem fields that contain pointers.
+// This makes the garbage collector's life easier and
+// avoids accidentally keeping other objects around unnecessarily.
+// It also reduces the possible collateral damage from use-after-free bugs.
+func (elem *QueueOutboundElement) clearPointers() {
+ elem.buffer = nil
+ elem.packet = nil
+ elem.keypair = nil
+ elem.peer = nil
}
-func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) {
- select {
- case outboundQueue <- element:
+/* Queues a keepalive if no packets are queued for peer
+ */
+func (peer *Peer) SendKeepalive() {
+ if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
+ elem := peer.device.NewOutboundElement()
+ elemsContainer := peer.device.GetOutboundElementsContainer()
+ elemsContainer.elems = append(elemsContainer.elems, elem)
select {
- case encryptionQueue <- element:
- return
+ case peer.queue.staged <- elemsContainer:
+ peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
- element.Drop()
- element.peer.device.PutMessageBuffer(element.buffer)
- element.Unlock()
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ peer.device.PutOutboundElementsContainer(elemsContainer)
}
- default:
- element.peer.device.PutMessageBuffer(element.buffer)
- element.peer.device.PutOutboundElement(element)
- }
-}
-
-/* Queues a keepalive if no packets are queued for peer
- */
-func (peer *Peer) SendKeepalive() bool {
- if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
- return false
- }
- elem := peer.device.NewOutboundElement()
- elem.packet = nil
- select {
- case peer.queue.nonce <- elem:
- peer.device.log.Debug.Println(peer, "- Sending keepalive packet")
- return true
- default:
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
- return false
}
+ peer.SendStagedPackets()
}
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry {
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
+ peer.timers.handshakeAttempts.Store(0)
}
peer.handshake.mutex.RLock()
@@ -143,16 +116,16 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock()
- peer.device.log.Debug.Println(peer, "- Sending handshake initiation")
+ peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
- peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err)
+ peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
return err
}
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
@@ -160,9 +133,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
- peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err)
+ peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
peer.timersHandshakeInitiated()
@@ -174,23 +147,23 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock()
- peer.device.log.Debug.Println(peer, "- Sending handshake response")
+ peer.device.log.Verbosef("%v - Sending handshake response", peer)
response, err := peer.device.CreateMessageResponse(peer)
if err != nil {
- peer.device.log.Error.Println(peer, "- Failed to create response message:", err)
+ peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
return err
}
- var buff [MessageResponseSize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageResponseSize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
err = peer.BeginSymmetricSession()
if err != nil {
- peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err)
+ peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
return err
}
@@ -198,28 +171,29 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ // TODO: allocation could be avoided
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
- peer.device.log.Error.Println(peer, "- Failed to send handshake response", err)
+ peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
return err
}
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
-
- device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString())
+ device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
if err != nil {
- device.log.Error.Println("Failed to create cookie reply:", err)
+ device.log.Errorf("Failed to create cookie reply: %v", err)
return err
}
- var buff [MessageCookieReplySize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageCookieReplySize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
- device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
+ // TODO: allocation could be avoided
+ device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil
}
@@ -228,222 +202,221 @@ func (peer *Peer) keepKeyFreshSending() {
if keypair == nil {
return
}
- nonce := atomic.LoadUint64(&keypair.sendNonce)
+ nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
}
-/* Reads packets from the TUN and inserts
- * into nonce queue for peer
- *
- * Obs. Single instance per TUN device
- */
func (device *Device) RoutineReadFromTUN() {
-
- logDebug := device.log.Debug
- logError := device.log.Error
-
defer func() {
- logDebug.Println("Routine: TUN reader - stopped")
+ device.log.Verbosef("Routine: TUN reader - stopped")
device.state.stopping.Done()
+ device.queue.encryption.wg.Done()
}()
- logDebug.Println("Routine: TUN reader - started")
- device.state.starting.Done()
-
- var elem *QueueOutboundElement
+ device.log.Verbosef("Routine: TUN reader - started")
+
+ var (
+ batchSize = device.BatchSize()
+ readErr error
+ elems = make([]*QueueOutboundElement, batchSize)
+ bufs = make([][]byte, batchSize)
+ elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
+ count = 0
+ sizes = make([]int, batchSize)
+ offset = MessageTransportHeaderSize
+ )
+
+ for i := range elems {
+ elems[i] = device.NewOutboundElement()
+ bufs[i] = elems[i].buffer[:]
+ }
- for {
- if elem != nil {
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ defer func() {
+ for _, elem := range elems {
+ if elem != nil {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
}
- elem = device.NewOutboundElement()
-
- // read packet
-
- offset := MessageTransportHeaderSize
- size, err := device.tun.device.Read(elem.buffer[:], offset)
+ }()
- if err != nil {
- if !device.isClosed.Get() {
- logError.Println("Failed to read packet from TUN device:", err)
- device.Close()
+ for {
+ // read packets
+ count, readErr = device.tun.device.Read(bufs, sizes, offset)
+ for i := 0; i < count; i++ {
+ if sizes[i] < 1 {
+ continue
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- return
- }
- if size == 0 || size > MaxContentSize {
- continue
- }
+ elem := elems[i]
+ elem.packet = bufs[i][offset : offset+sizes[i]]
- elem.packet = elem.buffer[offset : offset+size]
+ // lookup peer
+ var peer *Peer
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer = device.allowedips.Lookup(dst)
- // lookup peer
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer = device.allowedips.Lookup(dst)
- var peer *Peer
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- continue
+ default:
+ device.log.Verbosef("Received packet with unknown IP version")
}
- dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.allowedips.LookupIPv4(dst)
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
+ if peer == nil {
continue
}
- dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.allowedips.LookupIPv6(dst)
-
- default:
- logDebug.Println("Received packet with unknown IP version")
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetOutboundElementsContainer()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
+ elems[i] = device.NewOutboundElement()
+ bufs[i] = elems[i].buffer[:]
}
- if peer == nil {
- continue
+ for peer, elemsForPeer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.StagePackets(elemsForPeer)
+ peer.SendStagedPackets()
+ } else {
+ for _, elem := range elemsForPeer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsForPeer)
+ }
+ delete(elemsByPeer, peer)
}
- // insert into nonce/pre-handshake queue
-
- if peer.isRunning.Get() {
- if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
- peer.SendHandshakeInitiation(false)
+ if readErr != nil {
+ if errors.Is(readErr, tun.ErrTooManySegments) {
+ // TODO: record stat for this
+ // This will happen if MSS is surprisingly small (< 576)
+ // coincident with reasonably high throughput.
+ device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
+ continue
+ }
+ if !device.isClosed() {
+ if !errors.Is(readErr, os.ErrClosed) {
+ device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
+ }
+ go device.Close()
}
- addToNonceQueue(peer.queue.nonce, elem, device)
- elem = nil
+ return
}
}
}
-func (peer *Peer) FlushNonceQueue() {
- select {
- case peer.signals.flushNonceQueue <- struct{}{}:
- default:
- }
-}
-
-/* Queues packets when there is no handshake.
- * Then assigns nonces to packets sequentially
- * and creates "work" structs for workers
- *
- * Obs. A single instance per peer
- */
-func (peer *Peer) RoutineNonce() {
- var keypair *Keypair
-
- device := peer.device
- logDebug := device.log.Debug
-
- flush := func() {
- for {
- select {
- case elem := <-peer.queue.nonce:
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- default:
- return
+func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
+ for {
+ select {
+ case peer.queue.staged <- elems:
+ return
+ default:
+ }
+ select {
+ case tooOld := <-peer.queue.staged:
+ for _, elem := range tooOld.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
}
+ peer.device.PutOutboundElementsContainer(tooOld)
+ default:
}
}
+}
- defer func() {
- flush()
- logDebug.Println(peer, "- Routine: nonce worker - stopped")
- peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
- peer.routines.stopping.Done()
- }()
+func (peer *Peer) SendStagedPackets() {
+top:
+ if len(peer.queue.staged) == 0 || !peer.device.isUp() {
+ return
+ }
- peer.routines.starting.Done()
- logDebug.Println(peer, "- Routine: nonce worker - started")
+ keypair := peer.keypairs.Current()
+ if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
+ peer.SendHandshakeInitiation(false)
+ return
+ }
for {
- NextPacket:
- peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
-
+ var elemsContainerOOO *QueueOutboundElementsContainer
select {
- case <-peer.routines.stop:
- return
-
- case <-peer.signals.flushNonceQueue:
- flush()
- goto NextPacket
-
- case elem, ok := <-peer.queue.nonce:
-
- if !ok {
- return
- }
-
- // make sure to always pick the newest key
-
- for {
-
- // check validity of newest key pair
-
- keypair = peer.keypairs.Current()
- if keypair != nil && keypair.sendNonce < RejectAfterMessages {
- if time.Since(keypair.created) < RejectAfterTime {
- break
+ case elemsContainer := <-peer.queue.staged:
+ i := 0
+ for _, elem := range elemsContainer.elems {
+ elem.peer = peer
+ elem.nonce = keypair.sendNonce.Add(1) - 1
+ if elem.nonce >= RejectAfterMessages {
+ keypair.sendNonce.Store(RejectAfterMessages)
+ if elemsContainerOOO == nil {
+ elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
+ elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
+ continue
+ } else {
+ elemsContainer.elems[i] = elem
+ i++
}
- peer.queue.packetInNonceQueueIsAwaitingKey.Set(true)
- // no suitable key pair, request for new handshake
-
- select {
- case <-peer.signals.newKeypairArrived:
- default:
- }
-
- peer.SendHandshakeInitiation(false)
-
- // wait for key to be established
-
- logDebug.Println(peer, "- Awaiting keypair")
+ elem.keypair = keypair
+ }
+ elemsContainer.Lock()
+ elemsContainer.elems = elemsContainer.elems[:i]
- select {
- case <-peer.signals.newKeypairArrived:
- logDebug.Println(peer, "- Obtained awaited keypair")
+ if elemsContainerOOO != nil {
+ peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
+ }
- case <-peer.signals.flushNonceQueue:
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- flush()
- goto NextPacket
+ if len(elemsContainer.elems) == 0 {
+ peer.device.PutOutboundElementsContainer(elemsContainer)
+ goto top
+ }
- case <-peer.routines.stop:
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- return
+ // add to parallel and sequential queue
+ if peer.isRunning.Load() {
+ peer.queue.outbound.c <- elemsContainer
+ peer.device.queue.encryption.c <- elemsContainer
+ } else {
+ for _, elem := range elemsContainer.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
}
+ peer.device.PutOutboundElementsContainer(elemsContainer)
}
- peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
-
- // populate work element
- elem.peer = peer
- elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
-
- // double check in case of race condition added by future code
-
- if elem.nonce >= RejectAfterMessages {
- atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- goto NextPacket
+ if elemsContainerOOO != nil {
+ goto top
}
+ default:
+ return
+ }
+ }
+}
- elem.keypair = keypair
- elem.dropped = AtomicFalse
- elem.Lock()
-
- // add to parallel and sequential queue
- addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
+func (peer *Peer) FlushStagedPackets() {
+ for {
+ select {
+ case elemsContainer := <-peer.queue.staged:
+ for _, elem := range elemsContainer.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(elemsContainer)
+ default:
+ return
}
}
}
@@ -468,55 +441,16 @@ func calculatePaddingSize(packetSize, mtu int) int {
*
* Obs. One instance per core
*/
-func (device *Device) RoutineEncryption() {
-
+func (device *Device) RoutineEncryption(id int) {
+ var paddingZeros [PaddingMultiple]byte
var nonce [chacha20poly1305.NonceSize]byte
- logDebug := device.log.Debug
-
- defer func() {
- for {
- select {
- case elem, ok := <-device.queue.encryption:
- if ok && !elem.IsDropped() {
- elem.Drop()
- device.PutMessageBuffer(elem.buffer)
- elem.Unlock()
- }
- default:
- goto out
- }
- }
- out:
- logDebug.Println("Routine: encryption worker - stopped")
- device.state.stopping.Done()
- }()
-
- logDebug.Println("Routine: encryption worker - started")
- device.state.starting.Done()
-
- for {
-
- // fetch next element
-
- select {
- case <-device.signals.stop:
- return
-
- case elem, ok := <-device.queue.encryption:
-
- if !ok {
- return
- }
-
- // check if dropped
-
- if elem.IsDropped() {
- continue
- }
+ defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
+ device.log.Verbosef("Routine: encryption worker %d - started", id)
+ for elemsContainer := range device.queue.encryption.c {
+ for _, elem := range elemsContainer.elems {
// populate header fields
-
header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4]
@@ -528,11 +462,8 @@ func (device *Device) RoutineEncryption() {
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
-
- paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
- for i := 0; i < paddingSize; i++ {
- elem.packet = append(elem.packet, 0)
- }
+ paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
+ elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
// encrypt content and release to consumer
@@ -543,82 +474,73 @@ func (device *Device) RoutineEncryption() {
elem.packet,
nil,
)
- elem.Unlock()
}
+ elemsContainer.Unlock()
}
}
-/* Sequentially reads packets from queue and sends to endpoint
- *
- * Obs. Single instance per peer.
- * The routine terminates then the outbound queue is closed.
- */
-func (peer *Peer) RoutineSequentialSender() {
-
+func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device
-
- logDebug := device.log.Debug
- logError := device.log.Error
-
defer func() {
- for {
- select {
- case elem, ok := <-peer.queue.outbound:
- if ok {
- if !elem.IsDropped() {
- device.PutMessageBuffer(elem.buffer)
- elem.Drop()
- }
- device.PutOutboundElement(elem)
- }
- default:
- goto out
- }
- }
- out:
- logDebug.Println(peer, "- Routine: sequential sender - stopped")
- peer.routines.stopping.Done()
+ defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
+ peer.stopping.Done()
}()
+ device.log.Verbosef("%v - Routine: sequential sender - started", peer)
- logDebug.Println(peer, "- Routine: sequential sender - started")
-
- peer.routines.starting.Done()
-
- for {
- select {
+ bufs := make([][]byte, 0, maxBatchSize)
- case <-peer.routines.stop:
+ for elemsContainer := range peer.queue.outbound.c {
+ bufs = bufs[:0]
+ if elemsContainer == nil {
return
-
- case elem, ok := <-peer.queue.outbound:
-
- if !ok {
- return
- }
-
- elem.Lock()
- if elem.IsDropped() {
+ }
+ if !peer.isRunning.Load() {
+ // peer has been stopped; return re-usable elems to the shared pool.
+ // This is an optimization only. It is possible for the peer to be stopped
+ // immediately after this check, in which case, elem will get processed.
+ // The timers and SendBuffers code are resilient to a few stragglers.
+ // TODO: rework peer shutdown order to ensure
+ // that we never accidentally keep timers alive longer than necessary.
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
- continue
}
-
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketSent()
-
- // send message and return buffer to pool
-
- err := peer.SendBuffer(elem.packet)
+ continue
+ }
+ dataSent := false
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
- peer.timersDataSent()
+ dataSent = true
}
+ bufs = append(bufs, elem.packet)
+ }
+
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketSent()
+
+ err := peer.SendBuffers(bufs)
+ if dataSent {
+ peer.timersDataSent()
+ }
+ for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
- if err != nil {
- logError.Println(peer, "- Failed to send data packet", err)
- continue
+ }
+ device.PutOutboundElementsContainer(elemsContainer)
+ if err != nil {
+ var errGSO conn.ErrUDPGSODisabled
+ if errors.As(err, &errGSO) {
+ device.log.Verbosef(err.Error())
+ err = errGSO.RetryErr
}
-
- peer.keepKeyFreshSending()
}
+ if err != nil {
+ device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
+ continue
+ }
+
+ peer.keepKeyFreshSending()
}
}
diff --git a/device/sticky_default.go b/device/sticky_default.go
index 56da4eb..1038256 100644
--- a/device/sticky_default.go
+++ b/device/sticky_default.go
@@ -1,4 +1,4 @@
-// +build !linux android
+//go:build !linux
package device
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
index e3efc86..6057ff1 100644
--- a/device/sticky_linux.go
+++ b/device/sticky_linux.go
@@ -1,8 +1,6 @@
-// +build !android
-
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
@@ -21,11 +19,19 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
+ if !conn.StdNetSupportsStickySockets {
+ return nil, nil
+ }
+ if _, ok := bind.(*conn.StdNetBind); !ok {
+ return nil, nil
+ }
+
netlinkSock, err := createNetlinkRouteSocket()
if err != nil {
return nil, err
@@ -49,6 +55,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
+ defer netlinkCancel.Close()
defer unix.Close(netlinkSock)
for msg := make([]byte, 1<<16); ; {
@@ -103,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok {
break
}
- pePtr.peer.Lock()
- if &pePtr.peer.endpoint != pePtr.endpoint {
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.Lock()
+ if &pePtr.peer.endpoint.val != pePtr.endpoint {
+ pePtr.peer.endpoint.Unlock()
break
}
- if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
- pePtr.peer.Unlock()
+ if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
+ pePtr.peer.endpoint.Unlock()
break
}
- pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.clearSrcOnTx = true
+ pePtr.peer.endpoint.Unlock()
}
attr = attr[attrhdr.Len:]
}
@@ -127,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
- peer.RLock()
- if peer.endpoint == nil {
- peer.RUnlock()
+ peer.endpoint.Lock()
+ if peer.endpoint.val == nil {
+ peer.endpoint.Unlock()
continue
}
- nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
+ nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil {
- peer.RUnlock()
+ peer.endpoint.Unlock()
continue
}
- if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
- peer.RUnlock()
+ if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
+ peer.endpoint.Unlock()
break
}
nlmsg := struct {
@@ -165,26 +172,26 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8,
Type: unix.RTA_DST,
},
- nativeEP.Dst4().Addr,
+ nativeEP.DstIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
- nativeEP.Src4().Src,
+ nativeEP.SrcIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
},
- uint32(bind.LastMark()),
+ device.net.fwmark,
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
- endpoint: &peer.endpoint,
+ endpoint: &peer.endpoint.val,
}
reqPeerLock.Unlock()
- peer.RUnlock()
+ peer.endpoint.Unlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
@@ -200,13 +207,13 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
}
func createNetlinkRouteSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
- Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
+ Groups: unix.RTMGRP_IPV4_ROUTE,
}
err = unix.Bind(sock, saddr)
if err != nil {
diff --git a/device/timers.go b/device/timers.go
index 0232eef..d4a4ed4 100644
--- a/device/timers.go
+++ b/device/timers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/
@@ -8,16 +8,16 @@
package device
import (
- "math/rand"
"sync"
- "sync/atomic"
"time"
+ _ "unsafe"
)
-/* This Timer structure and related functions should roughly copy the interface of
- * the Linux kernel's struct timer_list.
- */
+//go:linkname fastrandn runtime.fastrandn
+func fastrandn(n uint32) uint32
+// A Timer manages time-based aspects of the WireGuard protocol.
+// Timer roughly copies the interface of the Linux kernel's struct timer_list.
type Timer struct {
*time.Timer
modifyingLock sync.RWMutex
@@ -29,18 +29,17 @@ func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
timer := &Timer{}
timer.Timer = time.AfterFunc(time.Hour, func() {
timer.runningLock.Lock()
+ defer timer.runningLock.Unlock()
timer.modifyingLock.Lock()
if !timer.isPending {
timer.modifyingLock.Unlock()
- timer.runningLock.Unlock()
return
}
timer.isPending = false
timer.modifyingLock.Unlock()
expirationFunction(peer)
- timer.runningLock.Unlock()
})
timer.Stop()
return timer
@@ -74,12 +73,12 @@ func (timer *Timer) IsPending() bool {
}
func (peer *Peer) timersActive() bool {
- return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0
+ return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
}
func expiredRetransmitHandshake(peer *Peer) {
- if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
- peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2)
+ if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
+ peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() {
peer.timers.sendKeepalive.Del()
@@ -88,7 +87,7 @@ func expiredRetransmitHandshake(peer *Peer) {
/* We drop all packets without a keypair and don't try again,
* if we try unsuccessfully for too long to make a handshake.
*/
- peer.FlushNonceQueue()
+ peer.FlushStagedPackets()
/* We set a timer for destroying any residue that might be left
* of a partial exchange.
@@ -97,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
}
} else {
- atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
- peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
+ peer.timers.handshakeAttempts.Add(1)
+ peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(true)
}
@@ -113,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive()
- if peer.timers.needAnotherKeepalive.Get() {
- peer.timers.needAnotherKeepalive.Set(false)
+ if peer.timers.needAnotherKeepalive.Load() {
+ peer.timers.needAnotherKeepalive.Store(false)
if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
}
@@ -122,24 +117,19 @@ func expiredSendKeepalive(peer *Peer) {
}
func expiredNewHandshake(peer *Peer) {
- peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
+ peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(false)
-
}
func expiredZeroKeyMaterial(peer *Peer) {
- peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
+ peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds()))
peer.ZeroAndFlushAll()
}
func expiredPersistentKeepalive(peer *Peer) {
- if peer.persistentKeepaliveInterval > 0 {
+ if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
}
}
@@ -147,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
- peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+ peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
}
}
@@ -157,7 +147,7 @@ func (peer *Peer) timersDataReceived() {
if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else {
- peer.timers.needAnotherKeepalive.Set(true)
+ peer.timers.needAnotherKeepalive.Store(true)
}
}
}
@@ -179,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
/* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() {
- peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+ peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
}
}
@@ -188,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() {
peer.timers.retransmitHandshake.Del()
}
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
- peer.timers.sentLastMinuteHandshake.Set(false)
- atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
+ peer.timers.handshakeAttempts.Store(0)
+ peer.timers.sentLastMinuteHandshake.Store(false)
+ peer.lastHandshakeNano.Store(time.Now().UnixNano())
}
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@@ -202,8 +192,9 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
- if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
- peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
+ keepalive := peer.persistentKeepaliveInterval.Load()
+ if keepalive > 0 && peer.timersActive() {
+ peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
}
}
@@ -213,9 +204,12 @@ func (peer *Peer) timersInit() {
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
- peer.timers.sentLastMinuteHandshake.Set(false)
- peer.timers.needAnotherKeepalive.Set(false)
+}
+
+func (peer *Peer) timersStart() {
+ peer.timers.handshakeAttempts.Store(0)
+ peer.timers.sentLastMinuteHandshake.Store(false)
+ peer.timers.needAnotherKeepalive.Store(false)
}
func (peer *Peer) timersStop() {
diff --git a/device/tun.go b/device/tun.go
index 1f88f33..2a2ace9 100644
--- a/device/tun.go
+++ b/device/tun.go
@@ -1,12 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
- "sync/atomic"
+ "fmt"
"golang.zx2c4.com/wireguard/tun"
)
@@ -14,43 +14,40 @@ import (
const DefaultMTU = 1420
func (device *Device) RoutineTUNEventReader() {
- setUp := false
- logDebug := device.log.Debug
- logInfo := device.log.Info
- logError := device.log.Error
-
- logDebug.Println("Routine: event worker - started")
- device.state.starting.Done()
+ device.log.Verbosef("Routine: event worker - started")
for event := range device.tun.device.Events() {
if event&tun.EventMTUUpdate != 0 {
mtu, err := device.tun.device.MTU()
- old := atomic.LoadInt32(&device.tun.mtu)
if err != nil {
- logError.Println("Failed to load updated MTU of device:", err)
- } else if int(old) != mtu {
- if mtu+MessageTransportSize > MaxMessageSize {
- logInfo.Println("MTU updated:", mtu, "(too large)")
- } else {
- logInfo.Println("MTU updated:", mtu)
- }
- atomic.StoreInt32(&device.tun.mtu, int32(mtu))
+ device.log.Errorf("Failed to load updated MTU of device: %v", err)
+ continue
+ }
+ if mtu < 0 {
+ device.log.Errorf("MTU not updated to negative value: %v", mtu)
+ continue
+ }
+ var tooLarge string
+ if mtu > MaxContentSize {
+ tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
+ mtu = MaxContentSize
+ }
+ old := device.tun.mtu.Swap(int32(mtu))
+ if int(old) != mtu {
+ device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
}
}
- if event&tun.EventUp != 0 && !setUp {
- logInfo.Println("Interface set up")
- setUp = true
+ if event&tun.EventUp != 0 {
+ device.log.Verbosef("Interface up requested")
device.Up()
}
- if event&tun.EventDown != 0 && setUp {
- logInfo.Println("Interface set down")
- setUp = false
+ if event&tun.EventDown != 0 {
+ device.log.Verbosef("Interface down requested")
device.Down()
}
}
- logDebug.Println("Routine: event worker - stopped")
- device.state.stopping.Done()
+ device.log.Verbosef("Routine: event worker - stopped")
}
diff --git a/device/tun_test.go b/device/tun_test.go
deleted file mode 100644
index a2db2a5..0000000
--- a/device/tun_test.go
+++ /dev/null
@@ -1,56 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "errors"
- "os"
-
- "golang.zx2c4.com/wireguard/tun"
-)
-
-// newDummyTUN creates a dummy TUN device with the specified name.
-func newDummyTUN(name string) tun.Device {
- return &dummyTUN{
- name: name,
- packets: make(chan []byte, 100),
- events: make(chan tun.Event, 10),
- }
-}
-
-// A dummyTUN is a tun.Device which is used in unit tests.
-type dummyTUN struct {
- name string
- mtu int
- packets chan []byte
- events chan tun.Event
-}
-
-func (d *dummyTUN) Events() chan tun.Event { return d.events }
-func (*dummyTUN) File() *os.File { return nil }
-func (*dummyTUN) Flush() error { return nil }
-func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil }
-func (d *dummyTUN) Name() (string, error) { return d.name, nil }
-
-func (d *dummyTUN) Close() error {
- close(d.events)
- close(d.packets)
- return nil
-}
-
-func (d *dummyTUN) Read(b []byte, offset int) (int, error) {
- buf, ok := <-d.packets
- if !ok {
- return 0, errors.New("device closed")
- }
- copy(b[offset:], buf)
- return len(buf), nil
-}
-
-func (d *dummyTUN) Write(b []byte, offset int) (int, error) {
- d.packets <- b[offset:]
- return len(b), nil
-}
diff --git a/device/uapi.go b/device/uapi.go
index 9f9c9bd..d81dae3 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -1,45 +1,77 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bufio"
+ "bytes"
"errors"
"fmt"
"io"
"net"
+ "net/netip"
"strconv"
"strings"
- "sync/atomic"
+ "sync"
"time"
- "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ipc"
)
type IPCError struct {
- int64
+ code int64 // error code
+ err error // underlying/wrapped error
}
func (s IPCError) Error() string {
- return fmt.Sprintf("IPC error: %d", s.int64)
+ return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
+}
+
+func (s IPCError) Unwrap() error {
+ return s.err
}
func (s IPCError) ErrorCode() int64 {
- return s.int64
+ return s.code
+}
+
+func ipcErrorf(code int64, msg string, args ...any) *IPCError {
+ return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
}
-func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
- lines := make([]string, 0, 100)
- send := func(line string) {
- lines = append(lines, line)
+var byteBufferPool = &sync.Pool{
+ New: func() any { return new(bytes.Buffer) },
+}
+
+// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
+// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
+func (device *Device) IpcGetOperation(w io.Writer) error {
+ device.ipcMutex.RLock()
+ defer device.ipcMutex.RUnlock()
+
+ buf := byteBufferPool.Get().(*bytes.Buffer)
+ buf.Reset()
+ defer byteBufferPool.Put(buf)
+ sendf := func(format string, args ...any) {
+ fmt.Fprintf(buf, format, args...)
+ buf.WriteByte('\n')
+ }
+ keyf := func(prefix string, key *[32]byte) {
+ buf.Grow(len(key)*2 + 2 + len(prefix))
+ buf.WriteString(prefix)
+ buf.WriteByte('=')
+ const hex = "0123456789abcdef"
+ for i := 0; i < len(key); i++ {
+ buf.WriteByte(hex[key[i]>>4])
+ buf.WriteByte(hex[key[i]&0xf])
+ }
+ buf.WriteByte('\n')
}
func() {
-
// lock required resources
device.net.RLock()
@@ -54,353 +86,326 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
// serialize device related values
if !device.staticIdentity.privateKey.IsZero() {
- send("private_key=" + device.staticIdentity.privateKey.ToHex())
+ keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
}
if device.net.port != 0 {
- send(fmt.Sprintf("listen_port=%d", device.net.port))
+ sendf("listen_port=%d", device.net.port)
}
if device.net.fwmark != 0 {
- send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
+ sendf("fwmark=%d", device.net.fwmark)
}
- // serialize each peer state
-
for _, peer := range device.peers.keyMap {
- peer.RLock()
- defer peer.RUnlock()
-
- send("public_key=" + peer.handshake.remoteStatic.ToHex())
- send("preshared_key=" + peer.handshake.presharedKey.ToHex())
- send("protocol_version=1")
- if peer.endpoint != nil {
- send("endpoint=" + peer.endpoint.DstToString())
+ // Serialize peer state.
+ peer.handshake.mutex.RLock()
+ keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
+ keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
+ peer.handshake.mutex.RUnlock()
+ sendf("protocol_version=1")
+ peer.endpoint.Lock()
+ if peer.endpoint.val != nil {
+ sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
+ peer.endpoint.Unlock()
- nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
+ nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
- send(fmt.Sprintf("last_handshake_time_sec=%d", secs))
- send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
- send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
- send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
- send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
-
- for _, ip := range device.allowedips.EntriesForPeer(peer) {
- send("allowed_ip=" + ip.String())
- }
+ sendf("last_handshake_time_sec=%d", secs)
+ sendf("last_handshake_time_nsec=%d", nano)
+ sendf("tx_bytes=%d", peer.txBytes.Load())
+ sendf("rx_bytes=%d", peer.rxBytes.Load())
+ sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
+ return true
+ })
}
}()
// send lines (does not require resource locks)
-
- for _, line := range lines {
- _, err := socket.WriteString(line + "\n")
- if err != nil {
- return &IPCError{ipc.IpcErrorIO}
- }
+ if _, err := w.Write(buf.Bytes()); err != nil {
+ return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
}
return nil
}
-func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
- scanner := bufio.NewScanner(socket)
- logError := device.log.Error
- logDebug := device.log.Debug
+// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
+// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
+func (device *Device) IpcSetOperation(r io.Reader) (err error) {
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
- var peer *Peer
+ defer func() {
+ if err != nil {
+ device.log.Errorf("%v", err)
+ }
+ }()
- dummy := false
- createdNewPeer := false
+ peer := new(ipcSetPeer)
deviceConfig := true
+ scanner := bufio.NewScanner(r)
for scanner.Scan() {
-
- // parse line
-
line := scanner.Text()
if line == "" {
+ // Blank line means terminate operation.
+ peer.handlePostConfig()
return nil
}
- parts := strings.Split(line, "=")
- if len(parts) != 2 {
- return &IPCError{ipc.IpcErrorProtocol}
+ key, value, ok := strings.Cut(line, "=")
+ if !ok {
+ return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
}
- key := parts[0]
- value := parts[1]
-
- /* device configuration */
-
- if deviceConfig {
-
- switch key {
- case "private_key":
- var sk NoisePrivateKey
- err := sk.FromMaybeZeroHex(value)
- if err != nil {
- logError.Println("Failed to set private_key:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- logDebug.Println("UAPI: Updating private key")
- device.SetPrivateKey(sk)
-
- case "listen_port":
-
- // parse port number
-
- port, err := strconv.ParseUint(value, 10, 16)
- if err != nil {
- logError.Println("Failed to parse listen_port:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- // update port and rebind
-
- logDebug.Println("UAPI: Updating listen port")
-
- device.net.Lock()
- device.net.port = uint16(port)
- device.net.Unlock()
-
- if err := device.BindUpdate(); err != nil {
- logError.Println("Failed to set listen_port:", err)
- return &IPCError{ipc.IpcErrorPortInUse}
- }
- case "fwmark":
-
- // parse fwmark field
-
- fwmark, err := func() (uint32, error) {
- if value == "" {
- return 0, nil
- }
- mark, err := strconv.ParseUint(value, 10, 32)
- return uint32(mark), err
- }()
-
- if err != nil {
- logError.Println("Invalid fwmark", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- logDebug.Println("UAPI: Updating fwmark")
-
- if err := device.BindSetMark(uint32(fwmark)); err != nil {
- logError.Println("Failed to update fwmark:", err)
- return &IPCError{ipc.IpcErrorPortInUse}
- }
-
- case "public_key":
- // switch to peer configuration
- logDebug.Println("UAPI: Transition to peer configuration")
+ if key == "public_key" {
+ if deviceConfig {
deviceConfig = false
-
- case "replace_peers":
- if value != "true" {
- logError.Println("Failed to set replace_peers, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- logDebug.Println("UAPI: Removing all peers")
- device.RemoveAllPeers()
-
- default:
- logError.Println("Invalid UAPI device key:", key)
- return &IPCError{ipc.IpcErrorInvalid}
}
+ peer.handlePostConfig()
+ // Load/create the peer we are now configuring.
+ err := device.handlePublicKeyLine(peer, value)
+ if err != nil {
+ return err
+ }
+ continue
}
- /* peer configuration */
-
- if !deviceConfig {
-
- switch key {
-
- case "public_key":
- var publicKey NoisePublicKey
- err := publicKey.FromHex(value)
- if err != nil {
- logError.Println("Failed to get peer by public key:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- // ignore peer with public key of device
-
- device.staticIdentity.RLock()
- dummy = device.staticIdentity.publicKey.Equals(publicKey)
- device.staticIdentity.RUnlock()
-
- if dummy {
- peer = &Peer{}
- } else {
- peer = device.LookupPeer(publicKey)
- }
+ var err error
+ if deviceConfig {
+ err = device.handleDeviceLine(key, value)
+ } else {
+ err = device.handlePeerLine(peer, key, value)
+ }
+ if err != nil {
+ return err
+ }
+ }
+ peer.handlePostConfig()
- createdNewPeer = peer == nil
- if createdNewPeer {
- peer, err = device.NewPeer(publicKey)
- if err != nil {
- logError.Println("Failed to create new peer:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- if peer == nil {
- dummy = true
- peer = &Peer{}
- } else {
- logDebug.Println(peer, "- UAPI: Created")
- }
- }
+ if err := scanner.Err(); err != nil {
+ return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
+ }
+ return nil
+}
- case "update_only":
+func (device *Device) handleDeviceLine(key, value string) error {
+ switch key {
+ case "private_key":
+ var sk NoisePrivateKey
+ err := sk.FromMaybeZeroHex(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
+ }
+ device.log.Verbosef("UAPI: Updating private key")
+ device.SetPrivateKey(sk)
- // allow disabling of creation
+ case "listen_port":
+ port, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
+ }
- if value != "true" {
- logError.Println("Failed to set update only, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- if createdNewPeer && !dummy {
- device.RemovePeer(peer.handshake.remoteStatic)
- peer = &Peer{}
- dummy = true
- }
+ // update port and rebind
+ device.log.Verbosef("UAPI: Updating listen port")
- case "remove":
+ device.net.Lock()
+ device.net.port = uint16(port)
+ device.net.Unlock()
- // remove currently selected peer from device
+ if err := device.BindUpdate(); err != nil {
+ return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
+ }
- if value != "true" {
- logError.Println("Failed to set remove, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
- if !dummy {
- logDebug.Println(peer, "- UAPI: Removing")
- device.RemovePeer(peer.handshake.remoteStatic)
- }
- peer = &Peer{}
- dummy = true
+ case "fwmark":
+ mark, err := strconv.ParseUint(value, 10, 32)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
+ }
- case "preshared_key":
+ device.log.Verbosef("UAPI: Updating fwmark")
+ if err := device.BindSetMark(uint32(mark)); err != nil {
+ return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
+ }
- // update PSK
+ case "replace_peers":
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
+ }
+ device.log.Verbosef("UAPI: Removing all peers")
+ device.RemoveAllPeers()
- logDebug.Println(peer, "- UAPI: Updating preshared key")
+ default:
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
+ }
- peer.handshake.mutex.Lock()
- err := peer.handshake.presharedKey.FromHex(value)
- peer.handshake.mutex.Unlock()
+ return nil
+}
- if err != nil {
- logError.Println("Failed to set preshared key:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+// An ipcSetPeer is the current state of an IPC set operation on a peer.
+type ipcSetPeer struct {
+ *Peer // Peer is the current peer being operated on
+ dummy bool // dummy reports whether this peer is a temporary, placeholder peer
+ created bool // new reports whether this is a newly created peer
+ pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
+}
- case "endpoint":
+func (peer *ipcSetPeer) handlePostConfig() {
+ if peer.Peer == nil || peer.dummy {
+ return
+ }
+ if peer.created {
+ peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
+ }
+ if peer.device.isUp() {
+ peer.Start()
+ if peer.pkaOn {
+ peer.SendKeepalive()
+ }
+ peer.SendStagedPackets()
+ }
+}
- // set endpoint destination
+func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error {
+ // Load/create the peer we are configuring.
+ var publicKey NoisePublicKey
+ err := publicKey.FromHex(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
+ }
- logDebug.Println(peer, "- UAPI: Updating endpoint")
+ // Ignore peer with the same public key as this device.
+ device.staticIdentity.RLock()
+ peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
+ device.staticIdentity.RUnlock()
- err := func() error {
- peer.Lock()
- defer peer.Unlock()
- endpoint, err := conn.CreateEndpoint(value)
- if err != nil {
- return err
- }
- peer.endpoint = endpoint
- return nil
- }()
+ if peer.dummy {
+ peer.Peer = &Peer{}
+ } else {
+ peer.Peer = device.LookupPeer(publicKey)
+ }
- if err != nil {
- logError.Println("Failed to set endpoint:", err, ":", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- case "persistent_keepalive_interval":
-
- // update persistent keepalive interval
-
- logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval")
-
- secs, err := strconv.ParseUint(value, 10, 16)
- if err != nil {
- logError.Println("Failed to set persistent keepalive interval:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
-
- old := peer.persistentKeepaliveInterval
- peer.persistentKeepaliveInterval = uint16(secs)
-
- // send immediate keepalive if we're turning it on and before it wasn't on
-
- if old == 0 && secs != 0 {
- if err != nil {
- logError.Println("Failed to get tun device status:", err)
- return &IPCError{ipc.IpcErrorIO}
- }
- if device.isUp.Get() && !dummy {
- peer.SendKeepalive()
- }
- }
+ peer.created = peer.Peer == nil
+ if peer.created {
+ peer.Peer, err = device.NewPeer(publicKey)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
+ }
+ device.log.Verbosef("%v - UAPI: Created", peer.Peer)
+ }
+ return nil
+}
- case "replace_allowed_ips":
+func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error {
+ switch key {
+ case "update_only":
+ // allow disabling of creation
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
+ }
+ if peer.created && !peer.dummy {
+ device.RemovePeer(peer.handshake.remoteStatic)
+ peer.Peer = &Peer{}
+ peer.dummy = true
+ }
- logDebug.Println(peer, "- UAPI: Removing all allowedips")
+ case "remove":
+ // remove currently selected peer from device
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
+ }
+ if !peer.dummy {
+ device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
+ device.RemovePeer(peer.handshake.remoteStatic)
+ }
+ peer.Peer = &Peer{}
+ peer.dummy = true
- if value != "true" {
- logError.Println("Failed to replace allowedips, invalid value:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ case "preshared_key":
+ device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
- if dummy {
- continue
- }
+ peer.handshake.mutex.Lock()
+ err := peer.handshake.presharedKey.FromHex(value)
+ peer.handshake.mutex.Unlock()
- device.allowedips.RemoveByPeer(peer)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
+ }
- case "allowed_ip":
+ case "endpoint":
+ device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
+ endpoint, err := device.net.bind.ParseEndpoint(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
+ }
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ peer.endpoint.val = endpoint
- logDebug.Println(peer, "- UAPI: Adding allowedip")
+ case "persistent_keepalive_interval":
+ device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
- _, network, err := net.ParseCIDR(value)
- if err != nil {
- logError.Println("Failed to set allowed ip:", err)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ secs, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
+ }
- if dummy {
- continue
- }
+ old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
- ones, _ := network.Mask.Size()
- device.allowedips.Insert(network.IP, uint(ones), peer)
+ // Send immediate keepalive if we're turning it on and before it wasn't on.
+ peer.pkaOn = old == 0 && secs != 0
- case "protocol_version":
+ case "replace_allowed_ips":
+ device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
+ if value != "true" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
+ }
+ if peer.dummy {
+ return nil
+ }
+ device.allowedips.RemoveByPeer(peer.Peer)
- if value != "1" {
- logError.Println("Invalid protocol version:", value)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ case "allowed_ip":
+ device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
+ prefix, err := netip.ParsePrefix(value)
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
+ }
+ if peer.dummy {
+ return nil
+ }
+ device.allowedips.Insert(prefix, peer.Peer)
- default:
- logError.Println("Invalid UAPI peer key:", key)
- return &IPCError{ipc.IpcErrorInvalid}
- }
+ case "protocol_version":
+ if value != "1" {
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
}
+
+ default:
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
}
return nil
}
-func (device *Device) IpcHandle(socket net.Conn) {
+func (device *Device) IpcGet() (string, error) {
+ buf := new(strings.Builder)
+ if err := device.IpcGetOperation(buf); err != nil {
+ return "", err
+ }
+ return buf.String(), nil
+}
- // create buffered read/writer
+func (device *Device) IpcSet(uapiConf string) error {
+ return device.IpcSetOperation(strings.NewReader(uapiConf))
+}
+func (device *Device) IpcHandle(socket net.Conn) {
defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
@@ -409,45 +414,44 @@ func (device *Device) IpcHandle(socket net.Conn) {
return bufio.NewReadWriter(reader, writer)
}(socket)
- defer buffered.Flush()
-
- op, err := buffered.ReadString('\n')
- if err != nil {
- return
- }
-
- // handle operation
-
- var status *IPCError
+ for {
+ op, err := buffered.ReadString('\n')
+ if err != nil {
+ return
+ }
- switch op {
- case "set=1\n":
- err = device.IpcSetOperation(buffered.Reader)
- if err != nil && !errors.As(err, &status) {
- // should never happen
- device.log.Error.Println("Invalid UAPI error:", err)
- status = &IPCError{1}
+ // handle operation
+ switch op {
+ case "set=1\n":
+ err = device.IpcSetOperation(buffered.Reader)
+ case "get=1\n":
+ var nextByte byte
+ nextByte, err = buffered.ReadByte()
+ if err != nil {
+ return
+ }
+ if nextByte != '\n' {
+ err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
+ break
+ }
+ err = device.IpcGetOperation(buffered.Writer)
+ default:
+ device.log.Errorf("invalid UAPI operation: %v", op)
+ return
}
- case "get=1\n":
- err = device.IpcGetOperation(buffered.Writer)
+ // write status
+ var status *IPCError
if err != nil && !errors.As(err, &status) {
- // should never happen
- device.log.Error.Println("Invalid UAPI error:", err)
- status = &IPCError{1}
+ // shouldn't happen
+ status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
}
-
- default:
- device.log.Error.Println("Invalid UAPI operation:", op)
- return
- }
-
- // write status
-
- if status != nil {
- device.log.Error.Println(status)
- fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
- } else {
- fmt.Fprintf(buffered, "errno=0\n\n")
+ if status != nil {
+ device.log.Errorf("%v", status)
+ fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
+ } else {
+ fmt.Fprintf(buffered, "errno=0\n\n")
+ }
+ buffered.Flush()
}
}
diff --git a/device/version.go b/device/version.go
deleted file mode 100644
index 0877595..0000000
--- a/device/version.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package device
-
-const WireGuardGoVersion = "0.0.20200320"