aboutsummaryrefslogtreecommitdiffstats
path: root/device/peer.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/peer.go')
-rw-r--r--device/peer.go270
1 files changed, 270 insertions, 0 deletions
diff --git a/device/peer.go b/device/peer.go
new file mode 100644
index 0000000..af3ef9d
--- /dev/null
+++ b/device/peer.go
@@ -0,0 +1,270 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+)
+
+const (
+ PeerRoutineNumber = 3
+)
+
+type Peer struct {
+ isRunning AtomicBool
+ sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
+ keypairs Keypairs
+ handshake Handshake
+ device *Device
+ endpoint Endpoint
+ persistentKeepaliveInterval uint16
+
+ // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
+ stats struct {
+ txBytes uint64 // bytes send to peer (endpoint)
+ rxBytes uint64 // bytes received from peer
+ lastHandshakeNano int64 // nano seconds since epoch
+ }
+
+ timers struct {
+ retransmitHandshake *Timer
+ sendKeepalive *Timer
+ newHandshake *Timer
+ zeroKeyMaterial *Timer
+ persistentKeepalive *Timer
+ handshakeAttempts uint32
+ needAnotherKeepalive AtomicBool
+ sentLastMinuteHandshake AtomicBool
+ }
+
+ signals struct {
+ newKeypairArrived chan struct{}
+ flushNonceQueue chan struct{}
+ }
+
+ queue struct {
+ nonce chan *QueueOutboundElement // nonce / pre-handshake queue
+ outbound chan *QueueOutboundElement // sequential ordering of work
+ inbound chan *QueueInboundElement // sequential ordering of work
+ packetInNonceQueueIsAwaitingKey AtomicBool
+ }
+
+ routines struct {
+ sync.Mutex // held when stopping / starting routines
+ starting sync.WaitGroup // routines pending start
+ stopping sync.WaitGroup // routines pending stop
+ stop chan struct{} // size 0, stop all go routines in peer
+ }
+
+ cookieGenerator CookieGenerator
+}
+
+func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
+
+ if device.isClosed.Get() {
+ return nil, errors.New("device closed")
+ }
+
+ // lock resources
+
+ device.staticIdentity.RLock()
+ defer device.staticIdentity.RUnlock()
+
+ device.peers.Lock()
+ defer device.peers.Unlock()
+
+ // check if over limit
+
+ if len(device.peers.keyMap) >= MaxPeers {
+ return nil, errors.New("too many peers")
+ }
+
+ // create peer
+
+ peer := new(Peer)
+ peer.Lock()
+ defer peer.Unlock()
+
+ peer.cookieGenerator.Init(pk)
+ peer.device = device
+ peer.isRunning.Set(false)
+
+ // map public key
+
+ _, ok := device.peers.keyMap[pk]
+ if ok {
+ return nil, errors.New("adding existing peer")
+ }
+ device.peers.keyMap[pk] = peer
+
+ // pre-compute DH
+
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ handshake.remoteStatic = pk
+ handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
+ handshake.mutex.Unlock()
+
+ // reset endpoint
+
+ peer.endpoint = nil
+
+ // start peer
+
+ if peer.device.isUp.Get() {
+ peer.Start()
+ }
+
+ return peer, nil
+}
+
+func (peer *Peer) SendBuffer(buffer []byte) error {
+ peer.device.net.RLock()
+ defer peer.device.net.RUnlock()
+
+ if peer.device.net.bind == nil {
+ return errors.New("no bind")
+ }
+
+ peer.RLock()
+ defer peer.RUnlock()
+
+ if peer.endpoint == nil {
+ return errors.New("no known endpoint for peer")
+ }
+
+ return peer.device.net.bind.Send(buffer, peer.endpoint)
+}
+
+func (peer *Peer) String() string {
+ base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
+ abbreviatedKey := "invalid"
+ if len(base64Key) == 44 {
+ abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
+ }
+ return fmt.Sprintf("peer(%s)", abbreviatedKey)
+}
+
+func (peer *Peer) Start() {
+
+ // should never start a peer on a closed device
+
+ if peer.device.isClosed.Get() {
+ return
+ }
+
+ // prevent simultaneous start/stop operations
+
+ peer.routines.Lock()
+ defer peer.routines.Unlock()
+
+ if peer.isRunning.Get() {
+ return
+ }
+
+ device := peer.device
+ device.log.Debug.Println(peer, "- Starting...")
+
+ // reset routine state
+
+ peer.routines.starting.Wait()
+ peer.routines.stopping.Wait()
+ peer.routines.stop = make(chan struct{})
+ peer.routines.starting.Add(PeerRoutineNumber)
+ peer.routines.stopping.Add(PeerRoutineNumber)
+
+ // prepare queues
+
+ peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
+ peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
+ peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
+
+ peer.timersInit()
+ peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
+ peer.signals.newKeypairArrived = make(chan struct{}, 1)
+ peer.signals.flushNonceQueue = make(chan struct{}, 1)
+
+ // wait for routines to start
+
+ go peer.RoutineNonce()
+ go peer.RoutineSequentialSender()
+ go peer.RoutineSequentialReceiver()
+
+ peer.routines.starting.Wait()
+ peer.isRunning.Set(true)
+}
+
+func (peer *Peer) ZeroAndFlushAll() {
+ device := peer.device
+
+ // clear key pairs
+
+ keypairs := &peer.keypairs
+ keypairs.Lock()
+ device.DeleteKeypair(keypairs.previous)
+ device.DeleteKeypair(keypairs.current)
+ device.DeleteKeypair(keypairs.next)
+ keypairs.previous = nil
+ keypairs.current = nil
+ keypairs.next = nil
+ keypairs.Unlock()
+
+ // clear handshake state
+
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ device.indexTable.Delete(handshake.localIndex)
+ handshake.Clear()
+ handshake.mutex.Unlock()
+
+ peer.FlushNonceQueue()
+}
+
+func (peer *Peer) Stop() {
+
+ // prevent simultaneous start/stop operations
+
+ if !peer.isRunning.Swap(false) {
+ return
+ }
+
+ peer.routines.starting.Wait()
+
+ peer.routines.Lock()
+ defer peer.routines.Unlock()
+
+ peer.device.log.Debug.Println(peer, "- Stopping...")
+
+ peer.timersStop()
+
+ // stop & wait for ongoing peer routines
+
+ close(peer.routines.stop)
+ peer.routines.stopping.Wait()
+
+ // close queues
+
+ close(peer.queue.nonce)
+ close(peer.queue.outbound)
+ close(peer.queue.inbound)
+
+ peer.ZeroAndFlushAll()
+}
+
+var roamingDisabled bool
+
+func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
+ if roamingDisabled {
+ return
+ }
+ peer.Lock()
+ peer.endpoint = endpoint
+ peer.Unlock()
+}