diff options
Diffstat (limited to 'device/receive.go')
-rw-r--r-- | device/receive.go | 645 |
1 files changed, 267 insertions, 378 deletions
diff --git a/device/receive.go b/device/receive.go index 7d0693e..1392957 100644 --- a/device/receive.go +++ b/device/receive.go @@ -1,73 +1,52 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package device import ( - "bytes" "encoding/binary" + "errors" "net" - "strconv" "sync" - "sync/atomic" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" ) type QueueHandshakeElement struct { msgType uint32 packet []byte - endpoint Endpoint + endpoint conn.Endpoint buffer *[MaxMessageSize]byte } type QueueInboundElement struct { - dropped int32 - sync.Mutex buffer *[MaxMessageSize]byte packet []byte counter uint64 keypair *Keypair - endpoint Endpoint -} - -func (elem *QueueInboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) + endpoint conn.Endpoint } -func (elem *QueueInboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool { - select { - case inboundQueue <- element: - select { - case decryptionQueue <- element: - return true - default: - element.Drop() - element.Unlock() - return false - } - default: - device.PutInboundElement(element) - return false - } +type QueueInboundElementsContainer struct { + sync.Mutex + elems []*QueueInboundElement } -func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool { - select { - case queue <- element: - return true - default: - return false - } +// clearPointers clears elem fields that contain pointers. +// This makes the garbage collector's life easier and +// avoids accidentally keeping other objects around unnecessarily. +// It also reduces the possible collateral damage from use-after-free bugs. +func (elem *QueueInboundElement) clearPointers() { + elem.buffer = nil + elem.packet = nil + elem.keypair = nil + elem.endpoint = nil } /* Called when a new authenticated message has been received @@ -75,12 +54,12 @@ func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem * NOTE: Not thread safe, but called by sequential receiver! */ func (peer *Peer) keepKeyFreshReceiving() { - if peer.timers.sentLastMinuteHandshake.Get() { + if peer.timers.sentLastMinuteHandshake.Load() { return } keypair := peer.keypairs.Current() if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { - peer.timers.sentLastMinuteHandshake.Set(true) + peer.timers.sentLastMinuteHandshake.Store(true) peer.SendHandshakeInitiation(false) } } @@ -90,188 +69,189 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { - - logDebug := device.log.Debug +func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { + recvName := recv.PrettyName() defer func() { - logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) + device.queue.decryption.wg.Done() + device.queue.handshake.wg.Done() device.net.stopping.Done() }() - logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started") - device.net.starting.Done() + device.log.Verbosef("Routine: receive incoming %s - started", recvName) // receive datagrams until conn is closed - buffer := device.GetMessageBuffer() - var ( - err error - size int - endpoint Endpoint + bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) + bufs = make([][]byte, maxBatchSize) + err error + sizes = make([]int, maxBatchSize) + count int + endpoints = make([]conn.Endpoint, maxBatchSize) + deathSpiral int + elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) - for { - - // read next datagram + for i := range bufsArrs { + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + } - switch IP { - case ipv4.Version: - size, endpoint, err = bind.ReceiveIPv4(buffer[:]) - case ipv6.Version: - size, endpoint, err = bind.ReceiveIPv6(buffer[:]) - default: - panic("invalid IP version") + defer func() { + for i := 0; i < maxBatchSize; i++ { + if bufsArrs[i] != nil { + device.PutMessageBuffer(bufsArrs[i]) + } } + }() + for { + count, err = recv(bufs, sizes, endpoints) if err != nil { - device.PutMessageBuffer(buffer) + if errors.Is(err, net.ErrClosed) { + return + } + device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) + if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { + return + } + if deathSpiral < 10 { + deathSpiral++ + time.Sleep(time.Second / 3) + continue + } return } + deathSpiral = 0 - if size < MinMessageSize { - continue - } - - // check size of packet + // handle each packet in the batch + for i, size := range sizes[:count] { + if size < MinMessageSize { + continue + } - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) + // check size of packet - var okay bool + packet := bufsArrs[i][:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - switch msgType { + switch msgType { - // check if transport + // check if transport - case MessageTransportType: + case MessageTransportType: - // check size + // check size - if len(packet) < MessageTransportSize { - continue - } + if len(packet) < MessageTransportSize { + continue + } - // lookup key pair + // lookup key pair - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indexTable.Lookup(receiver) - keypair := value.keypair - if keypair == nil { - continue - } + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { + continue + } - // check keypair expiry + // check keypair expiry - if keypair.created.Add(RejectAfterTime).Before(time.Now()) { - continue - } + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } - // create work element - peer := value.peer - elem := device.GetInboundElement() - elem.packet = packet - elem.buffer = buffer - elem.keypair = keypair - elem.dropped = AtomicFalse - elem.endpoint = endpoint - elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() - - // add to decryption queues - - if peer.isRunning.Get() { - if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { - buffer = device.GetMessageBuffer() + // create work element + peer := value.peer + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = bufsArrs[i] + elem.keypair = keypair + elem.endpoint = endpoints[i] + elem.counter = 0 + + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetInboundElementsContainer() + elemsForPeer.Lock() + elemsByPeer[peer] = elemsForPeer } - } + elemsForPeer.elems = append(elemsForPeer.elems, elem) + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + continue - continue + // otherwise it is a fixed size & handshake related packet - // otherwise it is a fixed size & handshake related packet + case MessageInitiationType: + if len(packet) != MessageInitiationSize { + continue + } - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize + case MessageResponseType: + if len(packet) != MessageResponseSize { + continue + } - case MessageResponseType: - okay = len(packet) == MessageResponseSize + case MessageCookieReplyType: + if len(packet) != MessageCookieReplySize { + continue + } - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize + default: + device.log.Verbosef("Received message with unknown type") + continue + } - default: - logDebug.Println("Received message with unknown type") + select { + case device.queue.handshake.c <- QueueHandshakeElement{ + msgType: msgType, + buffer: bufsArrs[i], + packet: packet, + endpoint: endpoints[i], + }: + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + default: + } } - - if okay { - if (device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - endpoint: endpoint, - }, - )) { - buffer = device.GetMessageBuffer() + for peer, elemsContainer := range elemsByPeer { + if peer.isRunning.Load() { + peer.queue.inbound.c <- elemsContainer + device.queue.decryption.c <- elemsContainer + } else { + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsContainer(elemsContainer) } + delete(elemsByPeer, peer) } } } -func (device *Device) RoutineDecryption() { - +func (device *Device) RoutineDecryption(id int) { var nonce [chacha20poly1305.NonceSize]byte - logDebug := device.log.Debug - defer func() { - logDebug.Println("Routine: decryption worker - stopped") - device.state.stopping.Done() - }() - logDebug.Println("Routine: decryption worker - started") - device.state.starting.Done() - - for { - select { - case <-device.signals.stop: - return - - case elem, ok := <-device.queue.decryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } + defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) + device.log.Verbosef("Routine: decryption worker %d - started", id) + for elemsContainer := range device.queue.decryption.c { + for _, elem := range elemsContainer.elems { // split message into fields - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] content := elem.packet[MessageTransportOffsetContent:] - // expand nonce - - nonce[0x4] = counter[0x0] - nonce[0x5] = counter[0x1] - nonce[0x6] = counter[0x2] - nonce[0x7] = counter[0x3] - - nonce[0x8] = counter[0x4] - nonce[0x9] = counter[0x5] - nonce[0xa] = counter[0x6] - nonce[0xb] = counter[0x7] - // decrypt and release to consumer - var err error elem.counter = binary.LittleEndian.Uint64(counter) + // copy counter to nonce + binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) elem.packet, err = elem.keypair.receive.Open( content[:0], nonce[:], @@ -279,51 +259,23 @@ func (device *Device) RoutineDecryption() { nil, ) if err != nil { - elem.Drop() - device.PutMessageBuffer(elem.buffer) + elem.packet = nil } - elem.Unlock() } + elemsContainer.Unlock() } } /* Handles incoming packets related to handshake */ -func (device *Device) RoutineHandshake() { - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - var elem QueueHandshakeElement - var ok bool - +func (device *Device) RoutineHandshake(id int) { defer func() { - logDebug.Println("Routine: handshake worker - stopped") - device.state.stopping.Done() - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - } + device.log.Verbosef("Routine: handshake worker %d - stopped", id) + device.queue.encryption.wg.Done() }() + device.log.Verbosef("Routine: handshake worker %d - started", id) - logDebug.Println("Routine: handshake worker - started") - device.state.starting.Done() - - for { - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - elem.buffer = nil - } - - select { - case elem, ok = <-device.queue.handshake: - case <-device.signals.stop: - return - } - - if !ok { - return - } + for elem := range device.queue.handshake.c { // handle cookie fields and ratelimiting @@ -334,11 +286,10 @@ func (device *Device) RoutineHandshake() { // unmarshal packet var reply MessageCookieReply - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &reply) + err := reply.unmarshal(elem.packet) if err != nil { - logDebug.Println("Failed to decode cookie reply") - return + device.log.Verbosef("Failed to decode cookie reply") + goto skip } // lookup peer from index @@ -346,27 +297,27 @@ func (device *Device) RoutineHandshake() { entry := device.indexTable.Lookup(reply.Receiver) if entry.peer == nil { - continue + goto skip } // consume reply - if peer := entry.peer; peer.isRunning.Get() { - logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString()) + if peer := entry.peer; peer.isRunning.Load() { + device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) if !peer.cookieGenerator.ConsumeReply(&reply) { - logDebug.Println("Could not decrypt invalid cookie response") + device.log.Verbosef("Could not decrypt invalid cookie response") } } - continue + goto skip case MessageInitiationType, MessageResponseType: // check mac fields and maybe ratelimit if !device.cookieChecker.CheckMAC1(elem.packet) { - logDebug.Println("Received packet with invalid mac1") - continue + device.log.Verbosef("Received packet with invalid mac1") + goto skip } // endpoints destination address is the source of the datagram @@ -377,19 +328,19 @@ func (device *Device) RoutineHandshake() { if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { device.SendHandshakeCookie(&elem) - continue + goto skip } // check ratelimiter if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { - continue + goto skip } } default: - logError.Println("Invalid packet ended up in the handshake queue") - continue + device.log.Errorf("Invalid packet ended up in the handshake queue") + goto skip } // handle handshake initiation/response content @@ -400,22 +351,18 @@ func (device *Device) RoutineHandshake() { // unmarshal var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msg.unmarshal(elem.packet) if err != nil { - logError.Println("Failed to decode initiation message") - continue + device.log.Errorf("Failed to decode initiation message") + goto skip } // consume initiation peer := device.ConsumeMessageInitiation(&msg) if peer == nil { - logInfo.Println( - "Received invalid initiation message from", - elem.endpoint.DstToString(), - ) - continue + device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) + goto skip } // update timers @@ -426,8 +373,8 @@ func (device *Device) RoutineHandshake() { // update endpoint peer.SetEndpointFromPacket(elem.endpoint) - logDebug.Println(peer, "- Received handshake initiation") - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + device.log.Verbosef("%v - Received handshake initiation", peer) + peer.rxBytes.Add(uint64(len(elem.packet))) peer.SendHandshakeResponse() @@ -436,29 +383,25 @@ func (device *Device) RoutineHandshake() { // unmarshal var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) + err := msg.unmarshal(elem.packet) if err != nil { - logError.Println("Failed to decode response message") - continue + device.log.Errorf("Failed to decode response message") + goto skip } // consume response peer := device.ConsumeMessageResponse(&msg) if peer == nil { - logInfo.Println( - "Received invalid response message from", - elem.endpoint.DstToString(), - ) - continue + device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) + goto skip } // update endpoint peer.SetEndpointFromPacket(elem.endpoint) - logDebug.Println(peer, "- Received handshake response") - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + device.log.Verbosef("%v - Received handshake response", peer) + peer.rxBytes.Add(uint64(len(elem.packet))) // update timers @@ -470,178 +413,124 @@ func (device *Device) RoutineHandshake() { err = peer.BeginSymmetricSession() if err != nil { - logError.Println(peer, "- Failed to derive keypair:", err) - continue + device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) + goto skip } peer.timersSessionDerived() peer.timersHandshakeComplete() peer.SendKeepalive() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } } + skip: + device.PutMessageBuffer(elem.buffer) } } -func (peer *Peer) RoutineSequentialReceiver() { - +func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device := peer.device - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - var elem *QueueInboundElement - defer func() { - logDebug.Println(peer, "- Routine: sequential receiver - stopped") - peer.routines.stopping.Done() - if elem != nil { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - } - device.PutInboundElement(elem) - } + device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) + peer.stopping.Done() }() + device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - logDebug.Println(peer, "- Routine: sequential receiver - started") - - peer.routines.starting.Done() - - for { - if elem != nil { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - } - device.PutInboundElement(elem) - elem = nil - } + bufs := make([][]byte, 0, maxBatchSize) - var elemOk bool - select { - case <-peer.routines.stop: + for elemsContainer := range peer.queue.inbound.c { + if elemsContainer == nil { return - case elem, elemOk = <-peer.queue.inbound: - if !elemOk { - return - } - } - - // wait for decryption - - elem.Lock() - - if elem.IsDropped() { - continue - } - - // check for replay - - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - continue - } - - // update endpoint - peer.SetEndpointFromPacket(elem.endpoint) - - // check if using new keypair - if peer.ReceivedWithKeypair(elem.keypair) { - peer.timersHandshakeComplete() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } - } - - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) - - // check for keepalive - - if len(elem.packet) == 0 { - logDebug.Println(peer, "- Receiving keepalive packet") - continue } - peer.timersDataReceived() - - // verify source and strip padding - - switch elem.packet[0] >> 4 { - case ipv4.Version: - - // strip padding - - if len(elem.packet) < ipv4.HeaderLen { + elemsContainer.Lock() + validTailPacket := -1 + dataPacketReceived := false + rxBytesLen := uint64(0) + for i, elem := range elemsContainer.elems { + if elem.packet == nil { + // decryption failed continue } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { continue } - elem.packet = elem.packet[:length] - - // verify IPv4 source - - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.LookupIPv4(src) != peer { - logInfo.Println( - "IPv4 packet with disallowed source address from", - peer, - ) - continue + validTailPacket = i + if peer.ReceivedWithKeypair(elem.keypair) { + peer.SetEndpointFromPacket(elem.endpoint) + peer.timersHandshakeComplete() + peer.SendStagedPackets() } + rxBytesLen += uint64(len(elem.packet) + MinMessageSize) - case ipv6.Version: - - // strip padding - - if len(elem.packet) < ipv6.HeaderLen { + if len(elem.packet) == 0 { + device.log.Verbosef("%v - Receiving keepalive packet", peer) continue } + dataPacketReceived = true - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } - - elem.packet = elem.packet[:length] + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) + continue + } - // verify IPv6 source + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) + continue + } - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.LookupIPv6(src) != peer { - logInfo.Println( - "IPv6 packet with disallowed source address from", - peer, - ) + default: + device.log.Verbosef("Packet with invalid IP version from %v", peer) continue } - default: - logInfo.Println("Packet with invalid IP version from", peer) - continue + bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } - // write to tun device - - offset := MessageTransportOffsetContent - _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) - if len(peer.queue.inbound) == 0 { - err = device.tun.device.Flush() - if err != nil { - peer.device.log.Error.Printf("Unable to flush packets: %v", err) + peer.rxBytes.Add(rxBytesLen) + if validTailPacket >= 0 { + peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + } + if dataPacketReceived { + peer.timersDataReceived() + } + if len(bufs) > 0 { + _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) + if err != nil && !device.isClosed() { + device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - if err != nil && !device.isClosed.Get() { - logError.Println("Failed to write packet to TUN device:", err) + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) } + bufs = bufs[:0] + device.PutInboundElementsContainer(elemsContainer) } } |