aboutsummaryrefslogtreecommitdiffstats
path: root/device
diff options
context:
space:
mode:
Diffstat (limited to 'device')
-rw-r--r--device/alignment_test.go65
-rw-r--r--device/allowedips.go89
-rw-r--r--device/allowedips_rand_test.go7
-rw-r--r--device/allowedips_test.go10
-rw-r--r--device/bind_test.go12
-rw-r--r--device/channels.go40
-rw-r--r--device/constants.go2
-rw-r--r--device/cookie.go7
-rw-r--r--device/cookie_test.go7
-rw-r--r--device/device.go90
-rw-r--r--device/device_test.go85
-rw-r--r--device/endpoint_test.go40
-rw-r--r--device/indextable.go2
-rw-r--r--device/ip.go2
-rw-r--r--device/kdf_test.go4
-rw-r--r--device/keypair.go15
-rw-r--r--device/logger.go10
-rw-r--r--device/misc.go41
-rw-r--r--device/mobilequirks.go11
-rw-r--r--device/noise-helpers.go12
-rw-r--r--device/noise-protocol.go86
-rw-r--r--device/noise-types.go2
-rw-r--r--device/noise_test.go12
-rw-r--r--device/peer.go134
-rw-r--r--device/pools.go58
-rw-r--r--device/pools_test.go77
-rw-r--r--device/queueconstants_android.go8
-rw-r--r--device/queueconstants_default.go8
-rw-r--r--device/queueconstants_ios.go4
-rw-r--r--device/queueconstants_windows.go2
-rw-r--r--device/race_disabled_test.go4
-rw-r--r--device/race_enabled_test.go4
-rw-r--r--device/receive.go385
-rw-r--r--device/send.go355
-rw-r--r--device/sticky_default.go2
-rw-r--r--device/sticky_linux.go45
-rw-r--r--device/timers.go55
-rw-r--r--device/tun.go5
-rw-r--r--device/uapi.go82
39 files changed, 1044 insertions, 835 deletions
diff --git a/device/alignment_test.go b/device/alignment_test.go
deleted file mode 100644
index a918112..0000000
--- a/device/alignment_test.go
+++ /dev/null
@@ -1,65 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "reflect"
- "testing"
- "unsafe"
-)
-
-func checkAlignment(t *testing.T, name string, offset uintptr) {
- t.Helper()
- if offset%8 != 0 {
- t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
- }
-}
-
-// TestPeerAlignment checks that atomically-accessed fields are
-// aligned to 64-bit boundaries, as required by the atomic package.
-//
-// Unfortunately, violating this rule on 32-bit platforms results in a
-// hard segfault at runtime.
-func TestPeerAlignment(t *testing.T) {
- var p Peer
-
- typ := reflect.TypeOf(&p).Elem()
- t.Logf("Peer type size: %d, with fields:", typ.Size())
- for i := 0; i < typ.NumField(); i++ {
- field := typ.Field(i)
- t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
- field.Name,
- field.Offset,
- field.Type.Size(),
- field.Type.Align(),
- )
- }
-
- checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
- checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
-}
-
-// TestDeviceAlignment checks that atomically-accessed fields are
-// aligned to 64-bit boundaries, as required by the atomic package.
-//
-// Unfortunately, violating this rule on 32-bit platforms results in a
-// hard segfault at runtime.
-func TestDeviceAlignment(t *testing.T) {
- var d Device
-
- typ := reflect.TypeOf(&d).Elem()
- t.Logf("Device type size: %d, with fields:", typ.Size())
- for i := 0; i < typ.NumField(); i++ {
- field := typ.Field(i)
- t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
- field.Name,
- field.Offset,
- field.Type.Size(),
- field.Type.Align(),
- )
- }
- checkAlignment(t, "Device.rate.underLoadUntil", unsafe.Offsetof(d.rate)+unsafe.Offsetof(d.rate.underLoadUntil))
-}
diff --git a/device/allowedips.go b/device/allowedips.go
index c08399b..fa46f97 100644
--- a/device/allowedips.go
+++ b/device/allowedips.go
@@ -1,15 +1,17 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"container/list"
+ "encoding/binary"
"errors"
"math/bits"
"net"
+ "net/netip"
"sync"
"unsafe"
)
@@ -26,49 +28,28 @@ type trieEntry struct {
cidr uint8
bitAtByte uint8
bitAtShift uint8
- bits net.IP
+ bits []byte
perPeerElem *list.Element
}
-func isLittleEndian() bool {
- one := uint32(1)
- return *(*byte)(unsafe.Pointer(&one)) != 0
-}
-
-func swapU32(i uint32) uint32 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes32(i)
-}
-
-func swapU64(i uint64) uint64 {
- if !isLittleEndian() {
- return i
- }
-
- return bits.ReverseBytes64(i)
-}
-
-func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
+func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
- a := (*uint32)(unsafe.Pointer(&ip1[0]))
- b := (*uint32)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
- return uint8(bits.LeadingZeros32(swapU32(x)))
+ a := binary.BigEndian.Uint32(ip1)
+ b := binary.BigEndian.Uint32(ip2)
+ x := a ^ b
+ return uint8(bits.LeadingZeros32(x))
} else if size == net.IPv6len {
- a := (*uint64)(unsafe.Pointer(&ip1[0]))
- b := (*uint64)(unsafe.Pointer(&ip2[0]))
- x := *a ^ *b
+ a := binary.BigEndian.Uint64(ip1)
+ b := binary.BigEndian.Uint64(ip2)
+ x := a ^ b
if x != 0 {
- return uint8(bits.LeadingZeros64(swapU64(x)))
+ return uint8(bits.LeadingZeros64(x))
}
- a = (*uint64)(unsafe.Pointer(&ip1[8]))
- b = (*uint64)(unsafe.Pointer(&ip2[8]))
- x = *a ^ *b
- return 64 + uint8(bits.LeadingZeros64(swapU64(x)))
+ a = binary.BigEndian.Uint64(ip1[8:])
+ b = binary.BigEndian.Uint64(ip2[8:])
+ x = a ^ b
+ return 64 + uint8(bits.LeadingZeros64(x))
} else {
panic("Wrong size bit string")
}
@@ -85,7 +66,7 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
-func (node *trieEntry) choose(ip net.IP) byte {
+func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -104,7 +85,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil
}
-func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
+func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
@@ -117,7 +98,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
return
}
-func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
+func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
@@ -207,7 +188,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
}
}
-func (node *trieEntry) lookup(ip net.IP) *Peer {
+func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -229,13 +210,14 @@ type AllowedIPs struct {
mutex sync.RWMutex
}
-func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
+func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
- if !cb(node.bits, node.cidr) {
+ a, _ := netip.AddrFromSlice(node.bits)
+ if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return
}
}
@@ -283,28 +265,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
}
}
-func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
+func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
- switch len(ip) {
- case net.IPv6len:
- parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
- case net.IPv4len:
- parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
- default:
+ if prefix.Addr().Is6() {
+ ip := prefix.Addr().As16()
+ parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else if prefix.Addr().Is4() {
+ ip := prefix.Addr().As4()
+ parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
+ } else {
panic(errors.New("inserting unknown address type"))
}
}
-func (table *AllowedIPs) Lookup(address []byte) *Peer {
+func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
- switch len(address) {
+ switch len(ip) {
case net.IPv6len:
- return table.IPv6.lookup(address)
+ return table.IPv6.lookup(ip)
case net.IPv4len:
- return table.IPv4.lookup(address)
+ return table.IPv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
index 16de170..07065c3 100644
--- a/device/allowedips_rand_test.go
+++ b/device/allowedips_rand_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,6 +8,7 @@ package device
import (
"math/rand"
"net"
+ "net/netip"
"sort"
"testing"
)
@@ -93,14 +94,14 @@ func TestTrieRandom(t *testing.T) {
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
- allowedIPs.Insert(addr4[:], cidr, peers[index])
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
- allowedIPs.Insert(addr6[:], cidr, peers[index])
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
index 2059a88..cde068e 100644
--- a/device/allowedips_test.go
+++ b/device/allowedips_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -8,6 +8,7 @@ package device
import (
"math/rand"
"net"
+ "net/netip"
"testing"
)
@@ -18,7 +19,6 @@ type testPairCommonBits struct {
}
func TestCommonBits(t *testing.T) {
-
tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
@@ -39,7 +39,7 @@ func TestCommonBits(t *testing.T) {
}
}
-func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
+func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
root := parentIndirection{&trie, 2}
@@ -98,7 +98,7 @@ func TestTrieIPv4(t *testing.T) {
var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
- allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
@@ -208,7 +208,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
- allowedIPs.Insert(addr, cidr, peer)
+ allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
diff --git a/device/bind_test.go b/device/bind_test.go
index 1c0d247..302a521 100644
--- a/device/bind_test.go
+++ b/device/bind_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -26,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil
}
-func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
+func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6
if !ok {
return 0, nil, errors.New("closed")
}
- copy(buff, datagram.msg)
+ copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
-func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
+func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4
if !ok {
return 0, nil, errors.New("closed")
}
- copy(buff, datagram.msg)
+ copy(buf, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
@@ -51,6 +51,6 @@ func (b *DummyBind) Close() error {
return nil
}
-func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
+func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
return nil
}
diff --git a/device/channels.go b/device/channels.go
index 6f4370a..e526f6b 100644
--- a/device/channels.go
+++ b/device/channels.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -19,13 +19,13 @@ import (
// call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed.
type outboundQueue struct {
- c chan *QueueOutboundElement
+ c chan *QueueOutboundElementsContainer
wg sync.WaitGroup
}
func newOutboundQueue() *outboundQueue {
q := &outboundQueue{
- c: make(chan *QueueOutboundElement, QueueOutboundSize),
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
q.wg.Add(1)
go func() {
@@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
// A inboundQueue is similar to an outboundQueue; see those docs.
type inboundQueue struct {
- c chan *QueueInboundElement
+ c chan *QueueInboundElementsContainer
wg sync.WaitGroup
}
func newInboundQueue() *inboundQueue {
q := &inboundQueue{
- c: make(chan *QueueInboundElement, QueueInboundSize),
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
q.wg.Add(1)
go func() {
@@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
}
type autodrainingInboundQueue struct {
- c chan *QueueInboundElement
+ c chan *QueueInboundElementsContainer
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
- c: make(chan *QueueInboundElement, QueueInboundSize),
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
@@ -90,10 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
- case elem := <-q.c:
- elem.Lock()
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsContainer(elemsContainer)
default:
return
}
@@ -101,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
}
type autodrainingOutboundQueue struct {
- c chan *QueueOutboundElement
+ c chan *QueueOutboundElementsContainer
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@@ -111,7 +114,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
- c: make(chan *QueueOutboundElement, QueueOutboundSize),
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
@@ -120,10 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
- case elem := <-q.c:
- elem.Lock()
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
diff --git a/device/constants.go b/device/constants.go
index 53a2526..59854a1 100644
--- a/device/constants.go
+++ b/device/constants.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/cookie.go b/device/cookie.go
index b02b769..876f05d 100644
--- a/device/cookie.go
+++ b/device/cookie.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -83,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
return hmac.Equal(mac1[:], msg[smac1:smac2])
}
-func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
+func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
st.RLock()
defer st.RUnlock()
@@ -119,7 +119,6 @@ func (st *CookieChecker) CreateReply(
recv uint32,
src []byte,
) (*MessageCookieReply, error) {
-
st.RLock()
// refresh cookie secret
@@ -204,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
-
if err != nil {
return false
}
@@ -215,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
}
func (st *CookieGenerator) AddMacs(msg []byte) {
-
size := len(msg)
smac2 := size - blake2s.Size128
diff --git a/device/cookie_test.go b/device/cookie_test.go
index 02e01d1..4f1e50a 100644
--- a/device/cookie_test.go
+++ b/device/cookie_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -10,7 +10,6 @@ import (
)
func TestCookieMAC1(t *testing.T) {
-
// setup generator / checker
var (
@@ -132,12 +131,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20
- srcBad1 := []byte{192, 168, 13, 37, 40, 01}
+ srcBad1 := []byte{192, 168, 13, 37, 40, 1}
if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed")
}
- srcBad2 := []byte{192, 168, 13, 38, 40, 01}
+ srcBad2 := []byte{192, 168, 13, 38, 40, 1}
if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed")
}
diff --git a/device/device.go b/device/device.go
index 5644c8a..83c33ee 100644
--- a/device/device.go
+++ b/device/device.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -30,7 +30,7 @@ type Device struct {
// will become the actual state; Up can fail.
// The device can also change state multiple times between time of check and time of use.
// Unsynchronized uses of state must therefore be advisory/best-effort only.
- state uint32 // actually a deviceState, but typed uint32 for convenience
+ state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
// stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup
// mu protects state changes.
@@ -44,6 +44,7 @@ type Device struct {
netlinkCancel *rwcancel.RWCancel
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
+ brokenRoaming bool
}
staticIdentity struct {
@@ -52,24 +53,26 @@ type Device struct {
publicKey NoisePublicKey
}
- rate struct {
- underLoadUntil int64
- limiter ratelimiter.Ratelimiter
- }
-
peers struct {
sync.RWMutex // protects keyMap
keyMap map[NoisePublicKey]*Peer
}
+ rate struct {
+ underLoadUntil atomic.Int64
+ limiter ratelimiter.Ratelimiter
+ }
+
allowedips AllowedIPs
indexTable IndexTable
cookieChecker CookieChecker
pool struct {
- messageBuffers *WaitPool
- inboundElements *WaitPool
- outboundElements *WaitPool
+ inboundElementsContainer *WaitPool
+ outboundElementsContainer *WaitPool
+ messageBuffers *WaitPool
+ inboundElements *WaitPool
+ outboundElements *WaitPool
}
queue struct {
@@ -80,7 +83,7 @@ type Device struct {
tun struct {
device tun.Device
- mtu int32
+ mtu atomic.Int32
}
ipcMutex sync.RWMutex
@@ -92,10 +95,9 @@ type Device struct {
// There are three states: down, up, closed.
// Transitions:
//
-// down -----+
-// ↑↓ ↓
-// up -> closed
-//
+// down -----+
+// ↑↓ ↓
+// up -> closed
type deviceState uint32
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
@@ -108,7 +110,7 @@ const (
// deviceState returns device.state.state as a deviceState
// See those docs for how to interpret this value.
func (device *Device) deviceState() deviceState {
- return deviceState(atomic.LoadUint32(&device.state.state))
+ return deviceState(device.state.state.Load())
}
// isClosed reports whether the device is closed (or is closing).
@@ -147,14 +149,14 @@ func (device *Device) changeState(want deviceState) (err error) {
case old:
return nil
case deviceStateUp:
- atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
+ device.state.state.Store(uint32(deviceStateUp))
err = device.upLocked()
if err == nil {
break
}
fallthrough // up failed; bring the device all the way back down
case deviceStateDown:
- atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
+ device.state.state.Store(uint32(deviceStateDown))
errDown := device.downLocked()
if err == nil {
err = errDown
@@ -172,10 +174,15 @@ func (device *Device) upLocked() error {
return err
}
+ // The IPC set operation waits for peers to be created before calling Start() on them,
+ // so if there's a concurrent IPC set request happening, we should wait for it to complete.
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
+
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
- if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
+ if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
}
}
@@ -212,11 +219,11 @@ func (device *Device) IsUnderLoad() bool {
now := time.Now()
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
if underLoad {
- atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
+ device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
return true
}
// check if recently under load
- return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
+ return device.rate.underLoadUntil.Load() > now.UnixNano()
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
@@ -260,7 +267,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
@@ -276,7 +283,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device)
- device.state.state = uint32(deviceStateDown)
+ device.state.state.Store(uint32(deviceStateDown))
device.closed = make(chan struct{})
device.log = logger
device.net.bind = bind
@@ -286,10 +293,11 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
mtu = DefaultMTU
}
- device.tun.mtu = int32(mtu)
+ device.tun.mtu.Store(int32(mtu))
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
+
device.PopulatePools()
// create queues
@@ -317,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
return device
}
+// BatchSize returns the BatchSize for the device as a whole which is the max of
+// the bind batch size and the tun batch size. The batch size reported by device
+// is the size used to construct memory pools, and is the allowed batch size for
+// the lifetime of the device.
+func (device *Device) BatchSize() int {
+ size := device.net.bind.BatchSize()
+ dSize := device.tun.device.BatchSize()
+ if size < dSize {
+ size = dSize
+ }
+ return size
+}
+
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
@@ -349,10 +370,12 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() {
device.state.Lock()
defer device.state.Unlock()
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
if device.isClosed() {
return
}
- atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
+ device.state.state.Store(uint32(deviceStateClosed))
device.log.Verbosef("Device closing")
device.tun.device.Close()
@@ -438,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@@ -467,11 +486,13 @@ func (device *Device) BindUpdate() error {
var err error
var recvFns []conn.ReceiveFunc
netc := &device.net
+
recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil {
netc.port = 0
return err
}
+
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
@@ -490,11 +511,7 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@@ -502,8 +519,9 @@ func (device *Device) BindUpdate() error {
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
- go device.RoutineReceiveIncoming(fn)
+ go device.RoutineReceiveIncoming(batchSize, fn)
}
device.log.Verbosef("UDP bind has been updated")
diff --git a/device/device_test.go b/device/device_test.go
index 29daeb9..fff172b 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -11,7 +11,8 @@ import (
"fmt"
"io"
"math/rand"
- "net"
+ "net/netip"
+ "os"
"runtime"
"runtime/pprof"
"sync"
@@ -21,6 +22,7 @@ import (
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
+ "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/tuntest"
)
@@ -48,7 +50,7 @@ func uapiCfg(cfg ...string) string {
// genConfigs generates a pair of configs that connect to each other.
// The configs use distinct, probably-usable ports.
-func genConfigs(tb testing.TB) (cfgs [2]string, endpointCfgs [2]string) {
+func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
var key1, key2 NoisePrivateKey
_, err := rand.Read(key1[:])
if err != nil {
@@ -96,7 +98,7 @@ type testPair [2]testPeer
type testPeer struct {
tun *tuntest.ChannelTUN
dev *Device
- ip net.IP
+ ip netip.Addr
}
type SendDirection bool
@@ -159,7 +161,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
for i := range pair {
p := &pair[i]
p.tun = tuntest.NewChannelTUN()
- p.ip = net.IPv4(1, 0, 0, byte(i+1))
+ p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
@@ -307,6 +309,17 @@ func TestConcurrencySafety(t *testing.T) {
}
})
+ // Perform bind updates and keepalive sends concurrently with tunnel use.
+ t.Run("bindUpdate and keepalive", func(t *testing.T) {
+ const iters = 10
+ for i := 0; i < iters; i++ {
+ for _, peer := range pair {
+ peer.dev.BindUpdate()
+ peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
+ }
+ }
+ })
+
close(done)
}
@@ -333,7 +346,7 @@ func BenchmarkThroughput(b *testing.B) {
// Measure how long it takes to receive b.N packets,
// starting when we receive the first packet.
- var recv uint64
+ var recv atomic.Uint64
var elapsed time.Duration
var wg sync.WaitGroup
wg.Add(1)
@@ -342,7 +355,7 @@ func BenchmarkThroughput(b *testing.B) {
var start time.Time
for {
<-pair[0].tun.Inbound
- new := atomic.AddUint64(&recv, 1)
+ new := recv.Add(1)
if new == 1 {
start = time.Now()
}
@@ -358,7 +371,7 @@ func BenchmarkThroughput(b *testing.B) {
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
pingc := pair[1].tun.Outbound
var sent uint64
- for atomic.LoadUint64(&recv) != uint64(b.N) {
+ for recv.Load() != uint64(b.N) {
sent++
pingc <- ping
}
@@ -405,3 +418,59 @@ func goroutineLeakCheck(t *testing.T) {
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
})
}
+
+type fakeBindSized struct {
+ size int
+}
+
+func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
+ return nil, 0, nil
+}
+func (b *fakeBindSized) Close() error { return nil }
+func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
+func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
+func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
+func (b *fakeBindSized) BatchSize() int { return b.size }
+
+type fakeTUNDeviceSized struct {
+ size int
+}
+
+func (t *fakeTUNDeviceSized) File() *os.File { return nil }
+func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
+ return 0, nil
+}
+func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
+func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
+func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
+func (t *fakeTUNDeviceSized) Close() error { return nil }
+func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
+
+func TestBatchSize(t *testing.T) {
+ d := Device{}
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 1, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{1}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{1}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+
+ d.net.bind = &fakeBindSized{128}
+ d.tun.device = &fakeTUNDeviceSized{128}
+ if want, got := 128, d.BatchSize(); got != want {
+ t.Errorf("expected batch size %d, got %d", want, got)
+ }
+}
diff --git a/device/endpoint_test.go b/device/endpoint_test.go
index 57c361c..93a4998 100644
--- a/device/endpoint_test.go
+++ b/device/endpoint_test.go
@@ -1,53 +1,49 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
- "net"
+ "net/netip"
)
type DummyEndpoint struct {
- src [16]byte
- dst [16]byte
+ src, dst netip.Addr
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
- var end DummyEndpoint
- if _, err := rand.Read(end.src[:]); err != nil {
+ var src, dst [16]byte
+ if _, err := rand.Read(src[:]); err != nil {
return nil, err
}
- _, err := rand.Read(end.dst[:])
- return &end, err
+ _, err := rand.Read(dst[:])
+ return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
- var addr net.UDPAddr
- addr.IP = e.SrcIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.SrcIP(), 1000).String()
}
func (e *DummyEndpoint) DstToString() string {
- var addr net.UDPAddr
- addr.IP = e.DstIP()
- addr.Port = 1000
- return addr.String()
+ return netip.AddrPortFrom(e.DstIP(), 1000).String()
}
-func (e *DummyEndpoint) SrcToBytes() []byte {
- return e.src[:]
+func (e *DummyEndpoint) DstToBytes() []byte {
+ out := e.DstIP().AsSlice()
+ out = append(out, byte(1000&0xff))
+ out = append(out, byte((1000>>8)&0xff))
+ return out
}
-func (e *DummyEndpoint) DstIP() net.IP {
- return e.dst[:]
+func (e *DummyEndpoint) DstIP() netip.Addr {
+ return e.dst
}
-func (e *DummyEndpoint) SrcIP() net.IP {
- return e.src[:]
+func (e *DummyEndpoint) SrcIP() netip.Addr {
+ return e.src
}
diff --git a/device/indextable.go b/device/indextable.go
index 15b8b0d..00ade7d 100644
--- a/device/indextable.go
+++ b/device/indextable.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/ip.go b/device/ip.go
index 50c5040..eaf2363 100644
--- a/device/ip.go
+++ b/device/ip.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/kdf_test.go b/device/kdf_test.go
index b14aa6d..f9c76d6 100644
--- a/device/kdf_test.go
+++ b/device/kdf_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -20,7 +20,7 @@ type KDFTest struct {
t2 string
}
-func assertEquals(t *testing.T, a string, b string) {
+func assertEquals(t *testing.T, a, b string) {
if a != b {
t.Fatal("expected", a, "=", b)
}
diff --git a/device/keypair.go b/device/keypair.go
index 788c947..e3540d7 100644
--- a/device/keypair.go
+++ b/device/keypair.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -10,7 +10,6 @@ import (
"sync"
"sync/atomic"
"time"
- "unsafe"
"golang.zx2c4.com/wireguard/replay"
)
@@ -23,7 +22,7 @@ import (
*/
type Keypair struct {
- sendNonce uint64 // accessed atomically
+ sendNonce atomic.Uint64
send cipher.AEAD
receive cipher.AEAD
replayFilter replay.Filter
@@ -37,15 +36,7 @@ type Keypairs struct {
sync.RWMutex
current *Keypair
previous *Keypair
- next *Keypair
-}
-
-func (kp *Keypairs) storeNext(next *Keypair) {
- atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
-}
-
-func (kp *Keypairs) loadNext() *Keypair {
- return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
+ next atomic.Pointer[Keypair]
}
func (kp *Keypairs) Current() *Keypair {
diff --git a/device/logger.go b/device/logger.go
index c79c971..22b0df0 100644
--- a/device/logger.go
+++ b/device/logger.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -16,8 +16,8 @@ import (
// They do not require a trailing newline in the format.
// If nil, that level of logging will be silent.
type Logger struct {
- Verbosef func(format string, args ...interface{})
- Errorf func(format string, args ...interface{})
+ Verbosef func(format string, args ...any)
+ Errorf func(format string, args ...any)
}
// Log levels for use with NewLogger.
@@ -28,14 +28,14 @@ const (
)
// Function for use in Logger for discarding logged lines.
-func DiscardLogf(format string, args ...interface{}) {}
+func DiscardLogf(format string, args ...any) {}
// NewLogger constructs a Logger that writes to stdout.
// It logs at the specified log level and above.
// It decorates log lines with the log level, date, time, and prepend.
func NewLogger(level int, prepend string) *Logger {
logger := &Logger{DiscardLogf, DiscardLogf}
- logf := func(prefix string) func(string, ...interface{}) {
+ logf := func(prefix string) func(string, ...any) {
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
}
if level >= LogLevelVerbose {
diff --git a/device/misc.go b/device/misc.go
deleted file mode 100644
index 4126704..0000000
--- a/device/misc.go
+++ /dev/null
@@ -1,41 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "sync/atomic"
-)
-
-/* Atomic Boolean */
-
-const (
- AtomicFalse = int32(iota)
- AtomicTrue
-)
-
-type AtomicBool struct {
- int32
-}
-
-func (a *AtomicBool) Get() bool {
- return atomic.LoadInt32(&a.int32) == AtomicTrue
-}
-
-func (a *AtomicBool) Swap(val bool) bool {
- flag := AtomicFalse
- if val {
- flag = AtomicTrue
- }
- return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
-}
-
-func (a *AtomicBool) Set(val bool) {
- flag := AtomicFalse
- if val {
- flag = AtomicTrue
- }
- atomic.StoreInt32(&a.int32, flag)
-}
diff --git a/device/mobilequirks.go b/device/mobilequirks.go
index f27d9d7..0a0080e 100644
--- a/device/mobilequirks.go
+++ b/device/mobilequirks.go
@@ -1,16 +1,19 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
+// though it will try to deal with it, and race maybe, if called after.
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
+ device.net.brokenRoaming = true
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- peer.disableRoaming = peer.endpoint != nil
- peer.Unlock()
+ peer.endpoint.Lock()
+ peer.endpoint.disableRoaming = peer.endpoint.val != nil
+ peer.endpoint.Unlock()
}
device.peers.RUnlock()
}
diff --git a/device/noise-helpers.go b/device/noise-helpers.go
index b6b51fd..c2f356b 100644
--- a/device/noise-helpers.go
+++ b/device/noise-helpers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -9,6 +9,7 @@ import (
"crypto/hmac"
"crypto/rand"
"crypto/subtle"
+ "errors"
"hash"
"golang.org/x/crypto/blake2s"
@@ -94,9 +95,14 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
return
}
-func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
+var errInvalidPublicKey = errors.New("invalid public key")
+
+func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk)
- return ss
+ if isZero(ss[:]) {
+ return ss, errInvalidPublicKey
+ }
+ return ss, nil
}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index 0212b7d..e8f6145 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -138,11 +138,11 @@ var (
ZeroNonce [chacha20poly1305.NonceSize]byte
)
-func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
+func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data)
}
-func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
+func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil)
hash.Write(h[:])
hash.Write(data)
@@ -175,8 +175,6 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
- var errZeroECDHResult = errors.New("ECDH returned all zeros")
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -204,9 +202,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- if isZero(ss[:]) {
- return nil, errZeroECDHResult
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
}
var key [chacha20poly1305.KeySize]byte
KDF2(
@@ -221,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
// encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
- return nil, errZeroECDHResult
+ return nil, errInvalidPublicKey
}
KDF2(
&handshake.chainKey,
@@ -264,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
- var err error
var peerPK NoisePublicKey
var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- if isZero(ss[:]) {
+ ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
@@ -282,7 +279,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// lookup peer
peer := device.LookupPeer(peerPK)
- if peer == nil {
+ if peer == nil || !peer.isRunning.Load() {
return nil
}
@@ -384,12 +381,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
- handshake.mixKey(ss[:])
- ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- handshake.mixKey(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
+ ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ if err != nil {
+ return nil, err
+ }
+ handshake.mixKey(ss[:])
// add preshared key
@@ -406,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mixHash(tau[:])
- func() {
- aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
- handshake.mixHash(msg.Empty[:])
- }()
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
+ handshake.mixHash(msg.Empty[:])
handshake.state = handshakeResponseCreated
@@ -436,7 +435,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
)
ok := func() bool {
-
// lock handshake state
handshake.mutex.RLock()
@@ -456,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
- func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
- func() {
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
- mixKey(&chainKey, &chainKey, ss[:])
- setZero(ss[:])
- }()
+ ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ if err != nil {
+ return false
+ }
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
// add preshared key (psk)
@@ -484,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
- _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
@@ -582,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.loadNext()
+ next := keypairs.next.Load()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.storeNext(nil)
+ keypairs.next.Store(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@@ -596,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.storeNext(keypair)
+ keypairs.next.Store(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@@ -608,18 +608,18 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
- if keypairs.loadNext() != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.loadNext() != receivedKeypair {
+ if keypairs.next.Load() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.loadNext()
- keypairs.storeNext(nil)
+ keypairs.current = keypairs.next.Load()
+ keypairs.next.Store(nil)
return true
}
diff --git a/device/noise-types.go b/device/noise-types.go
index 6e850e7..e850359 100644
--- a/device/noise-types.go
+++ b/device/noise-types.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/noise_test.go b/device/noise_test.go
index 807ca2d..2dd5324 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -24,10 +24,10 @@ func TestCurveWrappers(t *testing.T) {
pk1 := sk1.publicKey()
pk2 := sk2.publicKey()
- ss1 := sk1.sharedSecret(pk2)
- ss2 := sk2.sharedSecret(pk1)
+ ss1, err1 := sk1.sharedSecret(pk2)
+ ss2, err2 := sk2.sharedSecret(pk1)
- if ss1 != ss2 {
+ if ss1 != ss2 || err1 != nil || err2 != nil {
t.Fatal("Failed to compute shared secet")
}
}
@@ -71,6 +71,8 @@ func TestNoiseHandshake(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ peer1.Start()
+ peer2.Start()
assertEqual(
t,
@@ -146,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}
- key1 := peer1.keypairs.loadNext()
+ key1 := peer1.keypairs.next.Load()
key2 := peer2.keypairs.current
// encrypting / decryption test
diff --git a/device/peer.go b/device/peer.go
index c8b825d..47a2f14 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -16,36 +16,31 @@ import (
)
type Peer struct {
- isRunning AtomicBool
- sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
- keypairs Keypairs
- handshake Handshake
- device *Device
- endpoint conn.Endpoint
- stopping sync.WaitGroup // routines pending stop
-
- // These fields are accessed with atomic operations, which must be
- // 64-bit aligned even on 32-bit platforms. Go guarantees that an
- // allocated struct will be 64-bit aligned. So we place
- // atomically-accessed fields up front, so that they can share in
- // this alignment before smaller fields throw it off.
- stats struct {
- txBytes uint64 // bytes send to peer (endpoint)
- rxBytes uint64 // bytes received from peer
- lastHandshakeNano int64 // nano seconds since epoch
+ isRunning atomic.Bool
+ keypairs Keypairs
+ handshake Handshake
+ device *Device
+ stopping sync.WaitGroup // routines pending stop
+ txBytes atomic.Uint64 // bytes send to peer (endpoint)
+ rxBytes atomic.Uint64 // bytes received from peer
+ lastHandshakeNano atomic.Int64 // nano seconds since epoch
+
+ endpoint struct {
+ sync.Mutex
+ val conn.Endpoint
+ clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
+ disableRoaming bool
}
- disableRoaming bool
-
timers struct {
retransmitHandshake *Timer
sendKeepalive *Timer
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
- handshakeAttempts uint32
- needAnotherKeepalive AtomicBool
- sentLastMinuteHandshake AtomicBool
+ handshakeAttempts atomic.Uint32
+ needAnotherKeepalive atomic.Bool
+ sentLastMinuteHandshake atomic.Bool
}
state struct {
@@ -53,14 +48,14 @@ type Peer struct {
}
queue struct {
- staged chan *QueueOutboundElement // staged packets before a handshake is available
- outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
- inbound *autodrainingInboundQueue // sequential ordering of tun writing
+ staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
+ outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
+ inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
trieEntries list.List
- persistentKeepaliveInterval uint32 // accessed atomically
+ persistentKeepaliveInterval atomic.Uint32
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
@@ -82,14 +77,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer
peer := new(Peer)
- peer.Lock()
- defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
- peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
+ peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]
@@ -100,26 +93,27 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk
handshake.mutex.Unlock()
// reset endpoint
- peer.endpoint = nil
+ peer.endpoint.Lock()
+ peer.endpoint.val = nil
+ peer.endpoint.disableRoaming = false
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.Unlock()
+
+ // init timers
+ peer.timersInit()
// add
device.peers.keyMap[pk] = peer
- // start peer
- peer.timersInit()
- if peer.device.isUp() {
- peer.Start()
- }
-
return peer, nil
}
-func (peer *Peer) SendBuffer(buffer []byte) error {
+func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
@@ -127,16 +121,25 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return nil
}
- peer.RLock()
- defer peer.RUnlock()
-
- if peer.endpoint == nil {
+ peer.endpoint.Lock()
+ endpoint := peer.endpoint.val
+ if endpoint == nil {
+ peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
+ if peer.endpoint.clearSrcOnTx {
+ endpoint.ClearSrc()
+ peer.endpoint.clearSrcOnTx = false
+ }
+ peer.endpoint.Unlock()
- err := peer.device.net.bind.Send(buffer, peer.endpoint)
+ err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
- atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
+ var totalLen uint64
+ for _, b := range buffers {
+ totalLen += uint64(len(b))
+ }
+ peer.txBytes.Add(totalLen)
}
return err
}
@@ -177,7 +180,7 @@ func (peer *Peer) Start() {
peer.state.Lock()
defer peer.state.Unlock()
- if peer.isRunning.Get() {
+ if peer.isRunning.Load() {
return
}
@@ -198,10 +201,14 @@ func (peer *Peer) Start() {
device.flushInboundQueue(peer.queue.inbound)
device.flushOutboundQueue(peer.queue.outbound)
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
- peer.isRunning.Set(true)
+ // Use the device batch size, not the bind batch size, as the device size is
+ // the size of the batch pools.
+ batchSize := peer.device.BatchSize()
+ go peer.RoutineSequentialSender(batchSize)
+ go peer.RoutineSequentialReceiver(batchSize)
+
+ peer.isRunning.Store(true)
}
func (peer *Peer) ZeroAndFlushAll() {
@@ -213,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.loadNext())
+ device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil
keypairs.current = nil
- keypairs.storeNext(nil)
+ keypairs.next.Store(nil)
keypairs.Unlock()
// clear handshake state
@@ -241,11 +248,10 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs := &peer.keypairs
keypairs.Lock()
if keypairs.current != nil {
- atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages)
+ keypairs.current.sendNonce.Store(RejectAfterMessages)
}
- if keypairs.next != nil {
- next := keypairs.loadNext()
- atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
+ if next := keypairs.next.Load(); next != nil {
+ next.sendNonce.Store(RejectAfterMessages)
}
keypairs.Unlock()
}
@@ -271,10 +277,20 @@ func (peer *Peer) Stop() {
}
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
- if peer.disableRoaming {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.disableRoaming {
+ return
+ }
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.val = endpoint
+}
+
+func (peer *Peer) markEndpointSrcForClearing() {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.val == nil {
return
}
- peer.Lock()
- peer.endpoint = endpoint
- peer.Unlock()
+ peer.endpoint.clearSrcOnTx = true
}
diff --git a/device/pools.go b/device/pools.go
index f1d1fa0..94f3dc7 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -14,49 +14,85 @@ type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
- count uint32
+ count atomic.Uint32
max uint32
}
-func NewWaitPool(max uint32, new func() interface{}) *WaitPool {
+func NewWaitPool(max uint32, new func() any) *WaitPool {
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
p.cond = sync.Cond{L: &p.lock}
return p
}
-func (p *WaitPool) Get() interface{} {
+func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
- for atomic.LoadUint32(&p.count) >= p.max {
+ for p.count.Load() >= p.max {
p.cond.Wait()
}
- atomic.AddUint32(&p.count, 1)
+ p.count.Add(1)
p.lock.Unlock()
}
return p.pool.Get()
}
-func (p *WaitPool) Put(x interface{}) {
+func (p *WaitPool) Put(x any) {
p.pool.Put(x)
if p.max == 0 {
return
}
- atomic.AddUint32(&p.count, ^uint32(0))
+ p.count.Add(^uint32(0))
p.cond.Signal()
}
func (device *Device) PopulatePools() {
- device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+ device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueInboundElement, 0, device.BatchSize())
+ return &QueueInboundElementsContainer{elems: s}
+ })
+ device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueOutboundElement, 0, device.BatchSize())
+ return &QueueOutboundElementsContainer{elems: s}
+ })
+ device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
})
- device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+ device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new(QueueInboundElement)
})
- device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+ device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new(QueueOutboundElement)
})
}
+func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
+ c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
+}
+
+func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
+ }
+ c.elems = c.elems[:0]
+ device.pool.inboundElementsContainer.Put(c)
+}
+
+func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
+ c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
+}
+
+func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
+ }
+ c.elems = c.elems[:0]
+ device.pool.outboundElementsContainer.Put(c)
+}
+
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}
diff --git a/device/pools_test.go b/device/pools_test.go
index b0840e1..82d7493 100644
--- a/device/pools_test.go
+++ b/device/pools_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -17,29 +17,31 @@ import (
func TestWaitPool(t *testing.T) {
t.Skip("Currently disabled")
var wg sync.WaitGroup
- trials := int32(100000)
+ var trials atomic.Int32
+ startTrials := int32(100000)
if raceEnabled {
// This test can be very slow with -race.
- trials /= 10
+ startTrials /= 10
}
+ trials.Store(startTrials)
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
t.Skip("Not enough cores")
}
- p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) })
+ p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers)
- max := uint32(0)
+ var max atomic.Uint32
updateMax := func() {
- count := atomic.LoadUint32(&p.count)
+ count := p.count.Load()
if count > p.max {
t.Errorf("count (%d) > max (%d)", count, p.max)
}
for {
- old := atomic.LoadUint32(&max)
+ old := max.Load()
if count <= old {
break
}
- if atomic.CompareAndSwapUint32(&max, old, count) {
+ if max.CompareAndSwap(old, count) {
break
}
}
@@ -47,7 +49,7 @@ func TestWaitPool(t *testing.T) {
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
- for atomic.AddInt32(&trials, -1) > 0 {
+ for trials.Add(-1) > 0 {
updateMax()
x := p.Get()
updateMax()
@@ -59,25 +61,74 @@ func TestWaitPool(t *testing.T) {
}()
}
wg.Wait()
- if max != p.max {
+ if max.Load() != p.max {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
}
}
func BenchmarkWaitPool(b *testing.B) {
var wg sync.WaitGroup
- trials := int32(b.N)
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
workers := runtime.NumCPU() + 2
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
- p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) })
+ p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
- for atomic.AddInt32(&trials, -1) > 0 {
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkWaitPoolEmpty(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := NewWaitPool(0, func() any { return make([]byte, 16) })
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
+ x := p.Get()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ p.Put(x)
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+func BenchmarkSyncPool(b *testing.B) {
+ var wg sync.WaitGroup
+ var trials atomic.Int32
+ trials.Store(int32(b.N))
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ b.Skip("Not enough cores")
+ }
+ p := sync.Pool{New: func() any { return make([]byte, 16) }}
+ wg.Add(workers)
+ b.ResetTimer()
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for trials.Add(-1) > 0 {
x := p.Get()
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
p.Put(x)
diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go
index 7abc33a..25f700a 100644
--- a/device/queueconstants_android.go
+++ b/device/queueconstants_android.go
@@ -1,17 +1,19 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+import "golang.zx2c4.com/wireguard/conn"
+
/* Reduce memory consumption for Android */
const (
- QueueStagedSize = 128
+ QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
- MaxSegmentSize = 2200
+ MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 4096
)
diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go
index d5c6927..ea763d0 100644
--- a/device/queueconstants_default.go
+++ b/device/queueconstants_default.go
@@ -1,14 +1,16 @@
-// +build !android,!ios,!windows
+//go:build !android && !ios && !windows
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
+import "golang.zx2c4.com/wireguard/conn"
+
const (
- QueueStagedSize = 128
+ QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go
index 36c8704..acd3cec 100644
--- a/device/queueconstants_ios.go
+++ b/device/queueconstants_ios.go
@@ -1,8 +1,8 @@
-// +build ios
+//go:build ios
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go
index e330a6b..1eee32b 100644
--- a/device/queueconstants_windows.go
+++ b/device/queueconstants_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/race_disabled_test.go b/device/race_disabled_test.go
index 65fac7e..bb5c450 100644
--- a/device/race_disabled_test.go
+++ b/device/race_disabled_test.go
@@ -1,8 +1,8 @@
-//+build !race
+//go:build !race
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/race_enabled_test.go b/device/race_enabled_test.go
index f8ccac3..4e9daea 100644
--- a/device/race_enabled_test.go
+++ b/device/race_enabled_test.go
@@ -1,8 +1,8 @@
-//+build race
+//go:build race
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
diff --git a/device/receive.go b/device/receive.go
index 5857481..1ab3e29 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -11,13 +11,11 @@ import (
"errors"
"net"
"sync"
- "sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
-
"golang.zx2c4.com/wireguard/conn"
)
@@ -29,7 +27,6 @@ type QueueHandshakeElement struct {
}
type QueueInboundElement struct {
- sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
@@ -37,6 +34,11 @@ type QueueInboundElement struct {
endpoint conn.Endpoint
}
+type QueueInboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueInboundElement
+}
+
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
@@ -53,12 +55,12 @@ func (elem *QueueInboundElement) clearPointers() {
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
- if peer.timers.sentLastMinuteHandshake.Get() {
+ if peer.timers.sentLastMinuteHandshake.Load() {
return
}
keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
- peer.timers.sentLastMinuteHandshake.Set(true)
+ peer.timers.sentLastMinuteHandshake.Store(true)
peer.SendHandshakeInitiation(false)
}
}
@@ -68,7 +70,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
+func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@@ -81,20 +83,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
// receive datagrams until conn is closed
- buffer := device.GetMessageBuffer()
-
var (
+ bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
+ bufs = make([][]byte, maxBatchSize)
err error
- size int
- endpoint conn.Endpoint
+ sizes = make([]int, maxBatchSize)
+ count int
+ endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
+ elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
- for {
- size, endpoint, err = recv(buffer[:])
+ for i := range bufsArrs {
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
+ }
+
+ defer func() {
+ for i := 0; i < maxBatchSize; i++ {
+ if bufsArrs[i] != nil {
+ device.PutMessageBuffer(bufsArrs[i])
+ }
+ }
+ }()
+ for {
+ count, err = recv(bufs, sizes, endpoints)
if err != nil {
- device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) {
return
}
@@ -105,101 +120,119 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
- buffer = device.GetMessageBuffer()
continue
}
return
}
deathSpiral = 0
- if size < MinMessageSize {
- continue
- }
+ // handle each packet in the batch
+ for i, size := range sizes[:count] {
+ if size < MinMessageSize {
+ continue
+ }
- // check size of packet
+ // check size of packet
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ packet := bufsArrs[i][:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- var okay bool
+ switch msgType {
- switch msgType {
+ // check if transport
- // check if transport
+ case MessageTransportType:
- case MessageTransportType:
+ // check size
- // check size
+ if len(packet) < MessageTransportSize {
+ continue
+ }
- if len(packet) < MessageTransportSize {
- continue
- }
+ // lookup key pair
- // lookup key pair
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indexTable.Lookup(receiver)
+ keypair := value.keypair
+ if keypair == nil {
+ continue
+ }
- receiver := binary.LittleEndian.Uint32(
- packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
- )
- value := device.indexTable.Lookup(receiver)
- keypair := value.keypair
- if keypair == nil {
- continue
- }
+ // check keypair expiry
- // check keypair expiry
+ if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
- if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ // create work element
+ peer := value.peer
+ elem := device.GetInboundElement()
+ elem.packet = packet
+ elem.buffer = bufsArrs[i]
+ elem.keypair = keypair
+ elem.endpoint = endpoints[i]
+ elem.counter = 0
+
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetInboundElementsContainer()
+ elemsForPeer.Lock()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
continue
- }
-
- // create work element
- peer := value.peer
- elem := device.GetInboundElement()
- elem.packet = packet
- elem.buffer = buffer
- elem.keypair = keypair
- elem.endpoint = endpoint
- elem.counter = 0
- elem.Mutex = sync.Mutex{}
- elem.Lock()
-
- // add to decryption queues
- if peer.isRunning.Get() {
- peer.queue.inbound.c <- elem
- device.queue.decryption.c <- elem
- buffer = device.GetMessageBuffer()
- } else {
- device.PutInboundElement(elem)
- }
- continue
- // otherwise it is a fixed size & handshake related packet
+ // otherwise it is a fixed size & handshake related packet
- case MessageInitiationType:
- okay = len(packet) == MessageInitiationSize
+ case MessageInitiationType:
+ if len(packet) != MessageInitiationSize {
+ continue
+ }
- case MessageResponseType:
- okay = len(packet) == MessageResponseSize
+ case MessageResponseType:
+ if len(packet) != MessageResponseSize {
+ continue
+ }
- case MessageCookieReplyType:
- okay = len(packet) == MessageCookieReplySize
+ case MessageCookieReplyType:
+ if len(packet) != MessageCookieReplySize {
+ continue
+ }
- default:
- device.log.Verbosef("Received message with unknown type")
- }
+ default:
+ device.log.Verbosef("Received message with unknown type")
+ continue
+ }
- if okay {
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
- buffer: buffer,
+ buffer: bufsArrs[i],
packet: packet,
- endpoint: endpoint,
+ endpoint: endpoints[i],
}:
- buffer = device.GetMessageBuffer()
+ bufsArrs[i] = device.GetMessageBuffer()
+ bufs[i] = bufsArrs[i][:]
default:
}
}
+ for peer, elemsContainer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.queue.inbound.c <- elemsContainer
+ device.queue.decryption.c <- elemsContainer
+ } else {
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ device.PutInboundElementsContainer(elemsContainer)
+ }
+ delete(elemsByPeer, peer)
+ }
}
}
@@ -209,26 +242,28 @@ func (device *Device) RoutineDecryption(id int) {
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker %d - started", id)
- for elem := range device.queue.decryption.c {
- // split message into fields
- counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
- content := elem.packet[MessageTransportOffsetContent:]
-
- // decrypt and release to consumer
- var err error
- elem.counter = binary.LittleEndian.Uint64(counter)
- // copy counter to nonce
- binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
- elem.packet, err = elem.keypair.receive.Open(
- content[:0],
- nonce[:],
- content,
- nil,
- )
- if err != nil {
- elem.packet = nil
+ for elemsContainer := range device.queue.decryption.c {
+ for _, elem := range elemsContainer.elems {
+ // split message into fields
+ counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
+ content := elem.packet[MessageTransportOffsetContent:]
+
+ // decrypt and release to consumer
+ var err error
+ elem.counter = binary.LittleEndian.Uint64(counter)
+ // copy counter to nonce
+ binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
+ elem.packet, err = elem.keypair.receive.Open(
+ content[:0],
+ nonce[:],
+ content,
+ nil,
+ )
+ if err != nil {
+ elem.packet = nil
+ }
}
- elem.Unlock()
+ elemsContainer.Unlock()
}
}
@@ -269,7 +304,7 @@ func (device *Device) RoutineHandshake(id int) {
// consume reply
- if peer := entry.peer; peer.isRunning.Get() {
+ if peer := entry.peer; peer.isRunning.Load() {
device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
if !peer.cookieGenerator.ConsumeReply(&reply) {
device.log.Verbosef("Could not decrypt invalid cookie response")
@@ -342,7 +377,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake initiation", peer)
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ peer.rxBytes.Add(uint64(len(elem.packet)))
peer.SendHandshakeResponse()
@@ -370,7 +405,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SetEndpointFromPacket(elem.endpoint)
device.log.Verbosef("%v - Received handshake response", peer)
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ peer.rxBytes.Add(uint64(len(elem.packet)))
// update timers
@@ -395,7 +430,7 @@ func (device *Device) RoutineHandshake(id int) {
}
}
-func (peer *Peer) RoutineSequentialReceiver() {
+func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
@@ -403,89 +438,103 @@ func (peer *Peer) RoutineSequentialReceiver() {
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
- for elem := range peer.queue.inbound.c {
- if elem == nil {
+ bufs := make([][]byte, 0, maxBatchSize)
+
+ for elemsContainer := range peer.queue.inbound.c {
+ if elemsContainer == nil {
return
}
- var err error
- elem.Lock()
- if elem.packet == nil {
- // decryption failed
- goto skip
- }
+ elemsContainer.Lock()
+ validTailPacket := -1
+ dataPacketReceived := false
+ rxBytesLen := uint64(0)
+ for i, elem := range elemsContainer.elems {
+ if elem.packet == nil {
+ // decryption failed
+ continue
+ }
- if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
- goto skip
- }
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
+ continue
+ }
- peer.SetEndpointFromPacket(elem.endpoint)
- if peer.ReceivedWithKeypair(elem.keypair) {
- peer.timersHandshakeComplete()
- peer.SendStagedPackets()
- }
+ validTailPacket = i
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.SetEndpointFromPacket(elem.endpoint)
+ peer.timersHandshakeComplete()
+ peer.SendStagedPackets()
+ }
+ rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
- peer.keepKeyFreshReceiving()
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketReceived()
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
+ if len(elem.packet) == 0 {
+ device.log.Verbosef("%v - Receiving keepalive packet", peer)
+ continue
+ }
+ dataPacketReceived = true
- if len(elem.packet) == 0 {
- device.log.Verbosef("%v - Receiving keepalive packet", peer)
- goto skip
- }
- peer.timersDataReceived()
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
+ continue
+ }
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
- goto skip
- }
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
+ elem.packet = elem.packet[:length]
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.allowedips.Lookup(src) != peer {
+ device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
+ continue
+ }
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
- goto skip
- }
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- if int(length) > len(elem.packet) {
- goto skip
- }
- elem.packet = elem.packet[:length]
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.allowedips.Lookup(src) != peer {
- device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
- goto skip
+ default:
+ device.log.Verbosef("Packet with invalid IP version from %v", peer)
+ continue
}
- default:
- device.log.Verbosef("Packet with invalid IP version from %v", peer)
- goto skip
+ bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
- _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
- if err != nil && !device.isClosed() {
- device.log.Errorf("Failed to write packet to TUN device: %v", err)
+ peer.rxBytes.Add(rxBytesLen)
+ if validTailPacket >= 0 {
+ peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
}
- if len(peer.queue.inbound.c) == 0 {
- err = device.tun.device.Flush()
- if err != nil {
- peer.device.log.Errorf("Unable to flush packets: %v", err)
+ if dataPacketReceived {
+ peer.timersDataReceived()
+ }
+ if len(bufs) > 0 {
+ _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
+ if err != nil && !device.isClosed() {
+ device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
- skip:
- device.PutMessageBuffer(elem.buffer)
- device.PutInboundElement(elem)
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutInboundElement(elem)
+ }
+ bufs = bufs[:0]
+ device.PutInboundElementsContainer(elemsContainer)
}
}
diff --git a/device/send.go b/device/send.go
index b05c69e..769720a 100644
--- a/device/send.go
+++ b/device/send.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -12,12 +12,13 @@ import (
"net"
"os"
"sync"
- "sync/atomic"
"time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/tun"
)
/* Outbound flow
@@ -45,7 +46,6 @@ import (
*/
type QueueOutboundElement struct {
- sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
@@ -53,10 +53,14 @@ type QueueOutboundElement struct {
peer *Peer // related peer
}
+type QueueOutboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueOutboundElement
+}
+
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
elem.buffer = device.GetMessageBuffer()
- elem.Mutex = sync.Mutex{}
elem.nonce = 0
// keypair and peer were cleared (if necessary) by clearPointers.
return elem
@@ -76,14 +80,17 @@ func (elem *QueueOutboundElement) clearPointers() {
/* Queues a keepalive if no packets are queued for peer
*/
func (peer *Peer) SendKeepalive() {
- if len(peer.queue.staged) == 0 && peer.isRunning.Get() {
+ if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement()
+ elemsContainer := peer.device.GetOutboundElementsContainer()
+ elemsContainer.elems = append(elemsContainer.elems, elem)
select {
- case peer.queue.staged <- elem:
+ case peer.queue.staged <- elemsContainer:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
+ peer.device.PutOutboundElementsContainer(elemsContainer)
}
}
peer.SendStagedPackets()
@@ -91,7 +98,7 @@ func (peer *Peer) SendKeepalive() {
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry {
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
+ peer.timers.handshakeAttempts.Store(0)
}
peer.handshake.mutex.RLock()
@@ -117,8 +124,8 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
return err
}
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
@@ -126,7 +133,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
}
@@ -148,8 +155,8 @@ func (peer *Peer) SendHandshakeResponse() error {
return err
}
- var buff [MessageResponseSize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageResponseSize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
@@ -164,7 +171,8 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- err = peer.SendBuffer(packet)
+ // TODO: allocation could be avoided
+ err = peer.SendBuffers([][]byte{packet})
if err != nil {
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
}
@@ -181,10 +189,11 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
return err
}
- var buff [MessageCookieReplySize]byte
- writer := bytes.NewBuffer(buff[:0])
+ var buf [MessageCookieReplySize]byte
+ writer := bytes.NewBuffer(buf[:0])
binary.Write(writer, binary.LittleEndian, reply)
- device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
+ // TODO: allocation could be avoided
+ device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
return nil
}
@@ -193,17 +202,12 @@ func (peer *Peer) keepKeyFreshSending() {
if keypair == nil {
return
}
- nonce := atomic.LoadUint64(&keypair.sendNonce)
+ nonce := keypair.sendNonce.Load()
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
}
-/* Reads packets from the TUN and inserts
- * into staged queue for peer
- *
- * Obs. Single instance per TUN device
- */
func (device *Device) RoutineReadFromTUN() {
defer func() {
device.log.Verbosef("Routine: TUN reader - stopped")
@@ -213,82 +217,123 @@ func (device *Device) RoutineReadFromTUN() {
device.log.Verbosef("Routine: TUN reader - started")
- var elem *QueueOutboundElement
+ var (
+ batchSize = device.BatchSize()
+ readErr error
+ elems = make([]*QueueOutboundElement, batchSize)
+ bufs = make([][]byte, batchSize)
+ elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
+ count = 0
+ sizes = make([]int, batchSize)
+ offset = MessageTransportHeaderSize
+ )
+
+ for i := range elems {
+ elems[i] = device.NewOutboundElement()
+ bufs[i] = elems[i].buffer[:]
+ }
- for {
- if elem != nil {
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ defer func() {
+ for _, elem := range elems {
+ if elem != nil {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
}
- elem = device.NewOutboundElement()
-
- // read packet
-
- offset := MessageTransportHeaderSize
- size, err := device.tun.device.Read(elem.buffer[:], offset)
+ }()
- if err != nil {
- if !device.isClosed() {
- if !errors.Is(err, os.ErrClosed) {
- device.log.Errorf("Failed to read packet from TUN device: %v", err)
- }
- go device.Close()
+ for {
+ // read packets
+ count, readErr = device.tun.device.Read(bufs, sizes, offset)
+ for i := 0; i < count; i++ {
+ if sizes[i] < 1 {
+ continue
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
- return
- }
- if size == 0 || size > MaxContentSize {
- continue
- }
+ elem := elems[i]
+ elem.packet = bufs[i][offset : offset+sizes[i]]
- elem.packet = elem.buffer[offset : offset+size]
+ // lookup peer
+ var peer *Peer
+ switch elem.packet[0] >> 4 {
+ case 4:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer = device.allowedips.Lookup(dst)
- // lookup peer
+ case 6:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer = device.allowedips.Lookup(dst)
- var peer *Peer
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- continue
+ default:
+ device.log.Verbosef("Received packet with unknown IP version")
}
- dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.allowedips.Lookup(dst)
- case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
+ if peer == nil {
continue
}
- dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.allowedips.Lookup(dst)
-
- default:
- device.log.Verbosef("Received packet with unknown IP version")
+ elemsForPeer, ok := elemsByPeer[peer]
+ if !ok {
+ elemsForPeer = device.GetOutboundElementsContainer()
+ elemsByPeer[peer] = elemsForPeer
+ }
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
+ elems[i] = device.NewOutboundElement()
+ bufs[i] = elems[i].buffer[:]
}
- if peer == nil {
- continue
+ for peer, elemsForPeer := range elemsByPeer {
+ if peer.isRunning.Load() {
+ peer.StagePackets(elemsForPeer)
+ peer.SendStagedPackets()
+ } else {
+ for _, elem := range elemsForPeer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsForPeer)
+ }
+ delete(elemsByPeer, peer)
}
- if peer.isRunning.Get() {
- peer.StagePacket(elem)
- elem = nil
- peer.SendStagedPackets()
+
+ if readErr != nil {
+ if errors.Is(readErr, tun.ErrTooManySegments) {
+ // TODO: record stat for this
+ // This will happen if MSS is surprisingly small (< 576)
+ // coincident with reasonably high throughput.
+ device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
+ continue
+ }
+ if !device.isClosed() {
+ if !errors.Is(readErr, os.ErrClosed) {
+ device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
+ }
+ go device.Close()
+ }
+ return
}
}
}
-func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
+func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
for {
select {
- case peer.queue.staged <- elem:
+ case peer.queue.staged <- elems:
return
default:
}
select {
case tooOld := <-peer.queue.staged:
- peer.device.PutMessageBuffer(tooOld.buffer)
- peer.device.PutOutboundElement(tooOld)
+ for _, elem := range tooOld.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(tooOld)
default:
}
}
@@ -301,32 +346,59 @@ top:
}
keypair := peer.keypairs.Current()
- if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
+ if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
peer.SendHandshakeInitiation(false)
return
}
for {
+ var elemsContainerOOO *QueueOutboundElementsContainer
select {
- case elem := <-peer.queue.staged:
- elem.peer = peer
- elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
- if elem.nonce >= RejectAfterMessages {
- atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
- peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans
- goto top
+ case elemsContainer := <-peer.queue.staged:
+ i := 0
+ for _, elem := range elemsContainer.elems {
+ elem.peer = peer
+ elem.nonce = keypair.sendNonce.Add(1) - 1
+ if elem.nonce >= RejectAfterMessages {
+ keypair.sendNonce.Store(RejectAfterMessages)
+ if elemsContainerOOO == nil {
+ elemsContainerOOO = peer.device.GetOutboundElementsContainer()
+ }
+ elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
+ continue
+ } else {
+ elemsContainer.elems[i] = elem
+ i++
+ }
+
+ elem.keypair = keypair
}
+ elemsContainer.Lock()
+ elemsContainer.elems = elemsContainer.elems[:i]
- elem.keypair = keypair
- elem.Lock()
+ if elemsContainerOOO != nil {
+ peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
+ }
+
+ if len(elemsContainer.elems) == 0 {
+ peer.device.PutOutboundElementsContainer(elemsContainer)
+ goto top
+ }
// add to parallel and sequential queue
- if peer.isRunning.Get() {
- peer.queue.outbound.c <- elem
- peer.device.queue.encryption.c <- elem
+ if peer.isRunning.Load() {
+ peer.queue.outbound.c <- elemsContainer
+ peer.device.queue.encryption.c <- elemsContainer
} else {
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
+ for _, elem := range elemsContainer.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(elemsContainer)
+ }
+
+ if elemsContainerOOO != nil {
+ goto top
}
default:
return
@@ -337,9 +409,12 @@ top:
func (peer *Peer) FlushStagedPackets() {
for {
select {
- case elem := <-peer.queue.staged:
- peer.device.PutMessageBuffer(elem.buffer)
- peer.device.PutOutboundElement(elem)
+ case elemsContainer := <-peer.queue.staged:
+ for _, elem := range elemsContainer.elems {
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ }
+ peer.device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
@@ -373,41 +448,38 @@ func (device *Device) RoutineEncryption(id int) {
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker %d - started", id)
- for elem := range device.queue.encryption.c {
- // populate header fields
- header := elem.buffer[:MessageTransportHeaderSize]
+ for elemsContainer := range device.queue.encryption.c {
+ for _, elem := range elemsContainer.elems {
+ // populate header fields
+ header := elem.buffer[:MessageTransportHeaderSize]
- fieldType := header[0:4]
- fieldReceiver := header[4:8]
- fieldNonce := header[8:16]
+ fieldType := header[0:4]
+ fieldReceiver := header[4:8]
+ fieldNonce := header[8:16]
- binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
- binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
- binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
+ binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
+ binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
+ binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
- // pad content to multiple of 16
- paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
- elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
+ // pad content to multiple of 16
+ paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
+ elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
- // encrypt content and release to consumer
+ // encrypt content and release to consumer
- binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
- elem.packet = elem.keypair.send.Seal(
- header,
- nonce[:],
- elem.packet,
- nil,
- )
- elem.Unlock()
+ binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
+ elem.packet = elem.keypair.send.Seal(
+ header,
+ nonce[:],
+ elem.packet,
+ nil,
+ )
+ }
+ elemsContainer.Unlock()
}
}
-/* Sequentially reads packets from queue and sends to endpoint
- *
- * Obs. Single instance per peer.
- * The routine terminates then the outbound queue is closed.
- */
-func (peer *Peer) RoutineSequentialSender() {
+func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
device := peer.device
defer func() {
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
@@ -415,36 +487,57 @@ func (peer *Peer) RoutineSequentialSender() {
}()
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
- for elem := range peer.queue.outbound.c {
- if elem == nil {
+ bufs := make([][]byte, 0, maxBatchSize)
+
+ for elemsContainer := range peer.queue.outbound.c {
+ bufs = bufs[:0]
+ if elemsContainer == nil {
return
}
- elem.Lock()
- if !peer.isRunning.Get() {
+ if !peer.isRunning.Load() {
// peer has been stopped; return re-usable elems to the shared pool.
// This is an optimization only. It is possible for the peer to be stopped
// immediately after this check, in which case, elem will get processed.
- // The timers and SendBuffer code are resilient to a few stragglers.
+ // The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary.
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
continue
}
+ dataSent := false
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
+ if len(elem.packet) != MessageKeepaliveSize {
+ dataSent = true
+ }
+ bufs = append(bufs, elem.packet)
+ }
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
- // send message and return buffer to pool
-
- err := peer.SendBuffer(elem.packet)
- if len(elem.packet) != MessageKeepaliveSize {
+ err := peer.SendBuffers(bufs)
+ if dataSent {
peer.timersDataSent()
}
- device.PutMessageBuffer(elem.buffer)
- device.PutOutboundElement(elem)
+ for _, elem := range elemsContainer.elems {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ device.PutOutboundElementsContainer(elemsContainer)
+ if err != nil {
+ var errGSO conn.ErrUDPGSODisabled
+ if errors.As(err, &errGSO) {
+ device.log.Verbosef(err.Error())
+ err = errGSO.RetryErr
+ }
+ }
if err != nil {
- device.log.Errorf("%v - Failed to send data packet: %v", peer, err)
+ device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
continue
}
diff --git a/device/sticky_default.go b/device/sticky_default.go
index 1cc52f6..1038256 100644
--- a/device/sticky_default.go
+++ b/device/sticky_default.go
@@ -1,4 +1,4 @@
-// +build !linux
+//go:build !linux
package device
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
index 6193ea3..6057ff1 100644
--- a/device/sticky_linux.go
+++ b/device/sticky_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
@@ -25,7 +25,10 @@ import (
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
- if _, ok := bind.(*conn.LinuxSocketBind); !ok {
+ if !conn.StdNetSupportsStickySockets {
+ return nil, nil
+ }
+ if _, ok := bind.(*conn.StdNetBind); !ok {
return nil, nil
}
@@ -107,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
if !ok {
break
}
- pePtr.peer.Lock()
- if &pePtr.peer.endpoint != pePtr.endpoint {
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.Lock()
+ if &pePtr.peer.endpoint.val != pePtr.endpoint {
+ pePtr.peer.endpoint.Unlock()
break
}
- if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
- pePtr.peer.Unlock()
+ if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
+ pePtr.peer.endpoint.Unlock()
break
}
- pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
- pePtr.peer.Unlock()
+ pePtr.peer.endpoint.clearSrcOnTx = true
+ pePtr.peer.endpoint.Unlock()
}
attr = attr[attrhdr.Len:]
}
@@ -131,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
- peer.RLock()
- if peer.endpoint == nil {
- peer.RUnlock()
+ peer.endpoint.Lock()
+ if peer.endpoint.val == nil {
+ peer.endpoint.Unlock()
continue
}
- nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
+ nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
if nativeEP == nil {
- peer.RUnlock()
+ peer.endpoint.Unlock()
continue
}
- if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
- peer.RUnlock()
+ if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
+ peer.endpoint.Unlock()
break
}
nlmsg := struct {
@@ -169,12 +172,12 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8,
Type: unix.RTA_DST,
},
- nativeEP.Dst4().Addr,
+ nativeEP.DstIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
- nativeEP.Src4().Src,
+ nativeEP.SrcIP().As4(),
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
@@ -185,10 +188,10 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
- endpoint: &peer.endpoint,
+ endpoint: &peer.endpoint.val,
}
reqPeerLock.Unlock()
- peer.RUnlock()
+ peer.endpoint.Unlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
@@ -204,7 +207,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
}
func createNetlinkRouteSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
diff --git a/device/timers.go b/device/timers.go
index ee191e5..d4a4ed4 100644
--- a/device/timers.go
+++ b/device/timers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/
@@ -8,12 +8,14 @@
package device
import (
- "math/rand"
"sync"
- "sync/atomic"
"time"
+ _ "unsafe"
)
+//go:linkname fastrandn runtime.fastrandn
+func fastrandn(n uint32) uint32
+
// A Timer manages time-based aspects of the WireGuard protocol.
// Timer roughly copies the interface of the Linux kernel's struct timer_list.
type Timer struct {
@@ -71,11 +73,11 @@ func (timer *Timer) IsPending() bool {
}
func (peer *Peer) timersActive() bool {
- return peer.isRunning.Get() && peer.device != nil && peer.device.isUp()
+ return peer.isRunning.Load() && peer.device != nil && peer.device.isUp()
}
func expiredRetransmitHandshake(peer *Peer) {
- if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
+ if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes {
peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2)
if peer.timersActive() {
@@ -94,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
}
} else {
- atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
- peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
+ peer.timers.handshakeAttempts.Add(1)
+ peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(true)
}
@@ -110,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive()
- if peer.timers.needAnotherKeepalive.Get() {
- peer.timers.needAnotherKeepalive.Set(false)
+ if peer.timers.needAnotherKeepalive.Load() {
+ peer.timers.needAnotherKeepalive.Store(false)
if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
}
@@ -121,13 +119,8 @@ func expiredSendKeepalive(peer *Peer) {
func expiredNewHandshake(peer *Peer) {
peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
- peer.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.Unlock()
+ peer.markEndpointSrcForClearing()
peer.SendHandshakeInitiation(false)
-
}
func expiredZeroKeyMaterial(peer *Peer) {
@@ -136,7 +129,7 @@ func expiredZeroKeyMaterial(peer *Peer) {
}
func expiredPersistentKeepalive(peer *Peer) {
- if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
+ if peer.persistentKeepaliveInterval.Load() > 0 {
peer.SendKeepalive()
}
}
@@ -144,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
- peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+ peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
}
}
@@ -154,7 +147,7 @@ func (peer *Peer) timersDataReceived() {
if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else {
- peer.timers.needAnotherKeepalive.Set(true)
+ peer.timers.needAnotherKeepalive.Store(true)
}
}
}
@@ -176,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
/* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() {
- peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+ peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs)))
}
}
@@ -185,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() {
peer.timers.retransmitHandshake.Del()
}
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
- peer.timers.sentLastMinuteHandshake.Set(false)
- atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
+ peer.timers.handshakeAttempts.Store(0)
+ peer.timers.sentLastMinuteHandshake.Store(false)
+ peer.lastHandshakeNano.Store(time.Now().UnixNano())
}
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
@@ -199,7 +192,7 @@ func (peer *Peer) timersSessionDerived() {
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
- keepalive := atomic.LoadUint32(&peer.persistentKeepaliveInterval)
+ keepalive := peer.persistentKeepaliveInterval.Load()
if keepalive > 0 && peer.timersActive() {
peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second)
}
@@ -214,9 +207,9 @@ func (peer *Peer) timersInit() {
}
func (peer *Peer) timersStart() {
- atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
- peer.timers.sentLastMinuteHandshake.Set(false)
- peer.timers.needAnotherKeepalive.Set(false)
+ peer.timers.handshakeAttempts.Store(0)
+ peer.timers.sentLastMinuteHandshake.Store(false)
+ peer.timers.needAnotherKeepalive.Store(false)
}
func (peer *Peer) timersStop() {
diff --git a/device/tun.go b/device/tun.go
index 4af9548..2a2ace9 100644
--- a/device/tun.go
+++ b/device/tun.go
@@ -1,13 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"fmt"
- "sync/atomic"
"golang.zx2c4.com/wireguard/tun"
)
@@ -33,7 +32,7 @@ func (device *Device) RoutineTUNEventReader() {
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
mtu = MaxContentSize
}
- old := atomic.SwapInt32(&device.tun.mtu, int32(mtu))
+ old := device.tun.mtu.Swap(int32(mtu))
if int(old) != mtu {
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
}
diff --git a/device/uapi.go b/device/uapi.go
index 66ecd48..d81dae3 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -12,10 +12,10 @@ import (
"fmt"
"io"
"net"
+ "net/netip"
"strconv"
"strings"
"sync"
- "sync/atomic"
"time"
"golang.zx2c4.com/wireguard/ipc"
@@ -38,12 +38,12 @@ func (s IPCError) ErrorCode() int64 {
return s.code
}
-func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError {
+func ipcErrorf(code int64, msg string, args ...any) *IPCError {
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
}
var byteBufferPool = &sync.Pool{
- New: func() interface{} { return new(bytes.Buffer) },
+ New: func() any { return new(bytes.Buffer) },
}
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
@@ -55,7 +55,7 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
buf := byteBufferPool.Get().(*bytes.Buffer)
buf.Reset()
defer byteBufferPool.Put(buf)
- sendf := func(format string, args ...interface{}) {
+ sendf := func(format string, args ...any) {
fmt.Fprintf(buf, format, args...)
buf.WriteByte('\n')
}
@@ -72,7 +72,6 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
}
func() {
-
// lock required resources
device.net.RLock()
@@ -98,31 +97,31 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
- // serialize each peer state
-
for _, peer := range device.peers.keyMap {
- peer.RLock()
- defer peer.RUnlock()
-
+ // Serialize peer state.
+ peer.handshake.mutex.RLock()
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
+ peer.handshake.mutex.RUnlock()
sendf("protocol_version=1")
- if peer.endpoint != nil {
- sendf("endpoint=%s", peer.endpoint.DstToString())
+ peer.endpoint.Lock()
+ if peer.endpoint.val != nil {
+ sendf("endpoint=%s", peer.endpoint.val.DstToString())
}
+ peer.endpoint.Unlock()
- nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
+ nano := peer.lastHandshakeNano.Load()
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
sendf("last_handshake_time_sec=%d", secs)
sendf("last_handshake_time_nsec=%d", nano)
- sendf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))
- sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
- sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
+ sendf("tx_bytes=%d", peer.txBytes.Load())
+ sendf("rx_bytes=%d", peer.rxBytes.Load())
+ sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
- device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
- sendf("allowed_ip=%s/%d", ip.String(), cidr)
+ device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
+ sendf("allowed_ip=%s", prefix.String())
return true
})
}
@@ -156,14 +155,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
+ peer.handlePostConfig()
return nil
}
- parts := strings.Split(line, "=")
- if len(parts) != 2 {
- return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q, found %d =-separated parts, want 2", line, len(parts))
+ key, value, ok := strings.Cut(line, "=")
+ if !ok {
+ return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line)
}
- key := parts[0]
- value := parts[1]
if key == "public_key" {
if deviceConfig {
@@ -254,10 +252,21 @@ type ipcSetPeer struct {
*Peer // Peer is the current peer being operated on
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
created bool // new reports whether this is a newly created peer
+ pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
}
func (peer *ipcSetPeer) handlePostConfig() {
- if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp() {
+ if peer.Peer == nil || peer.dummy {
+ return
+ }
+ if peer.created {
+ peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
+ }
+ if peer.device.isUp() {
+ peer.Start()
+ if peer.pkaOn {
+ peer.SendKeepalive()
+ }
peer.SendStagedPackets()
}
}
@@ -334,9 +343,9 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}
- peer.Lock()
- defer peer.Unlock()
- peer.endpoint = endpoint
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ peer.endpoint.val = endpoint
case "persistent_keepalive_interval":
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
@@ -346,17 +355,10 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
}
- old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs))
+ old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
// Send immediate keepalive if we're turning it on and before it wasn't on.
- if old == 0 && secs != 0 {
- if err != nil {
- return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
- }
- if device.isUp() && !peer.dummy {
- peer.SendKeepalive()
- }
- }
+ peer.pkaOn = old == 0 && secs != 0
case "replace_allowed_ips":
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
@@ -370,16 +372,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
-
- _, network, err := net.ParseCIDR(value)
+ prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if peer.dummy {
return nil
}
- ones, _ := network.Mask.Size()
- device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
+ device.allowedips.Insert(prefix, peer.Peer)
case "protocol_version":
if value != "1" {