aboutsummaryrefslogtreecommitdiffstats
path: root/device/receive.go
diff options
context:
space:
mode:
Diffstat (limited to 'device/receive.go')
-rw-r--r--device/receive.go103
1 files changed, 59 insertions, 44 deletions
diff --git a/device/receive.go b/device/receive.go
index e24d29f..1ab3e29 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -27,7 +27,6 @@ type QueueHandshakeElement struct {
}
type QueueInboundElement struct {
- sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
@@ -35,6 +34,11 @@ type QueueInboundElement struct {
endpoint conn.Endpoint
}
+type QueueInboundElementsContainer struct {
+ sync.Mutex
+ elems []*QueueInboundElement
+}
+
// clearPointers clears elem fields that contain pointers.
// This makes the garbage collector's life easier and
// avoids accidentally keeping other objects around unnecessarily.
@@ -87,7 +91,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
count int
endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
- elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
+ elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
)
for i := range bufsArrs {
@@ -170,15 +174,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
elem.keypair = keypair
elem.endpoint = endpoints[i]
elem.counter = 0
- elem.Mutex = sync.Mutex{}
- elem.Lock()
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
- elemsForPeer = device.GetInboundElementsSlice()
+ elemsForPeer = device.GetInboundElementsContainer()
+ elemsForPeer.Lock()
elemsByPeer[peer] = elemsForPeer
}
- *elemsForPeer = append(*elemsForPeer, elem)
+ elemsForPeer.elems = append(elemsForPeer.elems, elem)
bufsArrs[i] = device.GetMessageBuffer()
bufs[i] = bufsArrs[i][:]
continue
@@ -217,18 +220,16 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive
default:
}
}
- for peer, elems := range elemsByPeer {
+ for peer, elemsContainer := range elemsByPeer {
if peer.isRunning.Load() {
- peer.queue.inbound.c <- elems
- for _, elem := range *elems {
- device.queue.decryption.c <- elem
- }
+ peer.queue.inbound.c <- elemsContainer
+ device.queue.decryption.c <- elemsContainer
} else {
- for _, elem := range *elems {
+ for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
- device.PutInboundElementsSlice(elems)
+ device.PutInboundElementsContainer(elemsContainer)
}
delete(elemsByPeer, peer)
}
@@ -241,26 +242,28 @@ func (device *Device) RoutineDecryption(id int) {
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
device.log.Verbosef("Routine: decryption worker %d - started", id)
- for elem := range device.queue.decryption.c {
- // split message into fields
- counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
- content := elem.packet[MessageTransportOffsetContent:]
-
- // decrypt and release to consumer
- var err error
- elem.counter = binary.LittleEndian.Uint64(counter)
- // copy counter to nonce
- binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
- elem.packet, err = elem.keypair.receive.Open(
- content[:0],
- nonce[:],
- content,
- nil,
- )
- if err != nil {
- elem.packet = nil
+ for elemsContainer := range device.queue.decryption.c {
+ for _, elem := range elemsContainer.elems {
+ // split message into fields
+ counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
+ content := elem.packet[MessageTransportOffsetContent:]
+
+ // decrypt and release to consumer
+ var err error
+ elem.counter = binary.LittleEndian.Uint64(counter)
+ // copy counter to nonce
+ binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
+ elem.packet, err = elem.keypair.receive.Open(
+ content[:0],
+ nonce[:],
+ content,
+ nil,
+ )
+ if err != nil {
+ elem.packet = nil
+ }
}
- elem.Unlock()
+ elemsContainer.Unlock()
}
}
@@ -437,12 +440,15 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
bufs := make([][]byte, 0, maxBatchSize)
- for elems := range peer.queue.inbound.c {
- if elems == nil {
+ for elemsContainer := range peer.queue.inbound.c {
+ if elemsContainer == nil {
return
}
- for _, elem := range *elems {
- elem.Lock()
+ elemsContainer.Lock()
+ validTailPacket := -1
+ dataPacketReceived := false
+ rxBytesLen := uint64(0)
+ for i, elem := range elemsContainer.elems {
if elem.packet == nil {
// decryption failed
continue
@@ -452,21 +458,19 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
continue
}
- peer.SetEndpointFromPacket(elem.endpoint)
+ validTailPacket = i
if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.SetEndpointFromPacket(elem.endpoint)
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
- peer.keepKeyFreshReceiving()
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketReceived()
- peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
+ rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
- peer.timersDataReceived()
+ dataPacketReceived = true
switch elem.packet[0] >> 4 {
case 4:
@@ -509,17 +513,28 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
+
+ peer.rxBytes.Add(rxBytesLen)
+ if validTailPacket >= 0 {
+ peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
+ }
+ if dataPacketReceived {
+ peer.timersDataReceived()
+ }
if len(bufs) > 0 {
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
- for _, elem := range *elems {
+ for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
bufs = bufs[:0]
- device.PutInboundElementsSlice(elems)
+ device.PutInboundElementsContainer(elemsContainer)
}
}