aboutsummaryrefslogtreecommitdiffstats
path: root/device/peer.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/peer.go')
-rw-r--r--device/peer.go134
1 files changed, 75 insertions, 59 deletions
diff --git a/device/peer.go b/device/peer.go
index c8b825d..47a2f14 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
@@ -16,36 +16,31 @@ import (
)
type Peer struct {
- isRunning AtomicBool
- sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
- keypairs Keypairs
- handshake Handshake
- device *Device
- endpoint conn.Endpoint
- stopping sync.WaitGroup // routines pending stop
-
- // These fields are accessed with atomic operations, which must be
- // 64-bit aligned even on 32-bit platforms. Go guarantees that an
- // allocated struct will be 64-bit aligned. So we place
- // atomically-accessed fields up front, so that they can share in
- // this alignment before smaller fields throw it off.
- stats struct {
- txBytes uint64 // bytes send to peer (endpoint)
- rxBytes uint64 // bytes received from peer
- lastHandshakeNano int64 // nano seconds since epoch
+ isRunning atomic.Bool
+ keypairs Keypairs
+ handshake Handshake
+ device *Device
+ stopping sync.WaitGroup // routines pending stop
+ txBytes atomic.Uint64 // bytes send to peer (endpoint)
+ rxBytes atomic.Uint64 // bytes received from peer
+ lastHandshakeNano atomic.Int64 // nano seconds since epoch
+
+ endpoint struct {
+ sync.Mutex
+ val conn.Endpoint
+ clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
+ disableRoaming bool
}
- disableRoaming bool
-
timers struct {
retransmitHandshake *Timer
sendKeepalive *Timer
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
- handshakeAttempts uint32
- needAnotherKeepalive AtomicBool
- sentLastMinuteHandshake AtomicBool
+ handshakeAttempts atomic.Uint32
+ needAnotherKeepalive atomic.Bool
+ sentLastMinuteHandshake atomic.Bool
}
state struct {
@@ -53,14 +48,14 @@ type Peer struct {
}
queue struct {
- staged chan *QueueOutboundElement // staged packets before a handshake is available
- outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
- inbound *autodrainingInboundQueue // sequential ordering of tun writing
+ staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
+ outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
+ inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
cookieGenerator CookieGenerator
trieEntries list.List
- persistentKeepaliveInterval uint32 // accessed atomically
+ persistentKeepaliveInterval atomic.Uint32
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
@@ -82,14 +77,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer
peer := new(Peer)
- peer.Lock()
- defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
peer.queue.outbound = newAutodrainingOutboundQueue(device)
peer.queue.inbound = newAutodrainingInboundQueue(device)
- peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize)
+ peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
_, ok := device.peers.keyMap[pk]
@@ -100,26 +93,27 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk
handshake.mutex.Unlock()
// reset endpoint
- peer.endpoint = nil
+ peer.endpoint.Lock()
+ peer.endpoint.val = nil
+ peer.endpoint.disableRoaming = false
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.Unlock()
+
+ // init timers
+ peer.timersInit()
// add
device.peers.keyMap[pk] = peer
- // start peer
- peer.timersInit()
- if peer.device.isUp() {
- peer.Start()
- }
-
return peer, nil
}
-func (peer *Peer) SendBuffer(buffer []byte) error {
+func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
@@ -127,16 +121,25 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return nil
}
- peer.RLock()
- defer peer.RUnlock()
-
- if peer.endpoint == nil {
+ peer.endpoint.Lock()
+ endpoint := peer.endpoint.val
+ if endpoint == nil {
+ peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
+ if peer.endpoint.clearSrcOnTx {
+ endpoint.ClearSrc()
+ peer.endpoint.clearSrcOnTx = false
+ }
+ peer.endpoint.Unlock()
- err := peer.device.net.bind.Send(buffer, peer.endpoint)
+ err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
- atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
+ var totalLen uint64
+ for _, b := range buffers {
+ totalLen += uint64(len(b))
+ }
+ peer.txBytes.Add(totalLen)
}
return err
}
@@ -177,7 +180,7 @@ func (peer *Peer) Start() {
peer.state.Lock()
defer peer.state.Unlock()
- if peer.isRunning.Get() {
+ if peer.isRunning.Load() {
return
}
@@ -198,10 +201,14 @@ func (peer *Peer) Start() {
device.flushInboundQueue(peer.queue.inbound)
device.flushOutboundQueue(peer.queue.outbound)
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
- peer.isRunning.Set(true)
+ // Use the device batch size, not the bind batch size, as the device size is
+ // the size of the batch pools.
+ batchSize := peer.device.BatchSize()
+ go peer.RoutineSequentialSender(batchSize)
+ go peer.RoutineSequentialReceiver(batchSize)
+
+ peer.isRunning.Store(true)
}
func (peer *Peer) ZeroAndFlushAll() {
@@ -213,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.loadNext())
+ device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil
keypairs.current = nil
- keypairs.storeNext(nil)
+ keypairs.next.Store(nil)
keypairs.Unlock()
// clear handshake state
@@ -241,11 +248,10 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs := &peer.keypairs
keypairs.Lock()
if keypairs.current != nil {
- atomic.StoreUint64(&keypairs.current.sendNonce, RejectAfterMessages)
+ keypairs.current.sendNonce.Store(RejectAfterMessages)
}
- if keypairs.next != nil {
- next := keypairs.loadNext()
- atomic.StoreUint64(&next.sendNonce, RejectAfterMessages)
+ if next := keypairs.next.Load(); next != nil {
+ next.sendNonce.Store(RejectAfterMessages)
}
keypairs.Unlock()
}
@@ -271,10 +277,20 @@ func (peer *Peer) Stop() {
}
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
- if peer.disableRoaming {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.disableRoaming {
+ return
+ }
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.val = endpoint
+}
+
+func (peer *Peer) markEndpointSrcForClearing() {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.val == nil {
return
}
- peer.Lock()
- peer.endpoint = endpoint
- peer.Unlock()
+ peer.endpoint.clearSrcOnTx = true
}