aboutsummaryrefslogtreecommitdiffstats
path: root/device/device.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--device/device.go45
1 files changed, 29 insertions, 16 deletions
diff --git a/device/device.go b/device/device.go
index f96e277..83c33ee 100644
--- a/device/device.go
+++ b/device/device.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -68,9 +68,11 @@ type Device struct {
cookieChecker CookieChecker
pool struct {
- messageBuffers *WaitPool
- inboundElements *WaitPool
- outboundElements *WaitPool
+ inboundElementsContainer *WaitPool
+ outboundElementsContainer *WaitPool
+ messageBuffers *WaitPool
+ inboundElements *WaitPool
+ outboundElements *WaitPool
}
queue struct {
@@ -265,7 +267,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
expiredPeers = append(expiredPeers, peer)
}
@@ -295,6 +297,7 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.indexTable.Init()
+
device.PopulatePools()
// create queues
@@ -322,6 +325,19 @@ func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
return device
}
+// BatchSize returns the BatchSize for the device as a whole which is the max of
+// the bind batch size and the tun batch size. The batch size reported by device
+// is the size used to construct memory pools, and is the allowed batch size for
+// the lifetime of the device.
+func (device *Device) BatchSize() int {
+ size := device.net.bind.BatchSize()
+ dSize := device.tun.device.BatchSize()
+ if size < dSize {
+ size = dSize
+ }
+ return size
+}
+
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
@@ -354,6 +370,8 @@ func (device *Device) RemoveAllPeers() {
func (device *Device) Close() {
device.state.Lock()
defer device.state.Unlock()
+ device.ipcMutex.Lock()
+ defer device.ipcMutex.Unlock()
if device.isClosed() {
return
}
@@ -443,11 +461,7 @@ func (device *Device) BindSetMark(mark uint32) error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@@ -472,11 +486,13 @@ func (device *Device) BindUpdate() error {
var err error
var recvFns []conn.ReceiveFunc
netc := &device.net
+
recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil {
netc.port = 0
return err
}
+
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
@@ -495,11 +511,7 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ peer.markEndpointSrcForClearing()
}
device.peers.RUnlock()
@@ -507,8 +519,9 @@ func (device *Device) BindUpdate() error {
device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ batchSize := netc.bind.BatchSize()
for _, fn := range recvFns {
- go device.RoutineReceiveIncoming(fn)
+ go device.RoutineReceiveIncoming(batchSize, fn)
}
device.log.Verbosef("UDP bind has been updated")