From 9263014ed3f0a97800c893cb7346cc5109fc9e27 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Fri, 29 Jan 2021 14:54:11 +0100 Subject: device: simplify peer queue locking Signed-off-by: Jason A. Donenfeld --- device/device.go | 28 +++++++++---------- device/peer.go | 57 +++++++++---------------------------- device/receive.go | 84 ++++++++++++------------------------------------------- device/send.go | 48 ++++++++++++++++--------------- 4 files changed, 70 insertions(+), 147 deletions(-) (limited to 'device') diff --git a/device/device.go b/device/device.go index 7c8da9c..08db244 100644 --- a/device/device.go +++ b/device/device.go @@ -75,8 +75,8 @@ type Device struct { } queue struct { - encryption *encryptionQueue - decryption *decryptionQueue + encryption *outboundQueue + decryption *inboundQueue handshake chan QueueHandshakeElement } @@ -92,21 +92,21 @@ type Device struct { ipcMutex sync.RWMutex } -// An encryptionQueue is a channel of QueueOutboundElements awaiting encryption. -// An encryptionQueue is ref-counted using its wg field. -// An encryptionQueue created with newEncryptionQueue has one reference. +// An outboundQueue is a channel of QueueOutboundElements awaiting encryption. +// An outboundQueue is ref-counted using its wg field. +// An outboundQueue created with newOutboundQueue has one reference. // Every additional writer must call wg.Add(1). // Every completed writer must call wg.Done(). // When no further writers will be added, // call wg.Done to remove the initial reference. // When the refcount hits 0, the queue's channel is closed. -type encryptionQueue struct { +type outboundQueue struct { c chan *QueueOutboundElement wg sync.WaitGroup } -func newEncryptionQueue() *encryptionQueue { - q := &encryptionQueue{ +func newOutboundQueue() *outboundQueue { + q := &outboundQueue{ c: make(chan *QueueOutboundElement, QueueOutboundSize), } q.wg.Add(1) @@ -117,14 +117,14 @@ func newEncryptionQueue() *encryptionQueue { return q } -// A decryptionQueue is similar to an encryptionQueue; see those docs. -type decryptionQueue struct { +// A inboundQueue is similar to an outboundQueue; see those docs. +type inboundQueue struct { c chan *QueueInboundElement wg sync.WaitGroup } -func newDecryptionQueue() *decryptionQueue { - q := &decryptionQueue{ +func newInboundQueue() *inboundQueue { + q := &inboundQueue{ c: make(chan *QueueInboundElement, QueueInboundSize), } q.wg.Add(1) @@ -323,8 +323,8 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { // create queues device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) - device.queue.encryption = newEncryptionQueue() - device.queue.decryption = newDecryptionQueue() + device.queue.encryption = newOutboundQueue() + device.queue.decryption = newInboundQueue() // prepare signals diff --git a/device/peer.go b/device/peer.go index b385519..76f9a96 100644 --- a/device/peer.go +++ b/device/peer.go @@ -25,6 +25,7 @@ type Peer struct { endpoint conn.Endpoint persistentKeepaliveInterval uint32 // accessed atomically firstTrieEntry *trieEntry + stopping sync.WaitGroup // routines pending stop // These fields are accessed with atomic operations, which must be // 64-bit aligned even on 32-bit platforms. Go guarantees that an @@ -53,14 +54,8 @@ type Peer struct { queue struct { sync.RWMutex staged chan *QueueOutboundElement // staged packets before a handshake is available - outbound chan *QueueOutboundElement // sequential ordering of work - inbound chan *QueueInboundElement // sequential ordering of work - } - - routines struct { - sync.Mutex // held when stopping routines - stopping sync.WaitGroup // routines pending stop - stop chan struct{} // size 0, stop all go routines in peer + outbound chan *QueueOutboundElement // sequential ordering of udp transmission + inbound chan *QueueInboundElement // sequential ordering of tun writing } cookieGenerator CookieGenerator @@ -72,7 +67,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { } // lock resources - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -80,13 +74,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { defer device.peers.Unlock() // check if over limit - if len(device.peers.keyMap) >= MaxPeers { return nil, errors.New("too many peers") } // create peer - peer := new(Peer) peer.Lock() defer peer.Unlock() @@ -95,14 +87,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.device = device // map public key - _, ok := device.peers.keyMap[pk] if ok { return nil, errors.New("adding existing peer") } // pre-compute DH - handshake := &peer.handshake handshake.mutex.Lock() handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) @@ -110,16 +100,13 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake.mutex.Unlock() // reset endpoint - peer.endpoint = nil // add - device.peers.keyMap[pk] = peer device.peers.empty.Set(false) // start peer - if peer.device.isUp.Get() { peer.Start() } @@ -164,17 +151,14 @@ func (peer *Peer) String() string { } func (peer *Peer) Start() { - // should never start a peer on a closed device - if peer.device.isClosed.Get() { return } // prevent simultaneous start/stop operations - - peer.routines.Lock() - defer peer.routines.Unlock() + peer.queue.Lock() + defer peer.queue.Unlock() if peer.isRunning.Get() { return @@ -184,23 +168,19 @@ func (peer *Peer) Start() { device.log.Verbosef("%v - Starting...", peer) // reset routine state - - peer.routines.stopping.Wait() - peer.routines.stop = make(chan struct{}) - peer.routines.stopping.Add(1) + peer.stopping.Wait() + peer.stopping.Add(2) // prepare queues - peer.queue.Lock() - peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) - peer.queue.Unlock() + if peer.queue.staged == nil { + peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) + } peer.timersInit() peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) - // wait for routines to start - go peer.RoutineSequentialSender() go peer.RoutineSequentialReceiver() @@ -254,31 +234,20 @@ func (peer *Peer) ExpireCurrentKeypairs() { } func (peer *Peer) Stop() { - - // prevent simultaneous start/stop operations + peer.queue.Lock() + defer peer.queue.Unlock() if !peer.isRunning.Swap(false) { return } - peer.routines.Lock() - defer peer.routines.Unlock() - peer.device.log.Verbosef("%v - Stopping...", peer) peer.timersStop() - // stop & wait for ongoing peer routines - - close(peer.routines.stop) - peer.routines.stopping.Wait() - - // close queues - - peer.queue.Lock() close(peer.queue.inbound) close(peer.queue.outbound) - peer.queue.Unlock() + peer.stopping.Wait() peer.ZeroAndFlushAll() } diff --git a/device/receive.go b/device/receive.go index d513a21..abaf5af 100644 --- a/device/receive.go +++ b/device/receive.go @@ -174,7 +174,6 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { elem.Lock() // add to decryption queues - peer.queue.RLock() if peer.isRunning.Get() { peer.queue.inbound <- elem @@ -433,52 +432,25 @@ func (device *Device) RoutineHandshake() { func (peer *Peer) RoutineSequentialReceiver() { device := peer.device - var elem *QueueInboundElement - defer func() { device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) - peer.routines.stopping.Done() - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutInboundElement(elem) - } + peer.stopping.Done() }() - device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - for { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutInboundElement(elem) - elem = nil - } - - var elemOk bool - select { - case <-peer.routines.stop: - return - case elem, elemOk = <-peer.queue.inbound: - if !elemOk { - return - } - } - - // wait for decryption + for elem := range peer.queue.inbound { + var err error elem.Lock() if elem.packet == nil { // decryption failed - continue + goto skip } - // check for replay if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - continue + goto skip } - // update endpoint peer.SetEndpointFromPacket(elem.endpoint) - - // check if using new keypair if peer.ReceivedWithKeypair(elem.keypair) { peer.timersHandshakeComplete() peer.SendStagedPackets() @@ -489,83 +461,63 @@ func (peer *Peer) RoutineSequentialReceiver() { peer.timersAnyAuthenticatedPacketReceived() atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) - // check for keepalive - if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) - continue + goto skip } peer.timersDataReceived() - // verify source and strip padding - switch elem.packet[0] >> 4 { case ipv4.Version: - - // strip padding - if len(elem.packet) < ipv4.HeaderLen { - continue + goto skip } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] length := binary.BigEndian.Uint16(field) if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - continue + goto skip } - elem.packet = elem.packet[:length] - - // verify IPv4 source - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] if device.allowedips.LookupIPv4(src) != peer { device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) - continue + goto skip } case ipv6.Version: - - // strip padding - if len(elem.packet) < ipv6.HeaderLen { - continue + goto skip } - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] length := binary.BigEndian.Uint16(field) length += ipv6.HeaderLen if int(length) > len(elem.packet) { - continue + goto skip } - elem.packet = elem.packet[:length] - - // verify IPv6 source - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] if device.allowedips.LookupIPv6(src) != peer { device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) - continue + goto skip } default: device.log.Verbosef("Packet with invalid IP version from %v", peer) - continue + goto skip } - // write to tun device - - offset := MessageTransportOffsetContent - _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) + _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) if err != nil && !device.isClosed.Get() { device.log.Errorf("Failed to write packet to TUN device: %v", err) } if len(peer.queue.inbound) == 0 { - err := device.tun.device.Flush() + err = device.tun.device.Flush() if err != nil { peer.device.log.Errorf("Unable to flush packets: %v", err) } } + skip: + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) } } diff --git a/device/send.go b/device/send.go index 04d2001..5261c2f 100644 --- a/device/send.go +++ b/device/send.go @@ -74,22 +74,17 @@ func (elem *QueueOutboundElement) clearPointers() { /* Queues a keepalive if no packets are queued for peer */ func (peer *Peer) SendKeepalive() { - var elem *QueueOutboundElement - peer.queue.RLock() - if len(peer.queue.staged) != 0 || !peer.isRunning.Get() { - goto out - } - elem = peer.device.NewOutboundElement() - elem.packet = nil - select { - case peer.queue.staged <- elem: - peer.device.log.Verbosef("%v - Sending keepalive packet", peer) - default: - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) + if len(peer.queue.staged) == 0 && peer.isRunning.Get() { + elem := peer.device.NewOutboundElement() + elem.packet = nil + select { + case peer.queue.staged <- elem: + peer.device.log.Verbosef("%v - Sending keepalive packet", peer) + default: + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } } -out: - peer.queue.RUnlock() peer.SendStagedPackets() } @@ -176,7 +171,6 @@ func (peer *Peer) SendHandshakeResponse() error { } func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { - device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) @@ -297,6 +291,8 @@ func (peer *Peer) StagePacket(elem *QueueOutboundElement) { } func (peer *Peer) SendStagedPackets() { + peer.device.queue.encryption.wg.Add(1) + defer peer.device.queue.encryption.wg.Done() top: if len(peer.queue.staged) == 0 || !peer.device.isUp.Get() { return @@ -307,8 +303,6 @@ top: peer.SendHandshakeInitiation(false) return } - peer.device.queue.encryption.wg.Add(1) - defer peer.device.queue.encryption.wg.Done() for { select { @@ -325,8 +319,15 @@ top: elem.Lock() // add to parallel and sequential queue - peer.queue.outbound <- elem - peer.device.queue.encryption.c <- elem + peer.queue.RLock() + if peer.isRunning.Get() { + peer.queue.outbound <- elem + peer.device.queue.encryption.c <- elem + } else { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.queue.RUnlock() default: return } @@ -410,10 +411,11 @@ func (device *Device) RoutineEncryption() { * The routine terminates then the outbound queue is closed. */ func (peer *Peer) RoutineSequentialSender() { - device := peer.device - - defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) + defer func() { + defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) + peer.stopping.Done() + }() device.log.Verbosef("%v - Routine: sequential sender - started", peer) for elem := range peer.queue.outbound { -- cgit v1.2.3-59-g8ed1b