aboutsummaryrefslogtreecommitdiffstats
path: root/device/pools.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/pools.go')
-rw-r--r--device/pools.go157
1 files changed, 94 insertions, 63 deletions
diff --git a/device/pools.go b/device/pools.go
index 98f4ef1..94f3dc7 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -1,89 +1,120 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
-import "sync"
+import (
+ "sync"
+ "sync/atomic"
+)
-func (device *Device) PopulatePools() {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.messageBufferPool = &sync.Pool{
- New: func() interface{} {
- return new([MaxMessageSize]byte)
- },
- }
- device.pool.inboundElementPool = &sync.Pool{
- New: func() interface{} {
- return new(QueueInboundElement)
- },
- }
- device.pool.outboundElementPool = &sync.Pool{
- New: func() interface{} {
- return new(QueueOutboundElement)
- },
- }
- } else {
- device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
- device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
- }
- device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
- device.pool.inboundElementReuseChan <- new(QueueInboundElement)
- }
- device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
- device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
+type WaitPool struct {
+ pool sync.Pool
+ cond sync.Cond
+ lock sync.Mutex
+ count atomic.Uint32
+ max uint32
+}
+
+func NewWaitPool(max uint32, new func() any) *WaitPool {
+ p := &WaitPool{pool: sync.Pool{New: new}, max: max}
+ p.cond = sync.Cond{L: &p.lock}
+ return p
+}
+
+func (p *WaitPool) Get() any {
+ if p.max != 0 {
+ p.lock.Lock()
+ for p.count.Load() >= p.max {
+ p.cond.Wait()
}
+ p.count.Add(1)
+ p.lock.Unlock()
}
+ return p.pool.Get()
}
-func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
- } else {
- return <-device.pool.messageBufferReuseChan
+func (p *WaitPool) Put(x any) {
+ p.pool.Put(x)
+ if p.max == 0 {
+ return
}
+ p.count.Add(^uint32(0))
+ p.cond.Signal()
}
-func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.messageBufferPool.Put(msg)
- } else {
- device.pool.messageBufferReuseChan <- msg
- }
+func (device *Device) PopulatePools() {
+ device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueInboundElement, 0, device.BatchSize())
+ return &QueueInboundElementsContainer{elems: s}
+ })
+ device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ s := make([]*QueueOutboundElement, 0, device.BatchSize())
+ return &QueueOutboundElementsContainer{elems: s}
+ })
+ device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ return new([MaxMessageSize]byte)
+ })
+ device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ return new(QueueInboundElement)
+ })
+ device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
+ return new(QueueOutboundElement)
+ })
}
-func (device *Device) GetInboundElement() *QueueInboundElement {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.inboundElementPool.Get().(*QueueInboundElement)
- } else {
- return <-device.pool.inboundElementReuseChan
+func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
+ c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
+}
+
+func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
}
+ c.elems = c.elems[:0]
+ device.pool.inboundElementsContainer.Put(c)
+}
+
+func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
+ c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
+ c.Mutex = sync.Mutex{}
+ return c
}
-func (device *Device) PutInboundElement(msg *QueueInboundElement) {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.inboundElementPool.Put(msg)
- } else {
- device.pool.inboundElementReuseChan <- msg
+func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
+ for i := range c.elems {
+ c.elems[i] = nil
}
+ c.elems = c.elems[:0]
+ device.pool.outboundElementsContainer.Put(c)
+}
+
+func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
+ return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
+}
+
+func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
+ device.pool.messageBuffers.Put(msg)
+}
+
+func (device *Device) GetInboundElement() *QueueInboundElement {
+ return device.pool.inboundElements.Get().(*QueueInboundElement)
+}
+
+func (device *Device) PutInboundElement(elem *QueueInboundElement) {
+ elem.clearPointers()
+ device.pool.inboundElements.Put(elem)
}
func (device *Device) GetOutboundElement() *QueueOutboundElement {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
- } else {
- return <-device.pool.outboundElementReuseChan
- }
+ return device.pool.outboundElements.Get().(*QueueOutboundElement)
}
-func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.outboundElementPool.Put(msg)
- } else {
- device.pool.outboundElementReuseChan <- msg
- }
+func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
+ elem.clearPointers()
+ device.pool.outboundElements.Put(elem)
}