diff options
Diffstat (limited to 'device/peer.go')
-rw-r--r-- | device/peer.go | 267 |
1 files changed, 131 insertions, 136 deletions
diff --git a/device/peer.go b/device/peer.go index 91d975a..47a2f14 100644 --- a/device/peer.go +++ b/device/peer.go @@ -1,37 +1,35 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "encoding/base64" + "container/list" "errors" - "fmt" "sync" "sync/atomic" "time" -) -const ( - PeerRoutineNumber = 3 + "golang.zx2c4.com/wireguard/conn" ) type Peer struct { - isRunning AtomicBool - sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer - keypairs Keypairs - handshake Handshake - device *Device - endpoint Endpoint - persistentKeepaliveInterval uint16 - - // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly - stats struct { - txBytes uint64 // bytes send to peer (endpoint) - rxBytes uint64 // bytes received from peer - lastHandshakeNano int64 // nano seconds since epoch + isRunning atomic.Bool + keypairs Keypairs + handshake Handshake + device *Device + stopping sync.WaitGroup // routines pending stop + txBytes atomic.Uint64 // bytes send to peer (endpoint) + rxBytes atomic.Uint64 // bytes received from peer + lastHandshakeNano atomic.Int64 // nano seconds since epoch + + endpoint struct { + sync.Mutex + val conn.Endpoint + clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission + disableRoaming bool } timers struct { @@ -40,40 +38,32 @@ type Peer struct { newHandshake *Timer zeroKeyMaterial *Timer persistentKeepalive *Timer - handshakeAttempts uint32 - needAnotherKeepalive AtomicBool - sentLastMinuteHandshake AtomicBool + handshakeAttempts atomic.Uint32 + needAnotherKeepalive atomic.Bool + sentLastMinuteHandshake atomic.Bool } - signals struct { - newKeypairArrived chan struct{} - flushNonceQueue chan struct{} + state struct { + sync.Mutex // protects against concurrent Start/Stop } queue struct { - nonce chan *QueueOutboundElement // nonce / pre-handshake queue - outbound chan *QueueOutboundElement // sequential ordering of work - inbound chan *QueueInboundElement // sequential ordering of work - packetInNonceQueueIsAwaitingKey AtomicBool - } - - routines struct { - sync.Mutex // held when stopping / starting routines - starting sync.WaitGroup // routines pending start - stopping sync.WaitGroup // routines pending stop - stop chan struct{} // size 0, stop all go routines in peer + staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } - cookieGenerator CookieGenerator + cookieGenerator CookieGenerator + trieEntries list.List + persistentKeepaliveInterval atomic.Uint32 } func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { - if device.isClosed.Get() { + if device.isClosed() { return nil, errors.New("device closed") } // lock resources - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -81,136 +71,144 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { defer device.peers.Unlock() // check if over limit - if len(device.peers.keyMap) >= MaxPeers { return nil, errors.New("too many peers") } // create peer - peer := new(Peer) - peer.Lock() - defer peer.Unlock() peer.cookieGenerator.Init(pk) peer.device = device - peer.isRunning.Set(false) + peer.queue.outbound = newAutodrainingOutboundQueue(device) + peer.queue.inbound = newAutodrainingInboundQueue(device) + peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) // map public key - _, ok := device.peers.keyMap[pk] if ok { return nil, errors.New("adding existing peer") } // pre-compute DH - handshake := &peer.handshake handshake.mutex.Lock() - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) - ssIsZero := isZero(handshake.precomputedStaticStatic[:]) + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk) handshake.remoteStatic = pk handshake.mutex.Unlock() // reset endpoint + peer.endpoint.Lock() + peer.endpoint.val = nil + peer.endpoint.disableRoaming = false + peer.endpoint.clearSrcOnTx = false + peer.endpoint.Unlock() - peer.endpoint = nil - - // conditionally add - - if !ssIsZero { - device.peers.keyMap[pk] = peer - } else { - return nil, nil - } - - // start peer + // init timers + peer.timersInit() - if peer.device.isUp.Get() { - peer.Start() - } + // add + device.peers.keyMap[pk] = peer return peer, nil } -func (peer *Peer) SendBuffer(buffer []byte) error { +func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() - if peer.device.net.bind == nil { - return errors.New("no bind") + if peer.device.isClosed() { + return nil } - peer.RLock() - defer peer.RUnlock() - - if peer.endpoint == nil { + peer.endpoint.Lock() + endpoint := peer.endpoint.val + if endpoint == nil { + peer.endpoint.Unlock() return errors.New("no known endpoint for peer") } + if peer.endpoint.clearSrcOnTx { + endpoint.ClearSrc() + peer.endpoint.clearSrcOnTx = false + } + peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffer, peer.endpoint) + err := peer.device.net.bind.Send(buffers, endpoint) if err == nil { - atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) + var totalLen uint64 + for _, b := range buffers { + totalLen += uint64(len(b)) + } + peer.txBytes.Add(totalLen) } return err } func (peer *Peer) String() string { - base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) - abbreviatedKey := "invalid" - if len(base64Key) == 44 { - abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] + // The awful goo that follows is identical to: + // + // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) + // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43] + // return fmt.Sprintf("peer(%s)", abbreviatedKey) + // + // except that it is considerably more efficient. + src := peer.handshake.remoteStatic + b64 := func(input byte) byte { + return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3) } - return fmt.Sprintf("peer(%s)", abbreviatedKey) + b := []byte("peer(____…____)") + const first = len("peer(") + const second = len("peer(____…") + b[first+0] = b64((src[0] >> 2) & 63) + b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63) + b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63) + b[first+3] = b64(src[2] & 63) + b[second+0] = b64(src[29] & 63) + b[second+1] = b64((src[30] >> 2) & 63) + b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63) + b[second+3] = b64((src[31] << 2) & 63) + return string(b) } func (peer *Peer) Start() { - // should never start a peer on a closed device - - if peer.device.isClosed.Get() { + if peer.device.isClosed() { return } // prevent simultaneous start/stop operations + peer.state.Lock() + defer peer.state.Unlock() - peer.routines.Lock() - defer peer.routines.Unlock() - - if peer.isRunning.Get() { + if peer.isRunning.Load() { return } device := peer.device - device.log.Debug.Println(peer, "- Starting...") + device.log.Verbosef("%v - Starting", peer) // reset routine state + peer.stopping.Wait() + peer.stopping.Add(2) - peer.routines.starting.Wait() - peer.routines.stopping.Wait() - peer.routines.stop = make(chan struct{}) - peer.routines.starting.Add(PeerRoutineNumber) - peer.routines.stopping.Add(PeerRoutineNumber) - - // prepare queues + peer.handshake.mutex.Lock() + peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + peer.handshake.mutex.Unlock() - peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes - peer.timersInit() - peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) - peer.signals.newKeypairArrived = make(chan struct{}, 1) - peer.signals.flushNonceQueue = make(chan struct{}, 1) + peer.timersStart() - // wait for routines to start + device.flushInboundQueue(peer.queue.inbound) + device.flushOutboundQueue(peer.queue.outbound) - go peer.RoutineNonce() - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() + // Use the device batch size, not the bind batch size, as the device size is + // the size of the batch pools. + batchSize := peer.device.BatchSize() + go peer.RoutineSequentialSender(batchSize) + go peer.RoutineSequentialReceiver(batchSize) - peer.routines.starting.Wait() - peer.isRunning.Set(true) + peer.isRunning.Store(true) } func (peer *Peer) ZeroAndFlushAll() { @@ -222,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() { keypairs.Lock() device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.current) - device.DeleteKeypair(keypairs.next) + device.DeleteKeypair(keypairs.next.Load()) keypairs.previous = nil keypairs.current = nil - keypairs.next = nil + keypairs.next.Store(nil) keypairs.Unlock() // clear handshake state @@ -236,7 +234,7 @@ func (peer *Peer) ZeroAndFlushAll() { handshake.Clear() handshake.mutex.Unlock() - peer.FlushNonceQueue() + peer.FlushStagedPackets() } func (peer *Peer) ExpireCurrentKeypairs() { @@ -244,58 +242,55 @@ func (peer *Peer) ExpireCurrentKeypairs() { handshake.mutex.Lock() peer.device.indexTable.Delete(handshake.localIndex) handshake.Clear() - handshake.mutex.Unlock() peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + handshake.mutex.Unlock() keypairs := &peer.keypairs keypairs.Lock() if keypairs.current != nil { - keypairs.current.sendNonce = RejectAfterMessages + keypairs.current.sendNonce.Store(RejectAfterMessages) } - if keypairs.next != nil { - keypairs.next.sendNonce = RejectAfterMessages + if next := keypairs.next.Load(); next != nil { + next.sendNonce.Store(RejectAfterMessages) } keypairs.Unlock() } func (peer *Peer) Stop() { - - // prevent simultaneous start/stop operations + peer.state.Lock() + defer peer.state.Unlock() if !peer.isRunning.Swap(false) { return } - peer.routines.starting.Wait() - - peer.routines.Lock() - defer peer.routines.Unlock() - - peer.device.log.Debug.Println(peer, "- Stopping...") + peer.device.log.Verbosef("%v - Stopping", peer) peer.timersStop() - - // stop & wait for ongoing peer routines - - close(peer.routines.stop) - peer.routines.stopping.Wait() - - // close queues - - close(peer.queue.nonce) - close(peer.queue.outbound) - close(peer.queue.inbound) + // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. + peer.queue.inbound.c <- nil + peer.queue.outbound.c <- nil + peer.stopping.Wait() + peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us peer.ZeroAndFlushAll() } -var RoamingDisabled bool +func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.disableRoaming { + return + } + peer.endpoint.clearSrcOnTx = false + peer.endpoint.val = endpoint +} -func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { - if RoamingDisabled { +func (peer *Peer) markEndpointSrcForClearing() { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.val == nil { return } - peer.Lock() - peer.endpoint = endpoint - peer.Unlock() + peer.endpoint.clearSrcOnTx = true } |