summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-03-21 14:43:04 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2019-03-21 14:45:41 -0600
commit6440f010eec82abb9c999771a8f493af44c6b937 (patch)
tree187207b768cb2bb477879a1abfad8b732bbece45
parenttun: windows: add dummy overlapped events back (diff)
downloadwireguard-go-6440f010eec82abb9c999771a8f493af44c6b937.tar.xz
wireguard-go-6440f010eec82abb9c999771a8f493af44c6b937.zip
receive: implement flush semantics
-rw-r--r--device/boundif_darwin.go2
-rw-r--r--device/boundif_windows.go2
-rw-r--r--device/conn.go2
-rw-r--r--device/queueconstants_android.go2
-rw-r--r--device/receive.go204
-rw-r--r--tun/operateonfd.go (renamed from tun/tun_default.go)0
-rw-r--r--tun/tun.go1
-rw-r--r--tun/tun_darwin.go5
-rw-r--r--tun/tun_freebsd.go5
-rw-r--r--tun/tun_linux.go5
-rw-r--r--tun/tun_openbsd.go5
-rw-r--r--tun/tun_windows.go12
12 files changed, 147 insertions, 98 deletions
diff --git a/device/boundif_darwin.go b/device/boundif_darwin.go
index b3d10ba..a93441c 100644
--- a/device/boundif_darwin.go
+++ b/device/boundif_darwin.go
@@ -41,4 +41,4 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
return err
}
return nil
-} \ No newline at end of file
+}
diff --git a/device/boundif_windows.go b/device/boundif_windows.go
index 00631cb..97381ad 100644
--- a/device/boundif_windows.go
+++ b/device/boundif_windows.go
@@ -53,4 +53,4 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
return err
}
return nil
-} \ No newline at end of file
+}
diff --git a/device/conn.go b/device/conn.go
index 2594680..3c2aa04 100644
--- a/device/conn.go
+++ b/device/conn.go
@@ -177,4 +177,4 @@ func (device *Device) BindClose() error {
err := unsafeCloseBind(device)
device.net.Unlock()
return err
-} \ No newline at end of file
+}
diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go
index 8d051ad..f5c042d 100644
--- a/device/queueconstants_android.go
+++ b/device/queueconstants_android.go
@@ -13,4 +13,4 @@ const (
QueueHandshakeSize = 1024
MaxSegmentSize = 2200
PreallocatedBuffersPerPool = 4096
-) \ No newline at end of file
+)
diff --git a/device/receive.go b/device/receive.go
index 09fae59..747a188 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -482,6 +482,33 @@ func (device *Device) RoutineHandshake() {
}
}
+func (peer *Peer) elementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueInboundElement) {
+ if !*shouldFlush {
+ select {
+ case <-peer.routines.stop:
+ stop = true
+ return
+ case elem, elemOk = <-peer.queue.inbound:
+ return
+ }
+ } else {
+ select {
+ case <-peer.routines.stop:
+ stop = true
+ return
+ case elem, elemOk = <-peer.queue.inbound:
+ return
+ default:
+ *shouldFlush = false
+ err := peer.device.tun.device.Flush()
+ if err != nil {
+ peer.device.log.Error.Printf("Unable to flush packets: %v", err)
+ }
+ return peer.elementStopOrFlush(shouldFlush)
+ }
+ }
+}
+
func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device
@@ -491,6 +518,9 @@ func (peer *Peer) RoutineSequentialReceiver() {
var elem *QueueInboundElement
var ok bool
+ var stop bool
+
+ shouldFlush := false
defer func() {
logDebug.Println(peer, "- Routine: sequential receiver - stopped")
@@ -516,126 +546,122 @@ func (peer *Peer) RoutineSequentialReceiver() {
elem = nil
}
- select {
-
- case <-peer.routines.stop:
+ stop, ok, elem = peer.elementStopOrFlush(&shouldFlush)
+ if stop || !ok {
return
+ }
- case elem, ok = <-peer.queue.inbound:
-
- if !ok {
- return
- }
-
- // wait for decryption
+ // wait for decryption
- elem.Lock()
+ elem.Lock()
- if elem.IsDropped() {
- continue
- }
+ if elem.IsDropped() {
+ continue
+ }
- // check for replay
+ // check for replay
- if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
- continue
- }
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
+ continue
+ }
- // update endpoint
- peer.SetEndpointFromPacket(elem.endpoint)
+ // update endpoint
+ peer.SetEndpointFromPacket(elem.endpoint)
- // check if using new keypair
- if peer.ReceivedWithKeypair(elem.keypair) {
- peer.timersHandshakeComplete()
- select {
- case peer.signals.newKeypairArrived <- struct{}{}:
- default:
- }
+ // check if using new keypair
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.timersHandshakeComplete()
+ select {
+ case peer.signals.newKeypairArrived <- struct{}{}:
+ default:
}
+ }
- peer.keepKeyFreshReceiving()
- peer.timersAnyAuthenticatedPacketTraversal()
- peer.timersAnyAuthenticatedPacketReceived()
-
- // check for keepalive
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
- if len(elem.packet) == 0 {
- logDebug.Println(peer, "- Receiving keepalive packet")
- continue
- }
- peer.timersDataReceived()
+ // check for keepalive
- // verify source and strip padding
+ if len(elem.packet) == 0 {
+ logDebug.Println(peer, "- Receiving keepalive packet")
+ continue
+ }
+ peer.timersDataReceived()
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
+ // verify source and strip padding
- // strip padding
+ switch elem.packet[0] >> 4 {
+ case ipv4.Version:
- if len(elem.packet) < ipv4.HeaderLen {
- continue
- }
+ // strip padding
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
- continue
- }
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
- elem.packet = elem.packet[:length]
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
- // verify IPv4 source
+ elem.packet = elem.packet[:length]
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.allowedips.LookupIPv4(src) != peer {
- logInfo.Println(
- "IPv4 packet with disallowed source address from",
- peer,
- )
- continue
- }
+ // verify IPv4 source
- case ipv6.Version:
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.allowedips.LookupIPv4(src) != peer {
+ logInfo.Println(
+ "IPv4 packet with disallowed source address from",
+ peer,
+ )
+ continue
+ }
- // strip padding
+ case ipv6.Version:
- if len(elem.packet) < ipv6.HeaderLen {
- continue
- }
+ // strip padding
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- if int(length) > len(elem.packet) {
- continue
- }
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
- elem.packet = elem.packet[:length]
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
- // verify IPv6 source
+ elem.packet = elem.packet[:length]
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.allowedips.LookupIPv6(src) != peer {
- logInfo.Println(
- peer,
- "sent packet with disallowed IPv6 source",
- )
- continue
- }
+ // verify IPv6 source
- default:
- logInfo.Println("Packet with invalid IP version from", peer)
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.allowedips.LookupIPv6(src) != peer {
+ logInfo.Println(
+ peer,
+ "sent packet with disallowed IPv6 source",
+ )
continue
}
- // write to tun device
+ default:
+ logInfo.Println("Packet with invalid IP version from", peer)
+ continue
+ }
+
+ // write to tun device
- offset := MessageTransportOffsetContent
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
- _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
- if err != nil && !device.isClosed.Get() {
- logError.Println("Failed to write packet to TUN device:", err)
- }
+ offset := MessageTransportOffsetContent
+ atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
+ if err == nil {
+ shouldFlush = true
+ }
+ if err != nil && !device.isClosed.Get() {
+ logError.Println("Failed to write packet to TUN device:", err)
}
}
}
diff --git a/tun/tun_default.go b/tun/operateonfd.go
index 31747a2..31747a2 100644
--- a/tun/tun_default.go
+++ b/tun/operateonfd.go
diff --git a/tun/tun.go b/tun/tun.go
index c4b6cac..12febb8 100644
--- a/tun/tun.go
+++ b/tun/tun.go
@@ -21,6 +21,7 @@ type TUNDevice interface {
File() *os.File // returns the file descriptor of the device
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
+ Flush() error // flush all previous writes to the device
MTU() (int, error) // returns the MTU of the device
Name() (string, error) // fetches and returns the current name
Events() chan TUNEvent // returns a constant channel of events related to the device
diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go
index 3b39982..2077de3 100644
--- a/tun/tun_darwin.go
+++ b/tun/tun_darwin.go
@@ -281,6 +281,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.tunFile.Write(buff)
}
+func (tun *NativeTun) Flush() error {
+ //TODO: can flushing be implemented by buffering and using sendmmsg?
+ return nil
+}
+
func (tun *NativeTun) Close() error {
var err2 error
err1 := tun.tunFile.Close()
diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go
index 3a60725..01a4348 100644
--- a/tun/tun_freebsd.go
+++ b/tun/tun_freebsd.go
@@ -406,6 +406,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.tunFile.Write(buff)
}
+func (tun *NativeTun) Flush() error {
+ //TODO: can flushing be implemented by buffering and using sendmmsg?
+ return nil
+}
+
func (tun *NativeTun) Close() error {
var err3 error
err1 := tun.tunFile.Close()
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index b7c429c..784cb9f 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -318,6 +318,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.tunFile.Write(buff)
}
+func (tun *NativeTun) Flush() error {
+ //TODO: can flushing be implemented by buffering and using sendmmsg?
+ return nil
+}
+
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
select {
case err := <-tun.errors:
diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go
index 57edcb4..645bcca 100644
--- a/tun/tun_openbsd.go
+++ b/tun/tun_openbsd.go
@@ -237,6 +237,11 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.tunFile.Write(buff)
}
+func (tun *NativeTun) Flush() error {
+ //TODO: can flushing be implemented by buffering and using sendmmsg?
+ return nil
+}
+
func (tun *NativeTun) Close() error {
var err2 error
err1 := tun.tunFile.Close()
diff --git a/tun/tun_windows.go b/tun/tun_windows.go
index dcb414a..fffd802 100644
--- a/tun/tun_windows.go
+++ b/tun/tun_windows.go
@@ -281,7 +281,11 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
// Note: flush() and putTunPacket() assume the caller comes only from a single thread; there's no locking.
-func (tun *NativeTun) flush() error {
+func (tun *NativeTun) Flush() error {
+ if tun.wrBuff.offset == 0 {
+ return nil
+ }
+
// Get TUN data pipe.
file, err := tun.getTUN()
if err != nil {
@@ -322,7 +326,7 @@ func (tun *NativeTun) putTunPacket(buff []byte) error {
if tun.wrBuff.packetNum >= packetExchangeMax || tun.wrBuff.offset+pSize >= packetExchangeSize {
// Exchange buffer is full -> flush first.
- err := tun.flush()
+ err := tun.Flush()
if err != nil {
return err
}
@@ -345,9 +349,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
if err != nil {
return 0, err
}
-
- // Flush write buffer.
- return len(buff) - offset, tun.flush()
+ return len(buff) - offset, nil
}
//