aboutsummaryrefslogtreecommitdiffstats
path: root/device/peer.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/peer.go')
-rw-r--r--device/peer.go267
1 files changed, 131 insertions, 136 deletions
diff --git a/device/peer.go b/device/peer.go
index 91d975a..47a2f14 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -1,37 +1,35 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package device
import (
- "encoding/base64"
+ "container/list"
"errors"
- "fmt"
"sync"
"sync/atomic"
"time"
-)
-const (
- PeerRoutineNumber = 3
+ "golang.zx2c4.com/wireguard/conn"
)
type Peer struct {
- isRunning AtomicBool
- sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
- keypairs Keypairs
- handshake Handshake
- device *Device
- endpoint Endpoint
- persistentKeepaliveInterval uint16
-
- // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
- stats struct {
- txBytes uint64 // bytes send to peer (endpoint)
- rxBytes uint64 // bytes received from peer
- lastHandshakeNano int64 // nano seconds since epoch
+ isRunning atomic.Bool
+ keypairs Keypairs
+ handshake Handshake
+ device *Device
+ stopping sync.WaitGroup // routines pending stop
+ txBytes atomic.Uint64 // bytes send to peer (endpoint)
+ rxBytes atomic.Uint64 // bytes received from peer
+ lastHandshakeNano atomic.Int64 // nano seconds since epoch
+
+ endpoint struct {
+ sync.Mutex
+ val conn.Endpoint
+ clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
+ disableRoaming bool
}
timers struct {
@@ -40,40 +38,32 @@ type Peer struct {
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
- handshakeAttempts uint32
- needAnotherKeepalive AtomicBool
- sentLastMinuteHandshake AtomicBool
+ handshakeAttempts atomic.Uint32
+ needAnotherKeepalive atomic.Bool
+ sentLastMinuteHandshake atomic.Bool
}
- signals struct {
- newKeypairArrived chan struct{}
- flushNonceQueue chan struct{}
+ state struct {
+ sync.Mutex // protects against concurrent Start/Stop
}
queue struct {
- nonce chan *QueueOutboundElement // nonce / pre-handshake queue
- outbound chan *QueueOutboundElement // sequential ordering of work
- inbound chan *QueueInboundElement // sequential ordering of work
- packetInNonceQueueIsAwaitingKey AtomicBool
- }
-
- routines struct {
- sync.Mutex // held when stopping / starting routines
- starting sync.WaitGroup // routines pending start
- stopping sync.WaitGroup // routines pending stop
- stop chan struct{} // size 0, stop all go routines in peer
+ staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
+ outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
+ inbound *autodrainingInboundQueue // sequential ordering of tun writing
}
- cookieGenerator CookieGenerator
+ cookieGenerator CookieGenerator
+ trieEntries list.List
+ persistentKeepaliveInterval atomic.Uint32
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
- if device.isClosed.Get() {
+ if device.isClosed() {
return nil, errors.New("device closed")
}
// lock resources
-
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -81,136 +71,144 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
defer device.peers.Unlock()
// check if over limit
-
if len(device.peers.keyMap) >= MaxPeers {
return nil, errors.New("too many peers")
}
// create peer
-
peer := new(Peer)
- peer.Lock()
- defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
- peer.isRunning.Set(false)
+ peer.queue.outbound = newAutodrainingOutboundQueue(device)
+ peer.queue.inbound = newAutodrainingInboundQueue(device)
+ peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
// map public key
-
_, ok := device.peers.keyMap[pk]
if ok {
return nil, errors.New("adding existing peer")
}
// pre-compute DH
-
handshake := &peer.handshake
handshake.mutex.Lock()
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
- ssIsZero := isZero(handshake.precomputedStaticStatic[:])
+ handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.remoteStatic = pk
handshake.mutex.Unlock()
// reset endpoint
+ peer.endpoint.Lock()
+ peer.endpoint.val = nil
+ peer.endpoint.disableRoaming = false
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.Unlock()
- peer.endpoint = nil
-
- // conditionally add
-
- if !ssIsZero {
- device.peers.keyMap[pk] = peer
- } else {
- return nil, nil
- }
-
- // start peer
+ // init timers
+ peer.timersInit()
- if peer.device.isUp.Get() {
- peer.Start()
- }
+ // add
+ device.peers.keyMap[pk] = peer
return peer, nil
}
-func (peer *Peer) SendBuffer(buffer []byte) error {
+func (peer *Peer) SendBuffers(buffers [][]byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
- if peer.device.net.bind == nil {
- return errors.New("no bind")
+ if peer.device.isClosed() {
+ return nil
}
- peer.RLock()
- defer peer.RUnlock()
-
- if peer.endpoint == nil {
+ peer.endpoint.Lock()
+ endpoint := peer.endpoint.val
+ if endpoint == nil {
+ peer.endpoint.Unlock()
return errors.New("no known endpoint for peer")
}
+ if peer.endpoint.clearSrcOnTx {
+ endpoint.ClearSrc()
+ peer.endpoint.clearSrcOnTx = false
+ }
+ peer.endpoint.Unlock()
- err := peer.device.net.bind.Send(buffer, peer.endpoint)
+ err := peer.device.net.bind.Send(buffers, endpoint)
if err == nil {
- atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
+ var totalLen uint64
+ for _, b := range buffers {
+ totalLen += uint64(len(b))
+ }
+ peer.txBytes.Add(totalLen)
}
return err
}
func (peer *Peer) String() string {
- base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
- abbreviatedKey := "invalid"
- if len(base64Key) == 44 {
- abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
+ // The awful goo that follows is identical to:
+ //
+ // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
+ // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
+ // return fmt.Sprintf("peer(%s)", abbreviatedKey)
+ //
+ // except that it is considerably more efficient.
+ src := peer.handshake.remoteStatic
+ b64 := func(input byte) byte {
+ return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
}
- return fmt.Sprintf("peer(%s)", abbreviatedKey)
+ b := []byte("peer(____…____)")
+ const first = len("peer(")
+ const second = len("peer(____…")
+ b[first+0] = b64((src[0] >> 2) & 63)
+ b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
+ b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
+ b[first+3] = b64(src[2] & 63)
+ b[second+0] = b64(src[29] & 63)
+ b[second+1] = b64((src[30] >> 2) & 63)
+ b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
+ b[second+3] = b64((src[31] << 2) & 63)
+ return string(b)
}
func (peer *Peer) Start() {
-
// should never start a peer on a closed device
-
- if peer.device.isClosed.Get() {
+ if peer.device.isClosed() {
return
}
// prevent simultaneous start/stop operations
+ peer.state.Lock()
+ defer peer.state.Unlock()
- peer.routines.Lock()
- defer peer.routines.Unlock()
-
- if peer.isRunning.Get() {
+ if peer.isRunning.Load() {
return
}
device := peer.device
- device.log.Debug.Println(peer, "- Starting...")
+ device.log.Verbosef("%v - Starting", peer)
// reset routine state
+ peer.stopping.Wait()
+ peer.stopping.Add(2)
- peer.routines.starting.Wait()
- peer.routines.stopping.Wait()
- peer.routines.stop = make(chan struct{})
- peer.routines.starting.Add(PeerRoutineNumber)
- peer.routines.stopping.Add(PeerRoutineNumber)
-
- // prepare queues
+ peer.handshake.mutex.Lock()
+ peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
+ peer.handshake.mutex.Unlock()
- peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
- peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
- peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
+ peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
- peer.timersInit()
- peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
- peer.signals.newKeypairArrived = make(chan struct{}, 1)
- peer.signals.flushNonceQueue = make(chan struct{}, 1)
+ peer.timersStart()
- // wait for routines to start
+ device.flushInboundQueue(peer.queue.inbound)
+ device.flushOutboundQueue(peer.queue.outbound)
- go peer.RoutineNonce()
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
+ // Use the device batch size, not the bind batch size, as the device size is
+ // the size of the batch pools.
+ batchSize := peer.device.BatchSize()
+ go peer.RoutineSequentialSender(batchSize)
+ go peer.RoutineSequentialReceiver(batchSize)
- peer.routines.starting.Wait()
- peer.isRunning.Set(true)
+ peer.isRunning.Store(true)
}
func (peer *Peer) ZeroAndFlushAll() {
@@ -222,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.next)
+ device.DeleteKeypair(keypairs.next.Load())
keypairs.previous = nil
keypairs.current = nil
- keypairs.next = nil
+ keypairs.next.Store(nil)
keypairs.Unlock()
// clear handshake state
@@ -236,7 +234,7 @@ func (peer *Peer) ZeroAndFlushAll() {
handshake.Clear()
handshake.mutex.Unlock()
- peer.FlushNonceQueue()
+ peer.FlushStagedPackets()
}
func (peer *Peer) ExpireCurrentKeypairs() {
@@ -244,58 +242,55 @@ func (peer *Peer) ExpireCurrentKeypairs() {
handshake.mutex.Lock()
peer.device.indexTable.Delete(handshake.localIndex)
handshake.Clear()
- handshake.mutex.Unlock()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
+ handshake.mutex.Unlock()
keypairs := &peer.keypairs
keypairs.Lock()
if keypairs.current != nil {
- keypairs.current.sendNonce = RejectAfterMessages
+ keypairs.current.sendNonce.Store(RejectAfterMessages)
}
- if keypairs.next != nil {
- keypairs.next.sendNonce = RejectAfterMessages
+ if next := keypairs.next.Load(); next != nil {
+ next.sendNonce.Store(RejectAfterMessages)
}
keypairs.Unlock()
}
func (peer *Peer) Stop() {
-
- // prevent simultaneous start/stop operations
+ peer.state.Lock()
+ defer peer.state.Unlock()
if !peer.isRunning.Swap(false) {
return
}
- peer.routines.starting.Wait()
-
- peer.routines.Lock()
- defer peer.routines.Unlock()
-
- peer.device.log.Debug.Println(peer, "- Stopping...")
+ peer.device.log.Verbosef("%v - Stopping", peer)
peer.timersStop()
-
- // stop & wait for ongoing peer routines
-
- close(peer.routines.stop)
- peer.routines.stopping.Wait()
-
- // close queues
-
- close(peer.queue.nonce)
- close(peer.queue.outbound)
- close(peer.queue.inbound)
+ // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
+ peer.queue.inbound.c <- nil
+ peer.queue.outbound.c <- nil
+ peer.stopping.Wait()
+ peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
peer.ZeroAndFlushAll()
}
-var RoamingDisabled bool
+func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.disableRoaming {
+ return
+ }
+ peer.endpoint.clearSrcOnTx = false
+ peer.endpoint.val = endpoint
+}
-func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
- if RoamingDisabled {
+func (peer *Peer) markEndpointSrcForClearing() {
+ peer.endpoint.Lock()
+ defer peer.endpoint.Unlock()
+ if peer.endpoint.val == nil {
return
}
- peer.Lock()
- peer.endpoint = endpoint
- peer.Unlock()
+ peer.endpoint.clearSrcOnTx = true
}