diff options
Diffstat (limited to 'device/pools.go')
-rw-r--r-- | device/pools.go | 58 |
1 files changed, 47 insertions, 11 deletions
diff --git a/device/pools.go b/device/pools.go index f1d1fa0..94f3dc7 100644 --- a/device/pools.go +++ b/device/pools.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -14,49 +14,85 @@ type WaitPool struct { pool sync.Pool cond sync.Cond lock sync.Mutex - count uint32 + count atomic.Uint32 max uint32 } -func NewWaitPool(max uint32, new func() interface{}) *WaitPool { +func NewWaitPool(max uint32, new func() any) *WaitPool { p := &WaitPool{pool: sync.Pool{New: new}, max: max} p.cond = sync.Cond{L: &p.lock} return p } -func (p *WaitPool) Get() interface{} { +func (p *WaitPool) Get() any { if p.max != 0 { p.lock.Lock() - for atomic.LoadUint32(&p.count) >= p.max { + for p.count.Load() >= p.max { p.cond.Wait() } - atomic.AddUint32(&p.count, 1) + p.count.Add(1) p.lock.Unlock() } return p.pool.Get() } -func (p *WaitPool) Put(x interface{}) { +func (p *WaitPool) Put(x any) { p.pool.Put(x) if p.max == 0 { return } - atomic.AddUint32(&p.count, ^uint32(0)) + p.count.Add(^uint32(0)) p.cond.Signal() } func (device *Device) PopulatePools() { - device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { + device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueInboundElement, 0, device.BatchSize()) + return &QueueInboundElementsContainer{elems: s} + }) + device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &QueueOutboundElementsContainer{elems: s} + }) + device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new([MaxMessageSize]byte) }) - device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { + device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new(QueueInboundElement) }) - device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} { + device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { return new(QueueOutboundElement) }) } +func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { + c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) + c.Mutex = sync.Mutex{} + return c +} + +func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil + } + c.elems = c.elems[:0] + device.pool.inboundElementsContainer.Put(c) +} + +func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { + c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) + c.Mutex = sync.Mutex{} + return c +} + +func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil + } + c.elems = c.elems[:0] + device.pool.outboundElementsContainer.Put(c) +} + func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) } |