aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tunnel/wintun_test.go202
1 files changed, 202 insertions, 0 deletions
diff --git a/tunnel/wintun_test.go b/tunnel/wintun_test.go
new file mode 100644
index 00000000..11d7ab2c
--- /dev/null
+++ b/tunnel/wintun_test.go
@@ -0,0 +1,202 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package tunnel_test
+
+import (
+ "bytes"
+ "crypto/rand"
+ "encoding/binary"
+ "fmt"
+ "net"
+ "sync"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/windows"
+
+ "golang.zx2c4.com/wireguard/tun"
+
+ "golang.zx2c4.com/wireguard/windows/elevate"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+)
+
+func TestWintunOrdering(t *testing.T) {
+ var tunDevice tun.Device
+ err := elevate.DoAsSystem(func() error {
+ var err error
+ tunDevice, err = tun.CreateTUNWithRequestedGUID("tunordertest", &windows.GUID{12, 12, 12, [8]byte{12, 12, 12, 12, 12, 12, 12, 12}}, 1500)
+ return err
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tunDevice.Close()
+ nativeTunDevice := tunDevice.(*tun.NativeTun)
+ luid := winipcfg.LUID(nativeTunDevice.LUID())
+ ip, ipnet, _ := net.ParseCIDR("10.82.31.4/24")
+ err = luid.SetIPAddresses([]net.IPNet{{ip, ipnet.Mask}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = luid.SetRoutes([]*winipcfg.RouteData{{*ipnet, ipnet.IP, 0}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ var token [32]byte
+ _, err = rand.Read(token[:])
+ if err != nil {
+ t.Fatal(err)
+ }
+ var sockWrite net.Conn
+ for i := 0; i < 1000; i++ {
+ sockWrite, err = net.Dial("udp", "10.82.31.5:9999")
+ if err == nil {
+ defer sockWrite.Close()
+ break
+ }
+ time.Sleep(time.Millisecond * 100)
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ var sockRead *net.UDPConn
+ for i := 0; i < 1000; i++ {
+ var listenAddress *net.UDPAddr
+ listenAddress, err = net.ResolveUDPAddr("udp", "10.82.31.4:9999")
+ if err != nil {
+ continue
+ }
+ sockRead, err = net.ListenUDP("udp", listenAddress)
+ if err == nil {
+ defer sockRead.Close()
+ break
+ }
+ time.Sleep(time.Millisecond * 100)
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ var wait sync.WaitGroup
+ wait.Add(4)
+ doneSockWrite := false
+ doneTunWrite := false
+ fatalErrors := make(chan error, 2)
+ errors := make(chan error, 2)
+ go func() {
+ defer wait.Done()
+ buffer := append(token[:], 0, 0, 0, 0, 0, 0, 0, 0)
+ for sendingIndex := uint64(0); !doneSockWrite; sendingIndex++ {
+ binary.LittleEndian.PutUint64(buffer[32:], sendingIndex)
+ _, err := sockWrite.Write(buffer[:])
+ if err != nil {
+ fatalErrors <- err
+ }
+ }
+ }()
+ go func() {
+ defer wait.Done()
+ packet := [20 + 8 + 32 + 8]byte{
+ 0x45, 0, 0, 20 + 8 + 32 + 8,
+ 0, 0, 0, 0,
+ 0x80, 0x11, 0, 0,
+ 10, 82, 31, 5,
+ 10, 82, 31, 4,
+ 8888 >> 8, 8888 & 0xff, 9999 >> 8, 9999 & 0xff, 0, 8 + 32 + 8, 0, 0,
+ }
+ copy(packet[28:], token[:])
+ for sendingIndex := uint64(0); !doneTunWrite; sendingIndex++ {
+ binary.BigEndian.PutUint16(packet[4:], uint16(sendingIndex))
+ var checksum uint32
+ for i := 0; i < 20; i += 2 {
+ if i != 10 {
+ checksum += uint32(binary.BigEndian.Uint16(packet[i:]))
+ }
+ }
+ binary.BigEndian.PutUint16(packet[10:], ^(uint16(checksum>>16) + uint16(checksum&0xffff)))
+ binary.LittleEndian.PutUint64(packet[20+8+32:], sendingIndex)
+ n, err := tunDevice.Write(packet[:], 0)
+ if err != nil {
+ fatalErrors <- err
+ }
+ if n == 0 {
+ time.Sleep(time.Millisecond * 300)
+ }
+ }
+ }()
+ const packetsPerTest = 1 << 21
+ go func() {
+ defer func() {
+ doneSockWrite = true
+ wait.Done()
+ }()
+ var expectedIndex uint64
+ for i := uint64(0); i < packetsPerTest; {
+ var buffer [(1 << 16) - 1]byte
+ bytesRead, err := tunDevice.Read(buffer[:], 0)
+ if err != nil {
+ fatalErrors <- err
+ }
+ if bytesRead < 0 || bytesRead > len(buffer) {
+ continue
+ }
+ packet := buffer[:bytesRead]
+ tokenPos := bytes.Index(packet, token[:])
+ if tokenPos == -1 || tokenPos+32+8 > len(packet) {
+ continue
+ }
+ foundIndex := binary.LittleEndian.Uint64(packet[tokenPos+32:])
+ if foundIndex < expectedIndex {
+ errors <- fmt.Errorf("Sock write, tun read: expected packet %d, received packet %d", expectedIndex, foundIndex)
+ }
+ expectedIndex = foundIndex + 1
+ i++
+ }
+ }()
+ go func() {
+ defer func() {
+ doneTunWrite = true
+ wait.Done()
+ }()
+ var expectedIndex uint64
+ for i := uint64(0); i < packetsPerTest; {
+ var buffer [(1 << 16) - 1]byte
+ bytesRead, err := sockRead.Read(buffer[:])
+ if err != nil {
+ fatalErrors <- err
+ }
+ if bytesRead < 0 || bytesRead > len(buffer) {
+ continue
+ }
+ packet := buffer[:bytesRead]
+ if len(packet) != 32+8 || !bytes.HasPrefix(packet, token[:]) {
+ continue
+ }
+ foundIndex := binary.LittleEndian.Uint64(packet[32:])
+ if foundIndex < expectedIndex {
+ errors <- fmt.Errorf("Tun write, sock read: expected packet %d, received packet %d", expectedIndex, foundIndex)
+ }
+ expectedIndex = foundIndex + 1
+ i++
+ }
+ }()
+ done := make(chan bool, 2)
+ doneFunc := func() {
+ wait.Wait()
+ done <- true
+ }
+ defer doneFunc()
+ go doneFunc()
+ for {
+ select {
+ case err := <-fatalErrors:
+ t.Fatal(err)
+ case err := <-errors:
+ t.Error(err)
+ case <-done:
+ return
+ }
+ }
+}