diff options
Diffstat (limited to 'device/send.go')
-rw-r--r-- | device/send.go | 355 |
1 files changed, 224 insertions, 131 deletions
diff --git a/device/send.go b/device/send.go index b05c69e..769720a 100644 --- a/device/send.go +++ b/device/send.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -12,12 +12,13 @@ import ( "net" "os" "sync" - "sync/atomic" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/tun" ) /* Outbound flow @@ -45,7 +46,6 @@ import ( */ type QueueOutboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption @@ -53,10 +53,14 @@ type QueueOutboundElement struct { peer *Peer // related peer } +type QueueOutboundElementsContainer struct { + sync.Mutex + elems []*QueueOutboundElement +} + func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} elem.nonce = 0 // keypair and peer were cleared (if necessary) by clearPointers. return elem @@ -76,14 +80,17 @@ func (elem *QueueOutboundElement) clearPointers() { /* Queues a keepalive if no packets are queued for peer */ func (peer *Peer) SendKeepalive() { - if len(peer.queue.staged) == 0 && peer.isRunning.Get() { + if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() + elemsContainer := peer.device.GetOutboundElementsContainer() + elemsContainer.elems = append(elemsContainer.elems, elem) select { - case peer.queue.staged <- elem: + case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) + peer.device.PutOutboundElementsContainer(elemsContainer) } } peer.SendStagedPackets() @@ -91,7 +98,7 @@ func (peer *Peer) SendKeepalive() { func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { if !isRetry { - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.handshakeAttempts.Store(0) } peer.handshake.mutex.RLock() @@ -117,8 +124,8 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - var buff [MessageInitiationSize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageInitiationSize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, msg) packet := writer.Bytes() peer.cookieGenerator.AddMacs(packet) @@ -126,7 +133,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + err = peer.SendBuffers([][]byte{packet}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -148,8 +155,8 @@ func (peer *Peer) SendHandshakeResponse() error { return err } - var buff [MessageResponseSize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageResponseSize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, response) packet := writer.Bytes() peer.cookieGenerator.AddMacs(packet) @@ -164,7 +171,8 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + // TODO: allocation could be avoided + err = peer.SendBuffers([][]byte{packet}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -181,10 +189,11 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) return err } - var buff [MessageCookieReplySize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageCookieReplySize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + // TODO: allocation could be avoided + device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) return nil } @@ -193,17 +202,12 @@ func (peer *Peer) keepKeyFreshSending() { if keypair == nil { return } - nonce := atomic.LoadUint64(&keypair.sendNonce) + nonce := keypair.sendNonce.Load() if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { peer.SendHandshakeInitiation(false) } } -/* Reads packets from the TUN and inserts - * into staged queue for peer - * - * Obs. Single instance per TUN device - */ func (device *Device) RoutineReadFromTUN() { defer func() { device.log.Verbosef("Routine: TUN reader - stopped") @@ -213,82 +217,123 @@ func (device *Device) RoutineReadFromTUN() { device.log.Verbosef("Routine: TUN reader - started") - var elem *QueueOutboundElement + var ( + batchSize = device.BatchSize() + readErr error + elems = make([]*QueueOutboundElement, batchSize) + bufs = make([][]byte, batchSize) + elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) + count = 0 + sizes = make([]int, batchSize) + offset = MessageTransportHeaderSize + ) + + for i := range elems { + elems[i] = device.NewOutboundElement() + bufs[i] = elems[i].buffer[:] + } - for { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + defer func() { + for _, elem := range elems { + if elem != nil { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } } - elem = device.NewOutboundElement() - - // read packet - - offset := MessageTransportHeaderSize - size, err := device.tun.device.Read(elem.buffer[:], offset) + }() - if err != nil { - if !device.isClosed() { - if !errors.Is(err, os.ErrClosed) { - device.log.Errorf("Failed to read packet from TUN device: %v", err) - } - go device.Close() + for { + // read packets + count, readErr = device.tun.device.Read(bufs, sizes, offset) + for i := 0; i < count; i++ { + if sizes[i] < 1 { + continue } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return - } - if size == 0 || size > MaxContentSize { - continue - } + elem := elems[i] + elem.packet = bufs[i][offset : offset+sizes[i]] - elem.packet = elem.buffer[offset : offset+size] + // lookup peer + var peer *Peer + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.allowedips.Lookup(dst) - // lookup peer + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.allowedips.Lookup(dst) - var peer *Peer - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue + default: + device.log.Verbosef("Received packet with unknown IP version") } - dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.allowedips.Lookup(dst) - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { + if peer == nil { continue } - dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.allowedips.Lookup(dst) - - default: - device.log.Verbosef("Received packet with unknown IP version") + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetOutboundElementsContainer() + elemsByPeer[peer] = elemsForPeer + } + elemsForPeer.elems = append(elemsForPeer.elems, elem) + elems[i] = device.NewOutboundElement() + bufs[i] = elems[i].buffer[:] } - if peer == nil { - continue + for peer, elemsForPeer := range elemsByPeer { + if peer.isRunning.Load() { + peer.StagePackets(elemsForPeer) + peer.SendStagedPackets() + } else { + for _, elem := range elemsForPeer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsContainer(elemsForPeer) + } + delete(elemsByPeer, peer) } - if peer.isRunning.Get() { - peer.StagePacket(elem) - elem = nil - peer.SendStagedPackets() + + if readErr != nil { + if errors.Is(readErr, tun.ErrTooManySegments) { + // TODO: record stat for this + // This will happen if MSS is surprisingly small (< 576) + // coincident with reasonably high throughput. + device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) + continue + } + if !device.isClosed() { + if !errors.Is(readErr, os.ErrClosed) { + device.log.Errorf("Failed to read packet from TUN device: %v", readErr) + } + go device.Close() + } + return } } } -func (peer *Peer) StagePacket(elem *QueueOutboundElement) { +func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { for { select { - case peer.queue.staged <- elem: + case peer.queue.staged <- elems: return default: } select { case tooOld := <-peer.queue.staged: - peer.device.PutMessageBuffer(tooOld.buffer) - peer.device.PutOutboundElement(tooOld) + for _, elem := range tooOld.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsContainer(tooOld) default: } } @@ -301,32 +346,59 @@ top: } keypair := peer.keypairs.Current() - if keypair == nil || atomic.LoadUint64(&keypair.sendNonce) >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { + if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { peer.SendHandshakeInitiation(false) return } for { + var elemsContainerOOO *QueueOutboundElementsContainer select { - case elem := <-peer.queue.staged: - elem.peer = peer - elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 - if elem.nonce >= RejectAfterMessages { - atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) - peer.StagePacket(elem) // XXX: Out of order, but we can't front-load go chans - goto top + case elemsContainer := <-peer.queue.staged: + i := 0 + for _, elem := range elemsContainer.elems { + elem.peer = peer + elem.nonce = keypair.sendNonce.Add(1) - 1 + if elem.nonce >= RejectAfterMessages { + keypair.sendNonce.Store(RejectAfterMessages) + if elemsContainerOOO == nil { + elemsContainerOOO = peer.device.GetOutboundElementsContainer() + } + elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) + continue + } else { + elemsContainer.elems[i] = elem + i++ + } + + elem.keypair = keypair } + elemsContainer.Lock() + elemsContainer.elems = elemsContainer.elems[:i] - elem.keypair = keypair - elem.Lock() + if elemsContainerOOO != nil { + peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans + } + + if len(elemsContainer.elems) == 0 { + peer.device.PutOutboundElementsContainer(elemsContainer) + goto top + } // add to parallel and sequential queue - if peer.isRunning.Get() { - peer.queue.outbound.c <- elem - peer.device.queue.encryption.c <- elem + if peer.isRunning.Load() { + peer.queue.outbound.c <- elemsContainer + peer.device.queue.encryption.c <- elemsContainer } else { - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) + for _, elem := range elemsContainer.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsContainer(elemsContainer) + } + + if elemsContainerOOO != nil { + goto top } default: return @@ -337,9 +409,12 @@ top: func (peer *Peer) FlushStagedPackets() { for { select { - case elem := <-peer.queue.staged: - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) + case elemsContainer := <-peer.queue.staged: + for _, elem := range elemsContainer.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsContainer(elemsContainer) default: return } @@ -373,41 +448,38 @@ func (device *Device) RoutineEncryption(id int) { defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) device.log.Verbosef("Routine: encryption worker %d - started", id) - for elem := range device.queue.encryption.c { - // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] + for elemsContainer := range device.queue.encryption.c { + for _, elem := range elemsContainer.elems { + // populate header fields + header := elem.buffer[:MessageTransportHeaderSize] - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - // pad content to multiple of 16 - paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu))) - elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) + // pad content to multiple of 16 + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) - // encrypt content and release to consumer + // encrypt content and release to consumer - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keypair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.Unlock() + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keypair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + } + elemsContainer.Unlock() } } -/* Sequentially reads packets from queue and sends to endpoint - * - * Obs. Single instance per peer. - * The routine terminates then the outbound queue is closed. - */ -func (peer *Peer) RoutineSequentialSender() { +func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device := peer.device defer func() { defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) @@ -415,36 +487,57 @@ func (peer *Peer) RoutineSequentialSender() { }() device.log.Verbosef("%v - Routine: sequential sender - started", peer) - for elem := range peer.queue.outbound.c { - if elem == nil { + bufs := make([][]byte, 0, maxBatchSize) + + for elemsContainer := range peer.queue.outbound.c { + bufs = bufs[:0] + if elemsContainer == nil { return } - elem.Lock() - if !peer.isRunning.Get() { + if !peer.isRunning.Load() { // peer has been stopped; return re-usable elems to the shared pool. // This is an optimization only. It is possible for the peer to be stopped // immediately after this check, in which case, elem will get processed. - // The timers and SendBuffer code are resilient to a few stragglers. + // The timers and SendBuffers code are resilient to a few stragglers. // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } continue } + dataSent := false + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + if len(elem.packet) != MessageKeepaliveSize { + dataSent = true + } + bufs = append(bufs, elem.packet) + } peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - // send message and return buffer to pool - - err := peer.SendBuffer(elem.packet) - if len(elem.packet) != MessageKeepaliveSize { + err := peer.SendBuffers(bufs) + if dataSent { peer.timersDataSent() } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsContainer(elemsContainer) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr + } + } if err != nil { - device.log.Errorf("%v - Failed to send data packet: %v", peer, err) + device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue } |