diff options
Diffstat (limited to 'device/pools.go')
-rw-r--r-- | device/pools.go | 157 |
1 files changed, 94 insertions, 63 deletions
diff --git a/device/pools.go b/device/pools.go index 98f4ef1..94f3dc7 100644 --- a/device/pools.go +++ b/device/pools.go @@ -1,89 +1,120 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device -import "sync" +import ( + "sync" + "sync/atomic" +) -func (device *Device) PopulatePools() { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool = &sync.Pool{ - New: func() interface{} { - return new([MaxMessageSize]byte) - }, - } - device.pool.inboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueInboundElement) - }, - } - device.pool.outboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueOutboundElement) - }, - } - } else { - device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte) - } - device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.inboundElementReuseChan <- new(QueueInboundElement) - } - device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.outboundElementReuseChan <- new(QueueOutboundElement) +type WaitPool struct { + pool sync.Pool + cond sync.Cond + lock sync.Mutex + count atomic.Uint32 + max uint32 +} + +func NewWaitPool(max uint32, new func() any) *WaitPool { + p := &WaitPool{pool: sync.Pool{New: new}, max: max} + p.cond = sync.Cond{L: &p.lock} + return p +} + +func (p *WaitPool) Get() any { + if p.max != 0 { + p.lock.Lock() + for p.count.Load() >= p.max { + p.cond.Wait() } + p.count.Add(1) + p.lock.Unlock() } + return p.pool.Get() } -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - if PreallocatedBuffersPerPool == 0 { - return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte) - } else { - return <-device.pool.messageBufferReuseChan +func (p *WaitPool) Put(x any) { + p.pool.Put(x) + if p.max == 0 { + return } + p.count.Add(^uint32(0)) + p.cond.Signal() } -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool.Put(msg) - } else { - device.pool.messageBufferReuseChan <- msg - } +func (device *Device) PopulatePools() { + device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueInboundElement, 0, device.BatchSize()) + return &QueueInboundElementsContainer{elems: s} + }) + device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &QueueOutboundElementsContainer{elems: s} + }) + device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new([MaxMessageSize]byte) + }) + device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new(QueueInboundElement) + }) + device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new(QueueOutboundElement) + }) } -func (device *Device) GetInboundElement() *QueueInboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.inboundElementPool.Get().(*QueueInboundElement) - } else { - return <-device.pool.inboundElementReuseChan +func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { + c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) + c.Mutex = sync.Mutex{} + return c +} + +func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } + c.elems = c.elems[:0] + device.pool.inboundElementsContainer.Put(c) +} + +func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { + c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutInboundElement(msg *QueueInboundElement) { - if PreallocatedBuffersPerPool == 0 { - device.pool.inboundElementPool.Put(msg) - } else { - device.pool.inboundElementReuseChan <- msg +func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } + c.elems = c.elems[:0] + device.pool.outboundElementsContainer.Put(c) +} + +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + device.pool.messageBuffers.Put(msg) +} + +func (device *Device) GetInboundElement() *QueueInboundElement { + return device.pool.inboundElements.Get().(*QueueInboundElement) +} + +func (device *Device) PutInboundElement(elem *QueueInboundElement) { + elem.clearPointers() + device.pool.inboundElements.Put(elem) } func (device *Device) GetOutboundElement() *QueueOutboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.outboundElementPool.Get().(*QueueOutboundElement) - } else { - return <-device.pool.outboundElementReuseChan - } + return device.pool.outboundElements.Get().(*QueueOutboundElement) } -func (device *Device) PutOutboundElement(msg *QueueOutboundElement) { - if PreallocatedBuffersPerPool == 0 { - device.pool.outboundElementPool.Put(msg) - } else { - device.pool.outboundElementReuseChan <- msg - } +func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { + elem.clearPointers() + device.pool.outboundElements.Put(elem) } |