From ba6ad6622544e87adfa3a10805d9497812c5d864 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sun, 24 Nov 2019 21:21:15 +0100 Subject: tunnel: add wintun ordered unit test Signed-off-by: Jason A. Donenfeld --- tunnel/wintun_test.go | 202 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 tunnel/wintun_test.go (limited to 'tunnel') diff --git a/tunnel/wintun_test.go b/tunnel/wintun_test.go new file mode 100644 index 00000000..11d7ab2c --- /dev/null +++ b/tunnel/wintun_test.go @@ -0,0 +1,202 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package tunnel_test + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "fmt" + "net" + "sync" + "testing" + "time" + + "golang.org/x/sys/windows" + + "golang.zx2c4.com/wireguard/tun" + + "golang.zx2c4.com/wireguard/windows/elevate" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func TestWintunOrdering(t *testing.T) { + var tunDevice tun.Device + err := elevate.DoAsSystem(func() error { + var err error + tunDevice, err = tun.CreateTUNWithRequestedGUID("tunordertest", &windows.GUID{12, 12, 12, [8]byte{12, 12, 12, 12, 12, 12, 12, 12}}, 1500) + return err + }) + if err != nil { + t.Fatal(err) + } + defer tunDevice.Close() + nativeTunDevice := tunDevice.(*tun.NativeTun) + luid := winipcfg.LUID(nativeTunDevice.LUID()) + ip, ipnet, _ := net.ParseCIDR("10.82.31.4/24") + err = luid.SetIPAddresses([]net.IPNet{{ip, ipnet.Mask}}) + if err != nil { + t.Fatal(err) + } + err = luid.SetRoutes([]*winipcfg.RouteData{{*ipnet, ipnet.IP, 0}}) + if err != nil { + t.Fatal(err) + } + var token [32]byte + _, err = rand.Read(token[:]) + if err != nil { + t.Fatal(err) + } + var sockWrite net.Conn + for i := 0; i < 1000; i++ { + sockWrite, err = net.Dial("udp", "10.82.31.5:9999") + if err == nil { + defer sockWrite.Close() + break + } + time.Sleep(time.Millisecond * 100) + } + if err != nil { + t.Fatal(err) + } + var sockRead *net.UDPConn + for i := 0; i < 1000; i++ { + var listenAddress *net.UDPAddr + listenAddress, err = net.ResolveUDPAddr("udp", "10.82.31.4:9999") + if err != nil { + continue + } + sockRead, err = net.ListenUDP("udp", listenAddress) + if err == nil { + defer sockRead.Close() + break + } + time.Sleep(time.Millisecond * 100) + } + if err != nil { + t.Fatal(err) + } + var wait sync.WaitGroup + wait.Add(4) + doneSockWrite := false + doneTunWrite := false + fatalErrors := make(chan error, 2) + errors := make(chan error, 2) + go func() { + defer wait.Done() + buffer := append(token[:], 0, 0, 0, 0, 0, 0, 0, 0) + for sendingIndex := uint64(0); !doneSockWrite; sendingIndex++ { + binary.LittleEndian.PutUint64(buffer[32:], sendingIndex) + _, err := sockWrite.Write(buffer[:]) + if err != nil { + fatalErrors <- err + } + } + }() + go func() { + defer wait.Done() + packet := [20 + 8 + 32 + 8]byte{ + 0x45, 0, 0, 20 + 8 + 32 + 8, + 0, 0, 0, 0, + 0x80, 0x11, 0, 0, + 10, 82, 31, 5, + 10, 82, 31, 4, + 8888 >> 8, 8888 & 0xff, 9999 >> 8, 9999 & 0xff, 0, 8 + 32 + 8, 0, 0, + } + copy(packet[28:], token[:]) + for sendingIndex := uint64(0); !doneTunWrite; sendingIndex++ { + binary.BigEndian.PutUint16(packet[4:], uint16(sendingIndex)) + var checksum uint32 + for i := 0; i < 20; i += 2 { + if i != 10 { + checksum += uint32(binary.BigEndian.Uint16(packet[i:])) + } + } + binary.BigEndian.PutUint16(packet[10:], ^(uint16(checksum>>16) + uint16(checksum&0xffff))) + binary.LittleEndian.PutUint64(packet[20+8+32:], sendingIndex) + n, err := tunDevice.Write(packet[:], 0) + if err != nil { + fatalErrors <- err + } + if n == 0 { + time.Sleep(time.Millisecond * 300) + } + } + }() + const packetsPerTest = 1 << 21 + go func() { + defer func() { + doneSockWrite = true + wait.Done() + }() + var expectedIndex uint64 + for i := uint64(0); i < packetsPerTest; { + var buffer [(1 << 16) - 1]byte + bytesRead, err := tunDevice.Read(buffer[:], 0) + if err != nil { + fatalErrors <- err + } + if bytesRead < 0 || bytesRead > len(buffer) { + continue + } + packet := buffer[:bytesRead] + tokenPos := bytes.Index(packet, token[:]) + if tokenPos == -1 || tokenPos+32+8 > len(packet) { + continue + } + foundIndex := binary.LittleEndian.Uint64(packet[tokenPos+32:]) + if foundIndex < expectedIndex { + errors <- fmt.Errorf("Sock write, tun read: expected packet %d, received packet %d", expectedIndex, foundIndex) + } + expectedIndex = foundIndex + 1 + i++ + } + }() + go func() { + defer func() { + doneTunWrite = true + wait.Done() + }() + var expectedIndex uint64 + for i := uint64(0); i < packetsPerTest; { + var buffer [(1 << 16) - 1]byte + bytesRead, err := sockRead.Read(buffer[:]) + if err != nil { + fatalErrors <- err + } + if bytesRead < 0 || bytesRead > len(buffer) { + continue + } + packet := buffer[:bytesRead] + if len(packet) != 32+8 || !bytes.HasPrefix(packet, token[:]) { + continue + } + foundIndex := binary.LittleEndian.Uint64(packet[32:]) + if foundIndex < expectedIndex { + errors <- fmt.Errorf("Tun write, sock read: expected packet %d, received packet %d", expectedIndex, foundIndex) + } + expectedIndex = foundIndex + 1 + i++ + } + }() + done := make(chan bool, 2) + doneFunc := func() { + wait.Wait() + done <- true + } + defer doneFunc() + go doneFunc() + for { + select { + case err := <-fatalErrors: + t.Fatal(err) + case err := <-errors: + t.Error(err) + case <-done: + return + } + } +} -- cgit v1.2.3-59-g8ed1b