From 833597b585f460aaa17bad93ad59290ec282e77e Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sat, 22 Sep 2018 06:29:02 +0200 Subject: More pooling --- device.go | 39 +++++---------------------- pools.go | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ receive.go | 40 +++++++++++++++------------ send.go | 34 ++++++++++++++++++----- 4 files changed, 148 insertions(+), 56 deletions(-) create mode 100644 pools.go diff --git a/device.go b/device.go index bbcd0fc..7cf9ba2 100644 --- a/device.go +++ b/device.go @@ -19,8 +19,6 @@ const ( DeviceRoutineNumberAdditional = 2 ) -var preallocatedBuffers = 0 - type Device struct { isUp AtomicBool // device is (going) up isClosed AtomicBool // device is closed? (acting as guard) @@ -68,8 +66,12 @@ type Device struct { } pool struct { - messageBuffers *sync.Pool - reuseChan chan interface{} + messageBufferPool *sync.Pool + messageBufferReuseChan chan *[MaxMessageSize]byte + inboundElementPool *sync.Pool + inboundElementReuseChan chan *QueueInboundElement + outboundElementPool *sync.Pool + outboundElementReuseChan chan *QueueOutboundElement } queue struct { @@ -245,22 +247,6 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { return nil } -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - if preallocatedBuffers == 0 { - return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) - } else { - return (<-device.pool.reuseChan).(*[MaxMessageSize]byte) - } -} - -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - if preallocatedBuffers == 0 { - device.pool.messageBuffers.Put(msg) - } else { - device.pool.reuseChan <- msg - } -} - func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device { device := new(Device) @@ -285,18 +271,7 @@ func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device { device.indexTable.Init() device.allowedips.Reset() - if preallocatedBuffers == 0 { - device.pool.messageBuffers = &sync.Pool{ - New: func() interface{} { - return new([MaxMessageSize]byte) - }, - } - } else { - device.pool.reuseChan = make(chan interface{}, preallocatedBuffers) - for i := 0; i < preallocatedBuffers; i += 1 { - device.pool.reuseChan <- new([MaxMessageSize]byte) - } - } + device.PopulatePools() // create queues diff --git a/pools.go b/pools.go new file mode 100644 index 0000000..fe219f4 --- /dev/null +++ b/pools.go @@ -0,0 +1,91 @@ +/* SPDX-License-Identifier: GPL-2.0 + * + * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved. + */ + +package main + +import "sync" + +var preallocatedBuffers = 0 + +func (device *Device) PopulatePools() { + if preallocatedBuffers == 0 { + device.pool.messageBufferPool = &sync.Pool{ + New: func() interface{} { + return new([MaxMessageSize]byte) + }, + } + device.pool.inboundElementPool = &sync.Pool{ + New: func() interface{} { + return new(QueueInboundElement) + }, + } + device.pool.outboundElementPool = &sync.Pool{ + New: func() interface{} { + return new(QueueOutboundElement) + }, + } + } else { + device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, preallocatedBuffers) + for i := 0; i < preallocatedBuffers; i += 1 { + device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte) + } + device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, preallocatedBuffers) + for i := 0; i < preallocatedBuffers; i += 1 { + device.pool.inboundElementReuseChan <- new(QueueInboundElement) + } + device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, preallocatedBuffers) + for i := 0; i < preallocatedBuffers; i += 1 { + device.pool.outboundElementReuseChan <- new(QueueOutboundElement) + } + } +} + +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + if preallocatedBuffers == 0 { + return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte) + } else { + return <-device.pool.messageBufferReuseChan + } +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + if preallocatedBuffers == 0 { + device.pool.messageBufferPool.Put(msg) + } else { + device.pool.messageBufferReuseChan <- msg + } +} + +func (device *Device) GetInboundElement() *QueueInboundElement { + if preallocatedBuffers == 0 { + return device.pool.inboundElementPool.Get().(*QueueInboundElement) + } else { + return <-device.pool.inboundElementReuseChan + } +} + +func (device *Device) PutInboundElement(msg *QueueInboundElement) { + if preallocatedBuffers == 0 { + device.pool.inboundElementPool.Put(msg) + } else { + device.pool.inboundElementReuseChan <- msg + } +} + +func (device *Device) GetOutboundElement() *QueueOutboundElement { + if preallocatedBuffers == 0 { + return device.pool.outboundElementPool.Get().(*QueueOutboundElement) + } else { + return <-device.pool.outboundElementReuseChan + } +} + +func (device *Device) PutOutboundElement(msg *QueueOutboundElement) { + if preallocatedBuffers == 0 { + device.pool.outboundElementPool.Put(msg) + } else { + device.pool.outboundElementReuseChan <- msg + } +} diff --git a/receive.go b/receive.go index 9bf3af3..ab86913 100644 --- a/receive.go +++ b/receive.go @@ -55,6 +55,7 @@ func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueIn return false } default: + device.PutInboundElement(element) return false } } @@ -168,15 +169,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { } // create work element - peer := value.peer - elem := &QueueInboundElement{ - packet: packet, - buffer: buffer, - keypair: keypair, - dropped: AtomicFalse, - endpoint: endpoint, - } + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = buffer + elem.keypair = keypair + elem.dropped = AtomicFalse + elem.endpoint = endpoint + elem.counter = 0 + elem.mutex = sync.Mutex{} elem.mutex.Lock() // add to decryption queues @@ -246,6 +247,7 @@ func (device *Device) RoutineDecryption() { // check if dropped if elem.IsDropped() { + device.PutInboundElement(elem) continue } @@ -280,7 +282,6 @@ func (device *Device) RoutineDecryption() { elem.Drop() device.PutMessageBuffer(elem.buffer) elem.buffer = nil - elem.mutex.Unlock() } elem.mutex.Unlock() } @@ -487,12 +488,16 @@ func (peer *Peer) RoutineSequentialReceiver() { logDebug := device.log.Debug var elem *QueueInboundElement + var ok bool defer func() { logDebug.Println(peer, "- Routine: sequential receiver - stopped") peer.routines.stopping.Done() - if elem != nil && elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) + if elem != nil { + if elem.buffer != nil { + device.PutMessageBuffer(elem.buffer) + } + device.PutInboundElement(elem) } }() @@ -501,8 +506,11 @@ func (peer *Peer) RoutineSequentialReceiver() { peer.routines.starting.Done() for { - if elem != nil && elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) + if elem != nil { + if elem.buffer != nil { + device.PutMessageBuffer(elem.buffer) + } + device.PutInboundElement(elem) } select { @@ -510,7 +518,7 @@ func (peer *Peer) RoutineSequentialReceiver() { case <-peer.routines.stop: return - case elem, ok := <-peer.queue.inbound: + case elem, ok = <-peer.queue.inbound: if !ok { return @@ -621,9 +629,7 @@ func (peer *Peer) RoutineSequentialReceiver() { offset := MessageTransportOffsetContent atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) - _, err := device.tun.device.Write( - elem.buffer[:offset+len(elem.packet)], - offset) + _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) if err != nil { logError.Println("Failed to write packet to TUN device:", err) } diff --git a/send.go b/send.go index 24e2f39..fa84043 100644 --- a/send.go +++ b/send.go @@ -52,10 +52,14 @@ type QueueOutboundElement struct { } func (device *Device) NewOutboundElement() *QueueOutboundElement { - return &QueueOutboundElement{ - dropped: AtomicFalse, - buffer: device.GetMessageBuffer(), - } + elem := device.GetOutboundElement() + elem.dropped = AtomicFalse + elem.buffer = device.GetMessageBuffer() + elem.mutex = sync.Mutex{} + elem.nonce = 0 + elem.keypair = nil + elem.peer = nil + return elem } func (elem *QueueOutboundElement) Drop() { @@ -75,6 +79,7 @@ func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundEle select { case old := <-queue: device.PutMessageBuffer(old.buffer) + device.PutOutboundElement(old) default: } } @@ -94,6 +99,7 @@ func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, } default: element.peer.device.PutMessageBuffer(element.buffer) + element.peer.device.PutOutboundElement(element) } } @@ -111,6 +117,7 @@ func (peer *Peer) SendKeepalive() bool { return true default: peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) return false } } @@ -236,8 +243,6 @@ func (peer *Peer) keepKeyFreshSending() { */ func (device *Device) RoutineReadFromTUN() { - elem := device.NewOutboundElement() - logDebug := device.log.Debug logError := device.log.Error @@ -249,7 +254,14 @@ func (device *Device) RoutineReadFromTUN() { logDebug.Println("Routine: TUN reader - started") device.state.starting.Done() + var elem *QueueOutboundElement + for { + if elem != nil { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + elem = device.NewOutboundElement() // read packet @@ -262,6 +274,7 @@ func (device *Device) RoutineReadFromTUN() { device.Close() } device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) return } @@ -304,7 +317,7 @@ func (device *Device) RoutineReadFromTUN() { peer.SendHandshakeInitiation(false) } addToNonceQueue(peer.queue.nonce, elem, device) - elem = device.NewOutboundElement() + elem = nil } } } @@ -339,6 +352,7 @@ func (peer *Peer) RoutineNonce() { select { case elem := <-peer.queue.nonce: device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) default: return } @@ -399,11 +413,13 @@ func (peer *Peer) RoutineNonce() { case <-peer.signals.flushNonceQueue: device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) flush() goto NextPacket case <-peer.routines.stop: device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) return } } @@ -419,6 +435,7 @@ func (peer *Peer) RoutineNonce() { if elem.nonce >= RejectAfterMessages { atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) goto NextPacket } @@ -468,6 +485,7 @@ func (device *Device) RoutineEncryption() { // check if dropped if elem.IsDropped() { + device.PutOutboundElement(elem) continue } @@ -544,6 +562,7 @@ func (peer *Peer) RoutineSequentialSender() { elem.mutex.Lock() if elem.IsDropped() { + device.PutOutboundElement(elem) continue } @@ -555,6 +574,7 @@ func (peer *Peer) RoutineSequentialSender() { length := uint64(len(elem.packet)) err := peer.SendBuffer(elem.packet) device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) if err != nil { logError.Println(peer, "- Failed to send data packet", err) continue -- cgit v1.2.3-59-g8ed1b