aboutsummaryrefslogtreecommitdiffstats
path: root/device
diff options
context:
space:
mode:
authorJordan Whited <jordan@tailscale.com>2023-10-02 14:48:28 -0700
committerJason A. Donenfeld <Jason@zx2c4.com>2023-10-10 15:07:36 +0200
commit1ec454f253c068f74ba7a7aea34546c9819493c0 (patch)
treeb7f3af5cb9487c892cc4d2390c10f0ca7f5e86b7 /device
parenttun: reduce redundant checksumming in tcpGRO() (diff)
downloadwireguard-go-1ec454f253c068f74ba7a7aea34546c9819493c0.tar.xz
wireguard-go-1ec454f253c068f74ba7a7aea34546c9819493c0.zip
device: move Queue{In,Out}boundElement Mutex to container type
Queue{In,Out}boundElement locking can contribute to significant overhead via sync.Mutex.lockSlow() in some environments. These types are passed throughout the device package as elements in a slice, so move the per-element Mutex to a container around the slice. Reviewed-by: Maisem Ali <maisem@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to '')
-rw-r--r--device/channels.go32
-rw-r--r--device/device.go10
-rw-r--r--device/peer.go8
-rw-r--r--device/pools.go44
-rw-r--r--device/receive.go43
-rw-r--r--device/send.go95
6 files changed, 121 insertions, 111 deletions
diff --git a/device/channels.go b/device/channels.go
index 40ee5c9..e526f6b 100644
--- a/device/channels.go
+++ b/device/channels.go
@@ -19,13 +19,13 @@ import (
// call wg.Done to remove the initial reference.
// When the refcount hits 0, the queue's channel is closed.
type outboundQueue struct {
- c chan *[]*QueueOutboundElement
+ c chan *QueueOutboundElementsContainer
wg sync.WaitGroup
}
func newOutboundQueue() *outboundQueue {
q := &outboundQueue{
- c: make(chan *[]*QueueOutboundElement, QueueOutboundSize),
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
q.wg.Add(1)
go func() {
@@ -37,13 +37,13 @@ func newOutboundQueue() *outboundQueue {
// A inboundQueue is similar to an outboundQueue; see those docs.
type inboundQueue struct {
- c chan *[]*QueueInboundElement
+ c chan *QueueInboundElementsContainer
wg sync.WaitGroup
}
func newInboundQueue() *inboundQueue {
q := &inboundQueue{
- c: make(chan *[]*QueueInboundElement, QueueInboundSize),
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
q.wg.Add(1)
go func() {
@@ -72,7 +72,7 @@ func newHandshakeQueue() *handshakeQueue {
}
type autodrainingInboundQueue struct {
- c chan *[]*QueueInboundElement
+ c chan *QueueInboundElementsContainer
}
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
@@ -81,7 +81,7 @@ type autodrainingInboundQueue struct {
// some other means, such as sending a sentinel nil values.
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
q := &autodrainingInboundQueue{
- c: make(chan *[]*QueueInboundElement, QueueInboundSize),
+ c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
}
runtime.SetFinalizer(q, device.flushInboundQueue)
return q
@@ -90,13 +90,13 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
for {
select {
- case elems := <-q.c:
- for _, elem := range *elems {
- elem.Lock()
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
- device.PutInboundElementsSlice(elems)
+ device.PutInboundElementsContainer(elemsContainer)
default:
return
}
@@ -104,7 +104,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
}
type autodrainingOutboundQueue struct {
- c chan *[]*QueueOutboundElement
+ c chan *QueueOutboundElementsContainer
}
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
@@ -114,7 +114,7 @@ type autodrainingOutboundQueue struct {
// All sends to the channel must be best-effort, because there may be no receivers.
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
q := &autodrainingOutboundQueue{
- c: make(chan *[]*QueueOutboundElement, QueueOutboundSize),
+ c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
}
runtime.SetFinalizer(q, device.flushOutboundQueue)
return q
@@ -123,13 +123,13 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
for {
select {
- case elems := <-q.c:
- for _, elem := range *elems {
- elem.Lock()
+ case elemsContainer := <-q.c:
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
- device.PutOutboundElementsSlice(elems)
+ device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
diff --git a/device/device.go b/device/device.go
index 1af9fe0..f9557a0 100644
--- a/device/device.go
+++ b/device/device.go
@@ -68,11 +68,11 @@ type Device struct {
cookieChecker CookieChecker
pool struct {
- outboundElementsSlice *WaitPool
- inboundElementsSlice *WaitPool
- messageBuffers *WaitPool
- inboundElements *WaitPool
- outboundElements *WaitPool
+ inboundElementsContainer *WaitPool
+ outboundElementsContainer *WaitPool
+ messageBuffers *WaitPool
+ inboundElements *WaitPool
+ outboundElements *WaitPool
}
queue struct {
diff --git a/device/peer.go b/device/peer.go
index 0ac4896..2fb5da6 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -45,9 +45,9 @@ type Peer struct {
}
queue struct {
- staged chan *[]*QueueOutboundElement // staged packets before a handshake is available
- outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
- inbound *autodrainingInboundQueue // sequential ordering of tun writing
+ staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
+ outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
+ inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
@@ -81,7 +81,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
- peer.queue.staged = make(chan *[]*QueueOutboundElement, QueueStagedSize)
+ peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]
diff --git a/device/pools.go b/device/pools.go
index 02a5d6a..94f3dc7 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -46,13 +46,13 @@ func (p *WaitPool) Put(x any) {
}
func (device *Device) PopulatePools() {
- device.pool.outboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
- s := make([]*QueueOutboundElement, 0, device.BatchSize())
- return &s
- })
- device.pool.inboundElementsSlice = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueInboundElement, 0, device.BatchSize())
- return &s
+ return &QueueInboundElementsContainer{elems: s}
+ })
+ device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueOutboundElement, 0, device.BatchSize())
+ return &QueueOutboundElementsContainer{elems: s}
})
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
@@ -65,28 +65,32 @@ func (device *Device) PopulatePools() {
})
}
-func (device *Device) GetOutboundElementsSlice() *[]*QueueOutboundElement {
- return device.pool.outboundElementsSlice.Get().(*[]*QueueOutboundElement)
+func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
+ c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
}
-func (device *Device) PutOutboundElementsSlice(s *[]*QueueOutboundElement) {
- for i := range *s {
- (*s)[i] = nil
+func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
}
- *s = (*s)[:0]
- device.pool.outboundElementsSlice.Put(s)
+ c.elems = c.elems[:0]
+ device.pool.inboundElementsContainer.Put(c)
}
-func (device *Device) GetInboundElementsSlice() *[]*QueueInboundElement {
- return device.pool.inboundElementsSlice.Get().(*[]*QueueInboundElement)
+func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
+ c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
}
-func (device *Device) PutInboundElementsSlice(s *[]*QueueInboundElement) {
- for i := range *s {
- (*s)[i] = nil
+func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
}
- *s = (*s)[:0]
- device.pool.inboundElementsSlice.Put(s)
+ c.elems = c.elems[:0]
+ device.pool.outboundElementsContainer.Put(c)
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
diff --git a/device/receive.go b/device/receive.go
index f0f37a1..4b32dc5 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -27,7 +27,6 @@ type QueueHandshakeElement struct {
}
type QueueInboundElement struct {
- sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
@@ -35,6 +34,11 @@ type QueueInboundElement struct {
endpoint conn.Endpoint
}
+type QueueInboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueInboundElement
+}
+
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
@@ -87,7 +91,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
count int
endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
- elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
+ elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
for i := range bufsArrs {
@@ -170,15 +174,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
elem.keypair = keypair
elem.endpoint = endpoints[i]
elem.counter = 0
- elem.Mutex = sync.Mutex{}
- elem.Lock()
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
- elemsForPeer = device.GetInboundElementsSlice()
+ elemsForPeer = device.GetInboundElementsContainer()
+ elemsForPeer.Lock()
elemsByPeer[peer] = elemsForPeer
}
- *elemsForPeer = append(*elemsForPeer, elem)
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
continue
@@ -217,16 +220,16 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
default:
}
}
- for peer, elems := range elemsByPeer {
+ for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
- peer.queue.inbound.c <- elems
- device.queue.decryption.c <- elems
+ peer.queue.inbound.c <- elemsContainer
+ device.queue.decryption.c <- elemsContainer
} else {
- for _, elem := range *elems {
+ for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
- device.PutInboundElementsSlice(elems)
+ device.PutInboundElementsContainer(elemsContainer)
}
delete(elemsByPeer, peer)
}
@@ -239,8 +242,8 @@ func (device *Device) RoutineDecryption(id int) {
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker %d - started", id)
- for elems := range device.queue.decryption.c {
- for _, elem := range *elems {
+ for elemsContainer := range device.queue.decryption.c {
+ for _, elem := range elemsContainer.elems {
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
@@ -259,8 +262,8 @@ func (device *Device) RoutineDecryption(id int) {
if err != nil {
elem.packet = nil
}
- elem.Unlock()
}
+ elemsContainer.Unlock()
}
}
@@ -437,12 +440,12 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
bufs := make([][]byte, 0, maxBatchSize)
- for elems := range peer.queue.inbound.c {
- if elems == nil {
+ for elemsContainer := range peer.queue.inbound.c {
+ if elemsContainer == nil {
return
}
- for _, elem := range *elems {
- elem.Lock()
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
if elem.packet == nil {
// decryption failed
continue
@@ -515,11 +518,11 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
- for _, elem := range *elems {
+ for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
bufs = bufs[:0]
- device.PutInboundElementsSlice(elems)
+ device.PutInboundElementsContainer(elemsContainer)
}
}
diff --git a/device/send.go b/device/send.go
index e838c4e..769720a 100644
--- a/device/send.go
+++ b/device/send.go
@@ -46,7 +46,6 @@ import (
*/
type QueueOutboundElement struct {
- sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
@@ -54,10 +53,14 @@ type QueueOutboundElement struct {
peer *Peer // related peer
}
+type QueueOutboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueOutboundElement
+}
+
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
elem.buffer = device.GetMessageBuffer()
- elem.Mutex = sync.Mutex{}
elem.nonce = 0
// keypair and peer were cleared (if necessary) by clearPointers.
return elem
@@ -79,15 +82,15 @@ func (elem *QueueOutboundElement) clearPointers() {
func (peer *Peer) SendKeepalive() {
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
elem := peer.device.NewOutboundElement()
- elems := peer.device.GetOutboundElementsSlice()
- *elems = append(*elems, elem)
+ elemsContainer := peer.device.GetOutboundElementsContainer()
+ elemsContainer.elems = append(elemsContainer.elems, elem)
select {
- case peer.queue.staged <- elems:
+ case peer.queue.staged <- elemsContainer:
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
- peer.device.PutOutboundElementsSlice(elems)
+ peer.device.PutOutboundElementsContainer(elemsContainer)
}
}
peer.SendStagedPackets()
@@ -219,7 +222,7 @@ func (device *Device) RoutineReadFromTUN() {
readErr error
elems = make([]*QueueOutboundElement, batchSize)
bufs = make([][]byte, batchSize)
- elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize)
+ elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
count = 0
sizes = make([]int, batchSize)
offset = MessageTransportHeaderSize
@@ -276,10 +279,10 @@ func (device *Device) RoutineReadFromTUN() {
}
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
- elemsForPeer = device.GetOutboundElementsSlice()
+ elemsForPeer = device.GetOutboundElementsContainer()
elemsByPeer[peer] = elemsForPeer
}
- *elemsForPeer = append(*elemsForPeer, elem)
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
elems[i] = device.NewOutboundElement()
bufs[i] = elems[i].buffer[:]
}
@@ -289,11 +292,11 @@ func (device *Device) RoutineReadFromTUN() {
peer.StagePackets(elemsForPeer)
peer.SendStagedPackets()
} else {
- for _, elem := range *elemsForPeer {
+ for _, elem := range elemsForPeer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
- device.PutOutboundElementsSlice(elemsForPeer)
+ device.PutOutboundElementsContainer(elemsForPeer)
}
delete(elemsByPeer, peer)
}
@@ -317,7 +320,7 @@ func (device *Device) RoutineReadFromTUN() {
}
}
-func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
+func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
for {
select {
case peer.queue.staged <- elems:
@@ -326,11 +329,11 @@ func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) {
}
select {
case tooOld := <-peer.queue.staged:
- for _, elem := range *tooOld {
+ for _, elem := range tooOld.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
- peer.device.PutOutboundElementsSlice(tooOld)
+ peer.device.PutOutboundElementsContainer(tooOld)
default:
}
}
@@ -349,52 +352,52 @@ top:
}
for {
- var elemsOOO *[]*QueueOutboundElement
+ var elemsContainerOOO *QueueOutboundElementsContainer
select {
- case elems := <-peer.queue.staged:
+ case elemsContainer := <-peer.queue.staged:
i := 0
- for _, elem := range *elems {
+ for _, elem := range elemsContainer.elems {
elem.peer = peer
elem.nonce = keypair.sendNonce.Add(1) - 1
if elem.nonce >= RejectAfterMessages {
keypair.sendNonce.Store(RejectAfterMessages)
- if elemsOOO == nil {
- elemsOOO = peer.device.GetOutboundElementsSlice()
+ if elemsContainerOOO == nil {
+ elemsContainerOOO = peer.device.GetOutboundElementsContainer()
}
- *elemsOOO = append(*elemsOOO, elem)
+ elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
continue
} else {
- (*elems)[i] = elem
+ elemsContainer.elems[i] = elem
i++
}
elem.keypair = keypair
- elem.Lock()
}
- *elems = (*elems)[:i]
+ elemsContainer.Lock()
+ elemsContainer.elems = elemsContainer.elems[:i]
- if elemsOOO != nil {
- peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans
+ if elemsContainerOOO != nil {
+ peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
}
- if len(*elems) == 0 {
- peer.device.PutOutboundElementsSlice(elems)
+ if len(elemsContainer.elems) == 0 {
+ peer.device.PutOutboundElementsContainer(elemsContainer)
goto top
}
// add to parallel and sequential queue
if peer.isRunning.Load() {
- peer.queue.outbound.c <- elems
- peer.device.queue.encryption.c <- elems
+ peer.queue.outbound.c <- elemsContainer
+ peer.device.queue.encryption.c <- elemsContainer
} else {
- for _, elem := range *elems {
+ for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
- peer.device.PutOutboundElementsSlice(elems)
+ peer.device.PutOutboundElementsContainer(elemsContainer)
}
- if elemsOOO != nil {
+ if elemsContainerOOO != nil {
goto top
}
default:
@@ -406,12 +409,12 @@ top:
func (peer *Peer) FlushStagedPackets() {
for {
select {
- case elems := <-peer.queue.staged:
- for _, elem := range *elems {
+ case elemsContainer := <-peer.queue.staged:
+ for _, elem := range elemsContainer.elems {
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
}
- peer.device.PutOutboundElementsSlice(elems)
+ peer.device.PutOutboundElementsContainer(elemsContainer)
default:
return
}
@@ -445,8 +448,8 @@ func (device *Device) RoutineEncryption(id int) {
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
device.log.Verbosef("Routine: encryption worker %d - started", id)
- for elems := range device.queue.encryption.c {
- for _, elem := range *elems {
+ for elemsContainer := range device.queue.encryption.c {
+ for _, elem := range elemsContainer.elems {
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
@@ -471,8 +474,8 @@ func (device *Device) RoutineEncryption(id int) {
elem.packet,
nil,
)
- elem.Unlock()
}
+ elemsContainer.Unlock()
}
}
@@ -486,9 +489,9 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
bufs := make([][]byte, 0, maxBatchSize)
- for elems := range peer.queue.outbound.c {
+ for elemsContainer := range peer.queue.outbound.c {
bufs = bufs[:0]
- if elems == nil {
+ if elemsContainer == nil {
return
}
if !peer.isRunning.Load() {
@@ -498,16 +501,16 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
// The timers and SendBuffers code are resilient to a few stragglers.
// TODO: rework peer shutdown order to ensure
// that we never accidentally keep timers alive longer than necessary.
- for _, elem := range *elems {
- elem.Lock()
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
continue
}
dataSent := false
- for _, elem := range *elems {
- elem.Lock()
+ elemsContainer.Lock()
+ for _, elem := range elemsContainer.elems {
if len(elem.packet) != MessageKeepaliveSize {
dataSent = true
}
@@ -521,11 +524,11 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
if dataSent {
peer.timersDataSent()
}
- for _, elem := range *elems {
+ for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
- device.PutOutboundElementsSlice(elems)
+ device.PutOutboundElementsContainer(elemsContainer)
if err != nil {
var errGSO conn.ErrUDPGSODisabled
if errors.As(err, &errGSO) {