diff options
Diffstat (limited to '')
-rw-r--r-- | tunnel/addressconfig.go | 5 | ||||
-rw-r--r-- | tunnel/defaultroutemonitor.go | 159 | ||||
-rw-r--r-- | tunnel/interfacewatcher.go | 79 | ||||
-rw-r--r-- | tunnel/ipcpermissions.go | 63 | ||||
-rw-r--r-- | tunnel/service.go | 164 | ||||
-rw-r--r-- | tunnel/winipcfg/winipcfg_test.go | 4 | ||||
-rw-r--r-- | tunnel/wintun_test.go | 202 |
7 files changed, 55 insertions, 621 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go index c887f7b7..907cc546 100644 --- a/tunnel/addressconfig.go +++ b/tunnel/addressconfig.go @@ -76,7 +76,7 @@ func isDnsCacheDisabled() (bool, string) { return cfg.StartType == mgr.StartDisabled, cfg.DisplayName } -func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID, clamper mtuClamper) error { +func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) error { estimatedRouteCount := 0 for _, peer := range conf.Peers { estimatedRouteCount += len(peer.AllowedIPs) @@ -172,9 +172,6 @@ func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid w ipif.OtherStatefulConfigurationSupported = false if conf.Interface.MTU > 0 { ipif.NLMTU = uint32(conf.Interface.MTU) - if clamper != nil { - clamper.ForceMTU(int(ipif.NLMTU)) - } } if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) { ipif.UseAutomaticMetric = false diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go deleted file mode 100644 index 2a107fda..00000000 --- a/tunnel/defaultroutemonitor.go +++ /dev/null @@ -1,159 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. - */ - -package tunnel - -import ( - "log" - "sync" - "time" - - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" -) - -func bindSocketRoute(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32, blackholeWhenLoop bool) error { - r, err := winipcfg.GetIPForwardTable2(family) - if err != nil { - return err - } - lowestMetric := ^uint32(0) - index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want. - luid := winipcfg.LUID(0) // Hopefully luid zero is unspecified, but hard to find docs saying so. - for i := range r { - if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID { - continue - } - ifrow, err := r[i].InterfaceLUID.Interface() - if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp { - continue - } - - iface, err := r[i].InterfaceLUID.IPInterface(family) - if err != nil { - continue - } - - if r[i].Metric+iface.Metric < lowestMetric { - lowestMetric = r[i].Metric + iface.Metric - index = r[i].InterfaceIndex - luid = r[i].InterfaceLUID - } - } - if luid == *lastLUID && index == *lastIndex { - return nil - } - *lastLUID = luid - *lastIndex = index - blackhole := blackholeWhenLoop && index == 0 - if family == windows.AF_INET { - log.Printf("Binding v4 socket to interface %d (blackhole=%v)", index, blackhole) - return binder.BindSocketToInterface4(index, blackhole) - } else if family == windows.AF_INET6 { - log.Printf("Binding v6 socket to interface %d (blackhole=%v)", index, blackhole) - return binder.BindSocketToInterface6(index, blackhole) - } - return nil -} - -type mtuClamper interface { - ForceMTU(mtu int) -} - -func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, autoMTU bool, blackholeWhenLoop bool, clamper mtuClamper, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) { - var minMTU uint32 - if family == windows.AF_INET { - minMTU = 576 - } else if family == windows.AF_INET6 { - minMTU = 1280 - } - lastLUID := winipcfg.LUID(0) - lastIndex := ^uint32(0) - lastMTU := uint32(0) - doIt := func() error { - err := bindSocketRoute(family, binder, ourLUID, &lastLUID, &lastIndex, blackholeWhenLoop) - if err != nil { - return err - } - if !autoMTU { - return nil - } - mtu := uint32(0) - if lastLUID != 0 { - iface, err := lastLUID.Interface() - if err != nil { - return err - } - if iface.MTU > 0 { - mtu = iface.MTU - } - } - if mtu > 0 && lastMTU != mtu { - iface, err := ourLUID.IPInterface(family) - if err != nil { - return err - } - iface.NLMTU = mtu - 80 - if iface.NLMTU < minMTU { - iface.NLMTU = minMTU - } - err = iface.Set() - if err != nil { - return err - } - - // Having one MTU for both v4 and v6 kind of breaks the Windows model, unfortunately. - clamper.ForceMTU(int(iface.NLMTU)) - lastMTU = mtu - } - return nil - } - err := doIt() - if err != nil { - return nil, err - } - - firstBurst := time.Time{} - burstMutex := sync.Mutex{} - burstTimer := time.AfterFunc(time.Hour*200, func() { - burstMutex.Lock() - firstBurst = time.Time{} - doIt() - burstMutex.Unlock() - }) - burstTimer.Stop() - bump := func() { - burstMutex.Lock() - burstTimer.Reset(time.Millisecond * 150) - if firstBurst.IsZero() { - firstBurst = time.Now() - } else if time.Since(firstBurst) > time.Second*2 { - firstBurst = time.Time{} - burstTimer.Stop() - doIt() - } - burstMutex.Unlock() - } - - cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { - if route != nil && route.DestinationPrefix.PrefixLength == 0 { - bump() - } - }) - if err != nil { - return nil, err - } - cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { - if notificationType == winipcfg.MibParameterNotification { - bump() - } - }) - if err != nil { - cbr.Unregister() - return nil, err - } - return []winipcfg.ChangeCallback{cbr, cbi}, nil -} diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go index 4e28a6a2..5ca2c69d 100644 --- a/tunnel/interfacewatcher.go +++ b/tunnel/interfacewatcher.go @@ -11,10 +11,8 @@ import ( "sync" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/windows/driver" - - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" @@ -31,8 +29,6 @@ type interfaceWatcherEvent struct { type interfaceWatcher struct { errors chan interfaceWatcherError - binder conn.BindSocketToInterface - clamper mtuClamper conf *conf.Config adapter *driver.Adapter luid winipcfg.LUID @@ -44,44 +40,6 @@ type interfaceWatcher struct { storedEvents []interfaceWatcherEvent } -func hasDefaultRoute(family winipcfg.AddressFamily, peers []conf.Peer) bool { - var ( - foundV401 bool - foundV41281 bool - foundV600001 bool - foundV680001 bool - foundV400 bool - foundV600 bool - v40 = [4]byte{} - v60 = [16]byte{} - v48 = [4]byte{0x80} - v68 = [16]byte{0x80} - ) - for _, peer := range peers { - for _, allowedip := range peer.AllowedIPs { - if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { - foundV600001 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) { - foundV680001 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { - foundV401 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) { - foundV41281 = true - } else if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { - foundV600 = true - } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { - foundV400 = true - } - } - } - if family == windows.AF_INET { - return foundV400 || (foundV401 && foundV41281) - } else if family == windows.AF_INET6 { - return foundV600 || (foundV600001 && foundV680001) - } - return false -} - func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { var changeCallbacks *[]winipcfg.ChangeCallback var ipversion string @@ -102,14 +60,7 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { } var err error - if iw.binder != nil && iw.clamper != nil { - log.Printf("Monitoring default %s routes", ipversion) - *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.clamper, iw.luid) - if err != nil { - iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} - return - } - } else if iw.conf.Interface.MTU == 0 { + if iw.conf.Interface.MTU == 0 { log.Printf("Monitoring MTU of default %s routes", ipversion) *changeCallbacks, err = monitorMTU(family, iw.luid) if err != nil { @@ -119,7 +70,7 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { } log.Printf("Setting device %s addresses", ipversion) - err = configureInterface(family, iw.conf, iw.luid, iw.clamper) + err = configureInterface(family, iw.conf, iw.luid) if err != nil { iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err} return @@ -147,17 +98,15 @@ func watchInterface() (*interfaceWatcher, error) { } iw.setup(iface.Family) - if iw.adapter != nil { - if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown { - log.Println("Reinitializing adapter configuration") - err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration()) - if err != nil { - log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err)) - } - err = iw.adapter.SetAdapterState(driver.AdapterStateUp) - if err != nil { - log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err)) - } + if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown { + log.Println("Reinitializing adapter configuration") + err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration()) + if err != nil { + log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err)) + } + err = iw.adapter.SetAdapterState(driver.AdapterStateUp) + if err != nil { + log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err)) } } }) @@ -167,11 +116,11 @@ func watchInterface() (*interfaceWatcher, error) { return iw, nil } -func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, clamper mtuClamper, adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) { +func (iw *interfaceWatcher) Configure(adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) { iw.setupMutex.Lock() defer iw.setupMutex.Unlock() - iw.binder, iw.clamper, iw.adapter, iw.conf, iw.luid = binder, clamper, adapter, conf, luid + iw.adapter, iw.conf, iw.luid = adapter, conf, luid for _, event := range iw.storedEvents { if event.luid == luid { iw.setup(event.family) diff --git a/tunnel/ipcpermissions.go b/tunnel/ipcpermissions.go deleted file mode 100644 index 3a676e4b..00000000 --- a/tunnel/ipcpermissions.go +++ /dev/null @@ -1,63 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. - */ - -package tunnel - -import ( - "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/ipc" - - "golang.zx2c4.com/wireguard/windows/conf" -) - -func CopyConfigOwnerToIPCSecurityDescriptor(filename string) error { - if conf.PathIsEncrypted(filename) { - return nil - } - - fileSd, err := windows.GetNamedSecurityInfo(filename, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) - if err != nil { - return err - } - fileOwner, _, err := fileSd.Owner() - if err != nil { - return err - } - if fileOwner.IsWellKnown(windows.WinLocalSystemSid) { - return nil - } - additionalEntries := []windows.EXPLICIT_ACCESS{{ - AccessPermissions: windows.GENERIC_ALL, - AccessMode: windows.GRANT_ACCESS, - Trustee: windows.TRUSTEE{ - TrusteeForm: windows.TRUSTEE_IS_SID, - TrusteeType: windows.TRUSTEE_IS_USER, - TrusteeValue: windows.TrusteeValueFromSID(fileOwner), - }, - }} - - sd, err := ipc.UAPISecurityDescriptor.ToAbsolute() - if err != nil { - return err - } - dacl, defaulted, _ := sd.DACL() - - newDacl, err := windows.ACLFromEntries(additionalEntries, dacl) - if err != nil { - return err - } - err = sd.SetDACL(newDacl, true, defaulted) - if err != nil { - return err - } - sd, err = sd.ToSelfRelative() - if err != nil { - return err - } - ipc.UAPISecurityDescriptor = sd - - return nil -} diff --git a/tunnel/service.go b/tunnel/service.go index 57d0ef91..013548df 100644 --- a/tunnel/service.go +++ b/tunnel/service.go @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "log" - "net" "os" "runtime" "time" @@ -17,10 +16,7 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" - "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/elevate" @@ -37,11 +33,7 @@ type tunnelService struct { func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { changes <- svc.Status{State: svc.StartPending} - var dev *device.Device - var uapi net.Listener var watcher *interfaceWatcher - var nativeTun *tun.NativeTun - var wintun tun.Device var adapter *driver.Adapter var luid winipcfg.LUID var config *conf.Config @@ -88,22 +80,16 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } }() - if logErr == nil && (dev != nil || adapter != nil) && config != nil { + if logErr == nil && adapter != nil && config != nil { logErr = runScriptCommand(config.Interface.PreDown, config.Name) } if watcher != nil { watcher.Destroy() } - if uapi != nil { - uapi.Close() - } - if dev != nil { - dev.Close() - } if adapter != nil { adapter.Close() } - if logErr == nil && (dev != nil || adapter != nil) && config != nil { + if logErr == nil && adapter != nil && config != nil { _ = runScriptCommand(config.Interface.PostDown, config.Name) } stopIt <- true @@ -128,11 +114,6 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } config.DeduplicateNetworkEntries() - err = CopyConfigOwnerToIPCSecurityDescriptor(service.Path) - if err != nil { - serviceError = services.ErrorLoadConfiguration - return - } log.SetPrefix(fmt.Sprintf("[%s] ", config.Name)) @@ -166,58 +147,33 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } log.Println("Creating network adapter") - if UseFixedGUIDInsteadOfDeterministic || !conf.AdminBool("UseUserspaceImplementation") { - for i := 0; i < 5; i++ { - if i > 0 { - time.Sleep(time.Second) - log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) - } - adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config)) - if err == nil || windows.DurationSinceBoot() > time.Minute*10 { - break - } - } - if err != nil { - err = fmt.Errorf("Error creating adapter: %w", err) - serviceError = services.ErrorCreateNetworkAdapter - return + for i := 0; i < 5; i++ { + if i > 0 { + time.Sleep(time.Second) + log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) } - luid = adapter.LUID() - driverVersion, err := driver.RunningVersion() - if err != nil { - log.Printf("Warning: unable to determine driver version: %v", err) - } else { - log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) - } - err = adapter.SetLogging(driver.AdapterLogOn) - if err != nil { - err = fmt.Errorf("Error enabling adapter logging: %w", err) - serviceError = services.ErrorCreateNetworkAdapter - return + adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config)) + if err == nil || windows.DurationSinceBoot() > time.Minute*10 { + break } + } + if err != nil { + err = fmt.Errorf("Error creating adapter: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return + } + luid = adapter.LUID() + driverVersion, err := driver.RunningVersion() + if err != nil { + log.Printf("Warning: unable to determine driver version: %v", err) } else { - for i := 0; i < 5; i++ { - if i > 0 { - time.Sleep(time.Second) - log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) - } - wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0) - if err == nil || windows.DurationSinceBoot() > time.Minute*10 { - break - } - } - if err != nil { - serviceError = services.ErrorCreateNetworkAdapter - return - } - nativeTun = wintun.(*tun.NativeTun) - luid = winipcfg.LUID(nativeTun.LUID()) - driverVersion, err := nativeTun.RunningVersion() - if err != nil { - log.Printf("Warning: unable to determine driver version: %v", err) - } else { - log.Printf("Using Wintun/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) - } + log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) + } + err = adapter.SetLogging(driver.AdapterLogOn) + if err != nil { + err = fmt.Errorf("Error enabling adapter logging: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return } err = runScriptCommand(config.Interface.PreUp, config.Name) @@ -239,54 +195,18 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - if nativeTun != nil { - log.Println("Creating interface instance") - bind := conn.NewDefaultBind() - dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf}) - - log.Println("Setting interface configuration") - uapi, err = ipc.UAPIListen(config.Name) - if err != nil { - serviceError = services.ErrorUAPIListen - return - } - err = dev.IpcSet(config.ToUAPI()) - if err != nil { - serviceError = services.ErrorDeviceSetConfig - return - } - - log.Println("Bringing peers up") - dev.Up() - - var clamper mtuClamper - clamper = nativeTun - watcher.Configure(bind.(conn.BindSocketToInterface), clamper, nil, config, luid) - - log.Println("Listening for UAPI requests") - go func() { - for { - conn, err := uapi.Accept() - if err != nil { - continue - } - go dev.IpcHandle(conn) - } - }() - } else { - log.Println("Setting interface configuration") - err = adapter.SetConfiguration(config.ToDriverConfiguration()) - if err != nil { - serviceError = services.ErrorDeviceSetConfig - return - } - err = adapter.SetAdapterState(driver.AdapterStateUp) - if err != nil { - serviceError = services.ErrorDeviceBringUp - return - } - watcher.Configure(nil, nil, adapter, config, luid) + log.Println("Setting interface configuration") + err = adapter.SetConfiguration(config.ToDriverConfiguration()) + if err != nil { + serviceError = services.ErrorDeviceSetConfig + return } + err = adapter.SetAdapterState(driver.AdapterStateUp) + if err != nil { + serviceError = services.ErrorDeviceBringUp + return + } + watcher.Configure(adapter, config, luid) err = runScriptCommand(config.Interface.PostUp, config.Name) if err != nil { @@ -297,12 +217,6 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} log.Println("Startup complete") - var devWaitChan chan struct{} - if dev != nil { - devWaitChan = dev.Wait() - } else { - devWaitChan = make(chan struct{}) - } for { select { case c := <-r: @@ -314,8 +228,6 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, default: log.Printf("Unexpected service control request #%d\n", c) } - case <-devWaitChan: - return case e := <-watcher.errors: serviceError, err = e.serviceError, e.err return diff --git a/tunnel/winipcfg/winipcfg_test.go b/tunnel/winipcfg/winipcfg_test.go index 7689a0c1..5d3bc276 100644 --- a/tunnel/winipcfg/winipcfg_test.go +++ b/tunnel/winipcfg/winipcfg_test.go @@ -8,8 +8,8 @@ Some tests in this file require: - A dedicated network adapter - Any network adapter will do. It may be virtual (Wintun etc.). The adapter name - must contain string "winipcfg_test". + Any network adapter will do. It may be virtual (WireGuardNT, Wintun, + etc.). The adapter name must contain string "winipcfg_test". Tests will add, remove, flush DNS servers, change adapter IP address, manipulate routes etc. The adapter will not be returned to previous state, so use an expendable one. diff --git a/tunnel/wintun_test.go b/tunnel/wintun_test.go deleted file mode 100644 index 4e56ff65..00000000 --- a/tunnel/wintun_test.go +++ /dev/null @@ -1,202 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. - */ - -package tunnel_test - -import ( - "bytes" - "crypto/rand" - "encoding/binary" - "fmt" - "net" - "sync" - "testing" - "time" - - "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/tun" - - "golang.zx2c4.com/wireguard/windows/elevate" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" -) - -func TestWintunOrdering(t *testing.T) { - var tunDevice tun.Device - err := elevate.DoAsSystem(func() error { - var err error - tunDevice, err = tun.CreateTUNWithRequestedGUID("tunordertest", &windows.GUID{12, 12, 12, [8]byte{12, 12, 12, 12, 12, 12, 12, 12}}, 1500) - return err - }) - if err != nil { - t.Fatal(err) - } - defer tunDevice.Close() - nativeTunDevice := tunDevice.(*tun.NativeTun) - luid := winipcfg.LUID(nativeTunDevice.LUID()) - ip, ipnet, _ := net.ParseCIDR("10.82.31.4/24") - err = luid.SetIPAddresses([]net.IPNet{{ip, ipnet.Mask}}) - if err != nil { - t.Fatal(err) - } - err = luid.SetRoutes([]*winipcfg.RouteData{{*ipnet, ipnet.IP, 0}}) - if err != nil { - t.Fatal(err) - } - var token [32]byte - _, err = rand.Read(token[:]) - if err != nil { - t.Fatal(err) - } - var sockWrite net.Conn - for i := 0; i < 1000; i++ { - sockWrite, err = net.Dial("udp", "10.82.31.5:9999") - if err == nil { - defer sockWrite.Close() - break - } - time.Sleep(time.Millisecond * 100) - } - if err != nil { - t.Fatal(err) - } - var sockRead *net.UDPConn - for i := 0; i < 1000; i++ { - var listenAddress *net.UDPAddr - listenAddress, err = net.ResolveUDPAddr("udp", "10.82.31.4:9999") - if err != nil { - continue - } - sockRead, err = net.ListenUDP("udp", listenAddress) - if err == nil { - defer sockRead.Close() - break - } - time.Sleep(time.Millisecond * 100) - } - if err != nil { - t.Fatal(err) - } - var wait sync.WaitGroup - wait.Add(4) - doneSockWrite := false - doneTunWrite := false - fatalErrors := make(chan error, 2) - errors := make(chan error, 2) - go func() { - defer wait.Done() - buffer := append(token[:], 0, 0, 0, 0, 0, 0, 0, 0) - for sendingIndex := uint64(0); !doneSockWrite; sendingIndex++ { - binary.LittleEndian.PutUint64(buffer[32:], sendingIndex) - _, err := sockWrite.Write(buffer[:]) - if err != nil { - fatalErrors <- err - } - } - }() - go func() { - defer wait.Done() - packet := [20 + 8 + 32 + 8]byte{ - 0x45, 0, 0, 20 + 8 + 32 + 8, - 0, 0, 0, 0, - 0x80, 0x11, 0, 0, - 10, 82, 31, 5, - 10, 82, 31, 4, - 8888 >> 8, 8888 & 0xff, 9999 >> 8, 9999 & 0xff, 0, 8 + 32 + 8, 0, 0, - } - copy(packet[28:], token[:]) - for sendingIndex := uint64(0); !doneTunWrite; sendingIndex++ { - binary.BigEndian.PutUint16(packet[4:], uint16(sendingIndex)) - var checksum uint32 - for i := 0; i < 20; i += 2 { - if i != 10 { - checksum += uint32(binary.BigEndian.Uint16(packet[i:])) - } - } - binary.BigEndian.PutUint16(packet[10:], ^(uint16(checksum>>16) + uint16(checksum&0xffff))) - binary.LittleEndian.PutUint64(packet[20+8+32:], sendingIndex) - n, err := tunDevice.Write(packet[:], 0) - if err != nil { - fatalErrors <- err - } - if n == 0 { - time.Sleep(time.Millisecond * 300) - } - } - }() - const packetsPerTest = 1 << 21 - go func() { - defer func() { - doneSockWrite = true - wait.Done() - }() - var expectedIndex uint64 - for i := uint64(0); i < packetsPerTest; { - var buffer [(1 << 16) - 1]byte - bytesRead, err := tunDevice.Read(buffer[:], 0) - if err != nil { - fatalErrors <- err - } - if bytesRead < 0 || bytesRead > len(buffer) { - continue - } - packet := buffer[:bytesRead] - tokenPos := bytes.Index(packet, token[:]) - if tokenPos == -1 || tokenPos+32+8 > len(packet) { - continue - } - foundIndex := binary.LittleEndian.Uint64(packet[tokenPos+32:]) - if foundIndex < expectedIndex { - errors <- fmt.Errorf("Sock write, tun read: expected packet %d, received packet %d", expectedIndex, foundIndex) - } - expectedIndex = foundIndex + 1 - i++ - } - }() - go func() { - defer func() { - doneTunWrite = true - wait.Done() - }() - var expectedIndex uint64 - for i := uint64(0); i < packetsPerTest; { - var buffer [(1 << 16) - 1]byte - bytesRead, err := sockRead.Read(buffer[:]) - if err != nil { - fatalErrors <- err - } - if bytesRead < 0 || bytesRead > len(buffer) { - continue - } - packet := buffer[:bytesRead] - if len(packet) != 32+8 || !bytes.HasPrefix(packet, token[:]) { - continue - } - foundIndex := binary.LittleEndian.Uint64(packet[32:]) - if foundIndex < expectedIndex { - errors <- fmt.Errorf("Tun write, sock read: expected packet %d, received packet %d", expectedIndex, foundIndex) - } - expectedIndex = foundIndex + 1 - i++ - } - }() - done := make(chan bool, 2) - doneFunc := func() { - wait.Wait() - done <- true - } - defer doneFunc() - go doneFunc() - for { - select { - case err := <-fatalErrors: - t.Fatal(err) - case err := <-errors: - t.Error(err) - case <-done: - return - } - } -} |