From 6440f010eec82abb9c999771a8f493af44c6b937 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 21 Mar 2019 14:43:04 -0600 Subject: receive: implement flush semantics --- device/receive.go | 204 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 115 insertions(+), 89 deletions(-) (limited to 'device/receive.go') diff --git a/device/receive.go b/device/receive.go index 09fae59..747a188 100644 --- a/device/receive.go +++ b/device/receive.go @@ -482,6 +482,33 @@ func (device *Device) RoutineHandshake() { } } +func (peer *Peer) elementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueInboundElement) { + if !*shouldFlush { + select { + case <-peer.routines.stop: + stop = true + return + case elem, elemOk = <-peer.queue.inbound: + return + } + } else { + select { + case <-peer.routines.stop: + stop = true + return + case elem, elemOk = <-peer.queue.inbound: + return + default: + *shouldFlush = false + err := peer.device.tun.device.Flush() + if err != nil { + peer.device.log.Error.Printf("Unable to flush packets: %v", err) + } + return peer.elementStopOrFlush(shouldFlush) + } + } +} + func (peer *Peer) RoutineSequentialReceiver() { device := peer.device @@ -491,6 +518,9 @@ func (peer *Peer) RoutineSequentialReceiver() { var elem *QueueInboundElement var ok bool + var stop bool + + shouldFlush := false defer func() { logDebug.Println(peer, "- Routine: sequential receiver - stopped") @@ -516,126 +546,122 @@ func (peer *Peer) RoutineSequentialReceiver() { elem = nil } - select { - - case <-peer.routines.stop: + stop, ok, elem = peer.elementStopOrFlush(&shouldFlush) + if stop || !ok { return + } - case elem, ok = <-peer.queue.inbound: - - if !ok { - return - } - - // wait for decryption + // wait for decryption - elem.Lock() + elem.Lock() - if elem.IsDropped() { - continue - } + if elem.IsDropped() { + continue + } - // check for replay + // check for replay - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - continue - } + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { + continue + } - // update endpoint - peer.SetEndpointFromPacket(elem.endpoint) + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) - // check if using new keypair - if peer.ReceivedWithKeypair(elem.keypair) { - peer.timersHandshakeComplete() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } + // check if using new keypair + if peer.ReceivedWithKeypair(elem.keypair) { + peer.timersHandshakeComplete() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: } + } - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - - // check for keepalive + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() - if len(elem.packet) == 0 { - logDebug.Println(peer, "- Receiving keepalive packet") - continue - } - peer.timersDataReceived() + // check for keepalive - // verify source and strip padding + if len(elem.packet) == 0 { + logDebug.Println(peer, "- Receiving keepalive packet") + continue + } + peer.timersDataReceived() - switch elem.packet[0] >> 4 { - case ipv4.Version: + // verify source and strip padding - // strip padding + switch elem.packet[0] >> 4 { + case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue - } + // strip padding - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - continue - } + if len(elem.packet) < ipv4.HeaderLen { + continue + } - elem.packet = elem.packet[:length] + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } - // verify IPv4 source + elem.packet = elem.packet[:length] - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.LookupIPv4(src) != peer { - logInfo.Println( - "IPv4 packet with disallowed source address from", - peer, - ) - continue - } + // verify IPv4 source - case ipv6.Version: + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.LookupIPv4(src) != peer { + logInfo.Println( + "IPv4 packet with disallowed source address from", + peer, + ) + continue + } - // strip padding + case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { - continue - } + // strip padding - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } + if len(elem.packet) < ipv6.HeaderLen { + continue + } - elem.packet = elem.packet[:length] + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } - // verify IPv6 source + elem.packet = elem.packet[:length] - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.LookupIPv6(src) != peer { - logInfo.Println( - peer, - "sent packet with disallowed IPv6 source", - ) - continue - } + // verify IPv6 source - default: - logInfo.Println("Packet with invalid IP version from", peer) + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.LookupIPv6(src) != peer { + logInfo.Println( + peer, + "sent packet with disallowed IPv6 source", + ) continue } - // write to tun device + default: + logInfo.Println("Packet with invalid IP version from", peer) + continue + } + + // write to tun device - offset := MessageTransportOffsetContent - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) - _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) - if err != nil && !device.isClosed.Get() { - logError.Println("Failed to write packet to TUN device:", err) - } + offset := MessageTransportOffsetContent + atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) + if err == nil { + shouldFlush = true + } + if err != nil && !device.isClosed.Get() { + logError.Println("Failed to write packet to TUN device:", err) } } } -- cgit v1.2.3-59-g8ed1b