diff options
Diffstat (limited to 'device/send.go')
-rw-r--r-- | device/send.go | 143 |
1 files changed, 77 insertions, 66 deletions
diff --git a/device/send.go b/device/send.go index d22bf26..769720a 100644 --- a/device/send.go +++ b/device/send.go @@ -17,6 +17,7 @@ import ( "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/tun" ) @@ -45,7 +46,6 @@ import ( */ type QueueOutboundElement struct { - sync.Mutex buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption @@ -53,10 +53,14 @@ type QueueOutboundElement struct { peer *Peer // related peer } +type QueueOutboundElementsContainer struct { + sync.Mutex + elems []*QueueOutboundElement +} + func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} elem.nonce = 0 // keypair and peer were cleared (if necessary) by clearPointers. return elem @@ -78,15 +82,15 @@ func (elem *QueueOutboundElement) clearPointers() { func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { elem := peer.device.NewOutboundElement() - elems := peer.device.GetOutboundElementsSlice() - *elems = append(*elems, elem) + elemsContainer := peer.device.GetOutboundElementsContainer() + elemsContainer.elems = append(elemsContainer.elems, elem) select { - case peer.queue.staged <- elems: + case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } } peer.SendStagedPackets() @@ -218,7 +222,7 @@ func (device *Device) RoutineReadFromTUN() { readErr error elems = make([]*QueueOutboundElement, batchSize) bufs = make([][]byte, batchSize) - elemsByPeer = make(map[*Peer]*[]*QueueOutboundElement, batchSize) + elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 sizes = make([]int, batchSize) offset = MessageTransportHeaderSize @@ -275,10 +279,10 @@ func (device *Device) RoutineReadFromTUN() { } elemsForPeer, ok := elemsByPeer[peer] if !ok { - elemsForPeer = device.GetOutboundElementsSlice() + elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } - *elemsForPeer = append(*elemsForPeer, elem) + elemsForPeer.elems = append(elemsForPeer.elems, elem) elems[i] = device.NewOutboundElement() bufs[i] = elems[i].buffer[:] } @@ -288,11 +292,11 @@ func (device *Device) RoutineReadFromTUN() { peer.StagePackets(elemsForPeer) peer.SendStagedPackets() } else { - for _, elem := range *elemsForPeer { + for _, elem := range elemsForPeer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elemsForPeer) + device.PutOutboundElementsContainer(elemsForPeer) } delete(elemsByPeer, peer) } @@ -316,7 +320,7 @@ func (device *Device) RoutineReadFromTUN() { } } -func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { +func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { for { select { case peer.queue.staged <- elems: @@ -325,11 +329,11 @@ func (peer *Peer) StagePackets(elems *[]*QueueOutboundElement) { } select { case tooOld := <-peer.queue.staged: - for _, elem := range *tooOld { + for _, elem := range tooOld.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(tooOld) + peer.device.PutOutboundElementsContainer(tooOld) default: } } @@ -348,54 +352,52 @@ top: } for { - var elemsOOO *[]*QueueOutboundElement + var elemsContainerOOO *QueueOutboundElementsContainer select { - case elems := <-peer.queue.staged: + case elemsContainer := <-peer.queue.staged: i := 0 - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { elem.peer = peer elem.nonce = keypair.sendNonce.Add(1) - 1 if elem.nonce >= RejectAfterMessages { keypair.sendNonce.Store(RejectAfterMessages) - if elemsOOO == nil { - elemsOOO = peer.device.GetOutboundElementsSlice() + if elemsContainerOOO == nil { + elemsContainerOOO = peer.device.GetOutboundElementsContainer() } - *elemsOOO = append(*elemsOOO, elem) + elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) continue } else { - (*elems)[i] = elem + elemsContainer.elems[i] = elem i++ } elem.keypair = keypair - elem.Lock() } - *elems = (*elems)[:i] + elemsContainer.Lock() + elemsContainer.elems = elemsContainer.elems[:i] - if elemsOOO != nil { - peer.StagePackets(elemsOOO) // XXX: Out of order, but we can't front-load go chans + if elemsContainerOOO != nil { + peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans } - if len(*elems) == 0 { - peer.device.PutOutboundElementsSlice(elems) + if len(elemsContainer.elems) == 0 { + peer.device.PutOutboundElementsContainer(elemsContainer) goto top } // add to parallel and sequential queue if peer.isRunning.Load() { - peer.queue.outbound.c <- elems - for _, elem := range *elems { - peer.device.queue.encryption.c <- elem - } + peer.queue.outbound.c <- elemsContainer + peer.device.queue.encryption.c <- elemsContainer } else { - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) } - if elemsOOO != nil { + if elemsContainerOOO != nil { goto top } default: @@ -407,12 +409,12 @@ top: func (peer *Peer) FlushStagedPackets() { for { select { - case elems := <-peer.queue.staged: - for _, elem := range *elems { + case elemsContainer := <-peer.queue.staged: + for _, elem := range elemsContainer.elems { peer.device.PutMessageBuffer(elem.buffer) peer.device.PutOutboundElement(elem) } - peer.device.PutOutboundElementsSlice(elems) + peer.device.PutOutboundElementsContainer(elemsContainer) default: return } @@ -446,32 +448,34 @@ func (device *Device) RoutineEncryption(id int) { defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) device.log.Verbosef("Routine: encryption worker %d - started", id) - for elem := range device.queue.encryption.c { - // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] + for elemsContainer := range device.queue.encryption.c { + for _, elem := range elemsContainer.elems { + // populate header fields + header := elem.buffer[:MessageTransportHeaderSize] - fieldType := header[0:4] - fieldReceiver := header[4:8] - fieldNonce := header[8:16] + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] - binary.LittleEndian.PutUint32(fieldType, MessageTransportType) - binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) - binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) - // pad content to multiple of 16 - paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) - elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) + // pad content to multiple of 16 + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) - // encrypt content and release to consumer + // encrypt content and release to consumer - binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) - elem.packet = elem.keypair.send.Seal( - header, - nonce[:], - elem.packet, - nil, - ) - elem.Unlock() + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keypair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + } + elemsContainer.Unlock() } } @@ -485,9 +489,9 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { bufs := make([][]byte, 0, maxBatchSize) - for elems := range peer.queue.outbound.c { + for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0] - if elems == nil { + if elemsContainer == nil { return } if !peer.isRunning.Load() { @@ -497,16 +501,16 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { // The timers and SendBuffers code are resilient to a few stragglers. // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } continue } dataSent := false - for _, elem := range *elems { - elem.Lock() + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { dataSent = true } @@ -520,11 +524,18 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { if dataSent { peer.timersDataSent() } - for _, elem := range *elems { + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) } - device.PutOutboundElementsSlice(elems) + device.PutOutboundElementsContainer(elemsContainer) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr + } + } if err != nil { device.log.Errorf("%v - Failed to send data packets: %v", peer, err) continue |