aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tunnel
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tunnel/addressconfig.go5
-rw-r--r--tunnel/defaultroutemonitor.go159
-rw-r--r--tunnel/interfacewatcher.go79
-rw-r--r--tunnel/ipcpermissions.go63
-rw-r--r--tunnel/service.go164
-rw-r--r--tunnel/winipcfg/winipcfg_test.go4
-rw-r--r--tunnel/wintun_test.go202
7 files changed, 55 insertions, 621 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go
index c887f7b7..907cc546 100644
--- a/tunnel/addressconfig.go
+++ b/tunnel/addressconfig.go
@@ -76,7 +76,7 @@ func isDnsCacheDisabled() (bool, string) {
return cfg.StartType == mgr.StartDisabled, cfg.DisplayName
}
-func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID, clamper mtuClamper) error {
+func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) error {
estimatedRouteCount := 0
for _, peer := range conf.Peers {
estimatedRouteCount += len(peer.AllowedIPs)
@@ -172,9 +172,6 @@ func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid w
ipif.OtherStatefulConfigurationSupported = false
if conf.Interface.MTU > 0 {
ipif.NLMTU = uint32(conf.Interface.MTU)
- if clamper != nil {
- clamper.ForceMTU(int(ipif.NLMTU))
- }
}
if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) {
ipif.UseAutomaticMetric = false
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go
deleted file mode 100644
index 2a107fda..00000000
--- a/tunnel/defaultroutemonitor.go
+++ /dev/null
@@ -1,159 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package tunnel
-
-import (
- "log"
- "sync"
- "time"
-
- "golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
-)
-
-func bindSocketRoute(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32, blackholeWhenLoop bool) error {
- r, err := winipcfg.GetIPForwardTable2(family)
- if err != nil {
- return err
- }
- lowestMetric := ^uint32(0)
- index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want.
- luid := winipcfg.LUID(0) // Hopefully luid zero is unspecified, but hard to find docs saying so.
- for i := range r {
- if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID {
- continue
- }
- ifrow, err := r[i].InterfaceLUID.Interface()
- if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp {
- continue
- }
-
- iface, err := r[i].InterfaceLUID.IPInterface(family)
- if err != nil {
- continue
- }
-
- if r[i].Metric+iface.Metric < lowestMetric {
- lowestMetric = r[i].Metric + iface.Metric
- index = r[i].InterfaceIndex
- luid = r[i].InterfaceLUID
- }
- }
- if luid == *lastLUID && index == *lastIndex {
- return nil
- }
- *lastLUID = luid
- *lastIndex = index
- blackhole := blackholeWhenLoop && index == 0
- if family == windows.AF_INET {
- log.Printf("Binding v4 socket to interface %d (blackhole=%v)", index, blackhole)
- return binder.BindSocketToInterface4(index, blackhole)
- } else if family == windows.AF_INET6 {
- log.Printf("Binding v6 socket to interface %d (blackhole=%v)", index, blackhole)
- return binder.BindSocketToInterface6(index, blackhole)
- }
- return nil
-}
-
-type mtuClamper interface {
- ForceMTU(mtu int)
-}
-
-func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, autoMTU bool, blackholeWhenLoop bool, clamper mtuClamper, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) {
- var minMTU uint32
- if family == windows.AF_INET {
- minMTU = 576
- } else if family == windows.AF_INET6 {
- minMTU = 1280
- }
- lastLUID := winipcfg.LUID(0)
- lastIndex := ^uint32(0)
- lastMTU := uint32(0)
- doIt := func() error {
- err := bindSocketRoute(family, binder, ourLUID, &lastLUID, &lastIndex, blackholeWhenLoop)
- if err != nil {
- return err
- }
- if !autoMTU {
- return nil
- }
- mtu := uint32(0)
- if lastLUID != 0 {
- iface, err := lastLUID.Interface()
- if err != nil {
- return err
- }
- if iface.MTU > 0 {
- mtu = iface.MTU
- }
- }
- if mtu > 0 && lastMTU != mtu {
- iface, err := ourLUID.IPInterface(family)
- if err != nil {
- return err
- }
- iface.NLMTU = mtu - 80
- if iface.NLMTU < minMTU {
- iface.NLMTU = minMTU
- }
- err = iface.Set()
- if err != nil {
- return err
- }
-
- // Having one MTU for both v4 and v6 kind of breaks the Windows model, unfortunately.
- clamper.ForceMTU(int(iface.NLMTU))
- lastMTU = mtu
- }
- return nil
- }
- err := doIt()
- if err != nil {
- return nil, err
- }
-
- firstBurst := time.Time{}
- burstMutex := sync.Mutex{}
- burstTimer := time.AfterFunc(time.Hour*200, func() {
- burstMutex.Lock()
- firstBurst = time.Time{}
- doIt()
- burstMutex.Unlock()
- })
- burstTimer.Stop()
- bump := func() {
- burstMutex.Lock()
- burstTimer.Reset(time.Millisecond * 150)
- if firstBurst.IsZero() {
- firstBurst = time.Now()
- } else if time.Since(firstBurst) > time.Second*2 {
- firstBurst = time.Time{}
- burstTimer.Stop()
- doIt()
- }
- burstMutex.Unlock()
- }
-
- cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
- if route != nil && route.DestinationPrefix.PrefixLength == 0 {
- bump()
- }
- })
- if err != nil {
- return nil, err
- }
- cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
- if notificationType == winipcfg.MibParameterNotification {
- bump()
- }
- })
- if err != nil {
- cbr.Unregister()
- return nil, err
- }
- return []winipcfg.ChangeCallback{cbr, cbi}, nil
-}
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go
index 4e28a6a2..5ca2c69d 100644
--- a/tunnel/interfacewatcher.go
+++ b/tunnel/interfacewatcher.go
@@ -11,10 +11,8 @@ import (
"sync"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/windows/driver"
-
- "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -31,8 +29,6 @@ type interfaceWatcherEvent struct {
type interfaceWatcher struct {
errors chan interfaceWatcherError
- binder conn.BindSocketToInterface
- clamper mtuClamper
conf *conf.Config
adapter *driver.Adapter
luid winipcfg.LUID
@@ -44,44 +40,6 @@ type interfaceWatcher struct {
storedEvents []interfaceWatcherEvent
}
-func hasDefaultRoute(family winipcfg.AddressFamily, peers []conf.Peer) bool {
- var (
- foundV401 bool
- foundV41281 bool
- foundV600001 bool
- foundV680001 bool
- foundV400 bool
- foundV600 bool
- v40 = [4]byte{}
- v60 = [16]byte{}
- v48 = [4]byte{0x80}
- v68 = [16]byte{0x80}
- )
- for _, peer := range peers {
- for _, allowedip := range peer.AllowedIPs {
- if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) {
- foundV600001 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) {
- foundV680001 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) {
- foundV401 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) {
- foundV41281 = true
- } else if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) {
- foundV600 = true
- } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) {
- foundV400 = true
- }
- }
- }
- if family == windows.AF_INET {
- return foundV400 || (foundV401 && foundV41281)
- } else if family == windows.AF_INET6 {
- return foundV600 || (foundV600001 && foundV680001)
- }
- return false
-}
-
func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
var changeCallbacks *[]winipcfg.ChangeCallback
var ipversion string
@@ -102,14 +60,7 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
}
var err error
- if iw.binder != nil && iw.clamper != nil {
- log.Printf("Monitoring default %s routes", ipversion)
- *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.clamper, iw.luid)
- if err != nil {
- iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
- return
- }
- } else if iw.conf.Interface.MTU == 0 {
+ if iw.conf.Interface.MTU == 0 {
log.Printf("Monitoring MTU of default %s routes", ipversion)
*changeCallbacks, err = monitorMTU(family, iw.luid)
if err != nil {
@@ -119,7 +70,7 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
}
log.Printf("Setting device %s addresses", ipversion)
- err = configureInterface(family, iw.conf, iw.luid, iw.clamper)
+ err = configureInterface(family, iw.conf, iw.luid)
if err != nil {
iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err}
return
@@ -147,17 +98,15 @@ func watchInterface() (*interfaceWatcher, error) {
}
iw.setup(iface.Family)
- if iw.adapter != nil {
- if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown {
- log.Println("Reinitializing adapter configuration")
- err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration())
- if err != nil {
- log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err))
- }
- err = iw.adapter.SetAdapterState(driver.AdapterStateUp)
- if err != nil {
- log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err))
- }
+ if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown {
+ log.Println("Reinitializing adapter configuration")
+ err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration())
+ if err != nil {
+ log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err))
+ }
+ err = iw.adapter.SetAdapterState(driver.AdapterStateUp)
+ if err != nil {
+ log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err))
}
}
})
@@ -167,11 +116,11 @@ func watchInterface() (*interfaceWatcher, error) {
return iw, nil
}
-func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, clamper mtuClamper, adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) {
+func (iw *interfaceWatcher) Configure(adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) {
iw.setupMutex.Lock()
defer iw.setupMutex.Unlock()
- iw.binder, iw.clamper, iw.adapter, iw.conf, iw.luid = binder, clamper, adapter, conf, luid
+ iw.adapter, iw.conf, iw.luid = adapter, conf, luid
for _, event := range iw.storedEvents {
if event.luid == luid {
iw.setup(event.family)
diff --git a/tunnel/ipcpermissions.go b/tunnel/ipcpermissions.go
deleted file mode 100644
index 3a676e4b..00000000
--- a/tunnel/ipcpermissions.go
+++ /dev/null
@@ -1,63 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package tunnel
-
-import (
- "golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/ipc"
-
- "golang.zx2c4.com/wireguard/windows/conf"
-)
-
-func CopyConfigOwnerToIPCSecurityDescriptor(filename string) error {
- if conf.PathIsEncrypted(filename) {
- return nil
- }
-
- fileSd, err := windows.GetNamedSecurityInfo(filename, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
- if err != nil {
- return err
- }
- fileOwner, _, err := fileSd.Owner()
- if err != nil {
- return err
- }
- if fileOwner.IsWellKnown(windows.WinLocalSystemSid) {
- return nil
- }
- additionalEntries := []windows.EXPLICIT_ACCESS{{
- AccessPermissions: windows.GENERIC_ALL,
- AccessMode: windows.GRANT_ACCESS,
- Trustee: windows.TRUSTEE{
- TrusteeForm: windows.TRUSTEE_IS_SID,
- TrusteeType: windows.TRUSTEE_IS_USER,
- TrusteeValue: windows.TrusteeValueFromSID(fileOwner),
- },
- }}
-
- sd, err := ipc.UAPISecurityDescriptor.ToAbsolute()
- if err != nil {
- return err
- }
- dacl, defaulted, _ := sd.DACL()
-
- newDacl, err := windows.ACLFromEntries(additionalEntries, dacl)
- if err != nil {
- return err
- }
- err = sd.SetDACL(newDacl, true, defaulted)
- if err != nil {
- return err
- }
- sd, err = sd.ToSelfRelative()
- if err != nil {
- return err
- }
- ipc.UAPISecurityDescriptor = sd
-
- return nil
-}
diff --git a/tunnel/service.go b/tunnel/service.go
index 57d0ef91..013548df 100644
--- a/tunnel/service.go
+++ b/tunnel/service.go
@@ -9,7 +9,6 @@ import (
"bytes"
"fmt"
"log"
- "net"
"os"
"runtime"
"time"
@@ -17,10 +16,7 @@ import (
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/ipc"
- "golang.zx2c4.com/wireguard/tun"
+
"golang.zx2c4.com/wireguard/windows/conf"
"golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/elevate"
@@ -37,11 +33,7 @@ type tunnelService struct {
func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
changes <- svc.Status{State: svc.StartPending}
- var dev *device.Device
- var uapi net.Listener
var watcher *interfaceWatcher
- var nativeTun *tun.NativeTun
- var wintun tun.Device
var adapter *driver.Adapter
var luid winipcfg.LUID
var config *conf.Config
@@ -88,22 +80,16 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
}
}()
- if logErr == nil && (dev != nil || adapter != nil) && config != nil {
+ if logErr == nil && adapter != nil && config != nil {
logErr = runScriptCommand(config.Interface.PreDown, config.Name)
}
if watcher != nil {
watcher.Destroy()
}
- if uapi != nil {
- uapi.Close()
- }
- if dev != nil {
- dev.Close()
- }
if adapter != nil {
adapter.Close()
}
- if logErr == nil && (dev != nil || adapter != nil) && config != nil {
+ if logErr == nil && adapter != nil && config != nil {
_ = runScriptCommand(config.Interface.PostDown, config.Name)
}
stopIt <- true
@@ -128,11 +114,6 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
return
}
config.DeduplicateNetworkEntries()
- err = CopyConfigOwnerToIPCSecurityDescriptor(service.Path)
- if err != nil {
- serviceError = services.ErrorLoadConfiguration
- return
- }
log.SetPrefix(fmt.Sprintf("[%s] ", config.Name))
@@ -166,58 +147,33 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
}
log.Println("Creating network adapter")
- if UseFixedGUIDInsteadOfDeterministic || !conf.AdminBool("UseUserspaceImplementation") {
- for i := 0; i < 5; i++ {
- if i > 0 {
- time.Sleep(time.Second)
- log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
- }
- adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config))
- if err == nil || windows.DurationSinceBoot() > time.Minute*10 {
- break
- }
- }
- if err != nil {
- err = fmt.Errorf("Error creating adapter: %w", err)
- serviceError = services.ErrorCreateNetworkAdapter
- return
+ for i := 0; i < 5; i++ {
+ if i > 0 {
+ time.Sleep(time.Second)
+ log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
}
- luid = adapter.LUID()
- driverVersion, err := driver.RunningVersion()
- if err != nil {
- log.Printf("Warning: unable to determine driver version: %v", err)
- } else {
- log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff)
- }
- err = adapter.SetLogging(driver.AdapterLogOn)
- if err != nil {
- err = fmt.Errorf("Error enabling adapter logging: %w", err)
- serviceError = services.ErrorCreateNetworkAdapter
- return
+ adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config))
+ if err == nil || windows.DurationSinceBoot() > time.Minute*10 {
+ break
}
+ }
+ if err != nil {
+ err = fmt.Errorf("Error creating adapter: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
+ }
+ luid = adapter.LUID()
+ driverVersion, err := driver.RunningVersion()
+ if err != nil {
+ log.Printf("Warning: unable to determine driver version: %v", err)
} else {
- for i := 0; i < 5; i++ {
- if i > 0 {
- time.Sleep(time.Second)
- log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
- }
- wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0)
- if err == nil || windows.DurationSinceBoot() > time.Minute*10 {
- break
- }
- }
- if err != nil {
- serviceError = services.ErrorCreateNetworkAdapter
- return
- }
- nativeTun = wintun.(*tun.NativeTun)
- luid = winipcfg.LUID(nativeTun.LUID())
- driverVersion, err := nativeTun.RunningVersion()
- if err != nil {
- log.Printf("Warning: unable to determine driver version: %v", err)
- } else {
- log.Printf("Using Wintun/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff)
- }
+ log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff)
+ }
+ err = adapter.SetLogging(driver.AdapterLogOn)
+ if err != nil {
+ err = fmt.Errorf("Error enabling adapter logging: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
}
err = runScriptCommand(config.Interface.PreUp, config.Name)
@@ -239,54 +195,18 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
return
}
- if nativeTun != nil {
- log.Println("Creating interface instance")
- bind := conn.NewDefaultBind()
- dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf})
-
- log.Println("Setting interface configuration")
- uapi, err = ipc.UAPIListen(config.Name)
- if err != nil {
- serviceError = services.ErrorUAPIListen
- return
- }
- err = dev.IpcSet(config.ToUAPI())
- if err != nil {
- serviceError = services.ErrorDeviceSetConfig
- return
- }
-
- log.Println("Bringing peers up")
- dev.Up()
-
- var clamper mtuClamper
- clamper = nativeTun
- watcher.Configure(bind.(conn.BindSocketToInterface), clamper, nil, config, luid)
-
- log.Println("Listening for UAPI requests")
- go func() {
- for {
- conn, err := uapi.Accept()
- if err != nil {
- continue
- }
- go dev.IpcHandle(conn)
- }
- }()
- } else {
- log.Println("Setting interface configuration")
- err = adapter.SetConfiguration(config.ToDriverConfiguration())
- if err != nil {
- serviceError = services.ErrorDeviceSetConfig
- return
- }
- err = adapter.SetAdapterState(driver.AdapterStateUp)
- if err != nil {
- serviceError = services.ErrorDeviceBringUp
- return
- }
- watcher.Configure(nil, nil, adapter, config, luid)
+ log.Println("Setting interface configuration")
+ err = adapter.SetConfiguration(config.ToDriverConfiguration())
+ if err != nil {
+ serviceError = services.ErrorDeviceSetConfig
+ return
}
+ err = adapter.SetAdapterState(driver.AdapterStateUp)
+ if err != nil {
+ serviceError = services.ErrorDeviceBringUp
+ return
+ }
+ watcher.Configure(adapter, config, luid)
err = runScriptCommand(config.Interface.PostUp, config.Name)
if err != nil {
@@ -297,12 +217,6 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown}
log.Println("Startup complete")
- var devWaitChan chan struct{}
- if dev != nil {
- devWaitChan = dev.Wait()
- } else {
- devWaitChan = make(chan struct{})
- }
for {
select {
case c := <-r:
@@ -314,8 +228,6 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
default:
log.Printf("Unexpected service control request #%d\n", c)
}
- case <-devWaitChan:
- return
case e := <-watcher.errors:
serviceError, err = e.serviceError, e.err
return
diff --git a/tunnel/winipcfg/winipcfg_test.go b/tunnel/winipcfg/winipcfg_test.go
index 7689a0c1..5d3bc276 100644
--- a/tunnel/winipcfg/winipcfg_test.go
+++ b/tunnel/winipcfg/winipcfg_test.go
@@ -8,8 +8,8 @@
Some tests in this file require:
- A dedicated network adapter
- Any network adapter will do. It may be virtual (Wintun etc.). The adapter name
- must contain string "winipcfg_test".
+ Any network adapter will do. It may be virtual (WireGuardNT, Wintun,
+ etc.). The adapter name must contain string "winipcfg_test".
Tests will add, remove, flush DNS servers, change adapter IP address, manipulate
routes etc.
The adapter will not be returned to previous state, so use an expendable one.
diff --git a/tunnel/wintun_test.go b/tunnel/wintun_test.go
deleted file mode 100644
index 4e56ff65..00000000
--- a/tunnel/wintun_test.go
+++ /dev/null
@@ -1,202 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package tunnel_test
-
-import (
- "bytes"
- "crypto/rand"
- "encoding/binary"
- "fmt"
- "net"
- "sync"
- "testing"
- "time"
-
- "golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/tun"
-
- "golang.zx2c4.com/wireguard/windows/elevate"
- "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
-)
-
-func TestWintunOrdering(t *testing.T) {
- var tunDevice tun.Device
- err := elevate.DoAsSystem(func() error {
- var err error
- tunDevice, err = tun.CreateTUNWithRequestedGUID("tunordertest", &windows.GUID{12, 12, 12, [8]byte{12, 12, 12, 12, 12, 12, 12, 12}}, 1500)
- return err
- })
- if err != nil {
- t.Fatal(err)
- }
- defer tunDevice.Close()
- nativeTunDevice := tunDevice.(*tun.NativeTun)
- luid := winipcfg.LUID(nativeTunDevice.LUID())
- ip, ipnet, _ := net.ParseCIDR("10.82.31.4/24")
- err = luid.SetIPAddresses([]net.IPNet{{ip, ipnet.Mask}})
- if err != nil {
- t.Fatal(err)
- }
- err = luid.SetRoutes([]*winipcfg.RouteData{{*ipnet, ipnet.IP, 0}})
- if err != nil {
- t.Fatal(err)
- }
- var token [32]byte
- _, err = rand.Read(token[:])
- if err != nil {
- t.Fatal(err)
- }
- var sockWrite net.Conn
- for i := 0; i < 1000; i++ {
- sockWrite, err = net.Dial("udp", "10.82.31.5:9999")
- if err == nil {
- defer sockWrite.Close()
- break
- }
- time.Sleep(time.Millisecond * 100)
- }
- if err != nil {
- t.Fatal(err)
- }
- var sockRead *net.UDPConn
- for i := 0; i < 1000; i++ {
- var listenAddress *net.UDPAddr
- listenAddress, err = net.ResolveUDPAddr("udp", "10.82.31.4:9999")
- if err != nil {
- continue
- }
- sockRead, err = net.ListenUDP("udp", listenAddress)
- if err == nil {
- defer sockRead.Close()
- break
- }
- time.Sleep(time.Millisecond * 100)
- }
- if err != nil {
- t.Fatal(err)
- }
- var wait sync.WaitGroup
- wait.Add(4)
- doneSockWrite := false
- doneTunWrite := false
- fatalErrors := make(chan error, 2)
- errors := make(chan error, 2)
- go func() {
- defer wait.Done()
- buffer := append(token[:], 0, 0, 0, 0, 0, 0, 0, 0)
- for sendingIndex := uint64(0); !doneSockWrite; sendingIndex++ {
- binary.LittleEndian.PutUint64(buffer[32:], sendingIndex)
- _, err := sockWrite.Write(buffer[:])
- if err != nil {
- fatalErrors <- err
- }
- }
- }()
- go func() {
- defer wait.Done()
- packet := [20 + 8 + 32 + 8]byte{
- 0x45, 0, 0, 20 + 8 + 32 + 8,
- 0, 0, 0, 0,
- 0x80, 0x11, 0, 0,
- 10, 82, 31, 5,
- 10, 82, 31, 4,
- 8888 >> 8, 8888 & 0xff, 9999 >> 8, 9999 & 0xff, 0, 8 + 32 + 8, 0, 0,
- }
- copy(packet[28:], token[:])
- for sendingIndex := uint64(0); !doneTunWrite; sendingIndex++ {
- binary.BigEndian.PutUint16(packet[4:], uint16(sendingIndex))
- var checksum uint32
- for i := 0; i < 20; i += 2 {
- if i != 10 {
- checksum += uint32(binary.BigEndian.Uint16(packet[i:]))
- }
- }
- binary.BigEndian.PutUint16(packet[10:], ^(uint16(checksum>>16) + uint16(checksum&0xffff)))
- binary.LittleEndian.PutUint64(packet[20+8+32:], sendingIndex)
- n, err := tunDevice.Write(packet[:], 0)
- if err != nil {
- fatalErrors <- err
- }
- if n == 0 {
- time.Sleep(time.Millisecond * 300)
- }
- }
- }()
- const packetsPerTest = 1 << 21
- go func() {
- defer func() {
- doneSockWrite = true
- wait.Done()
- }()
- var expectedIndex uint64
- for i := uint64(0); i < packetsPerTest; {
- var buffer [(1 << 16) - 1]byte
- bytesRead, err := tunDevice.Read(buffer[:], 0)
- if err != nil {
- fatalErrors <- err
- }
- if bytesRead < 0 || bytesRead > len(buffer) {
- continue
- }
- packet := buffer[:bytesRead]
- tokenPos := bytes.Index(packet, token[:])
- if tokenPos == -1 || tokenPos+32+8 > len(packet) {
- continue
- }
- foundIndex := binary.LittleEndian.Uint64(packet[tokenPos+32:])
- if foundIndex < expectedIndex {
- errors <- fmt.Errorf("Sock write, tun read: expected packet %d, received packet %d", expectedIndex, foundIndex)
- }
- expectedIndex = foundIndex + 1
- i++
- }
- }()
- go func() {
- defer func() {
- doneTunWrite = true
- wait.Done()
- }()
- var expectedIndex uint64
- for i := uint64(0); i < packetsPerTest; {
- var buffer [(1 << 16) - 1]byte
- bytesRead, err := sockRead.Read(buffer[:])
- if err != nil {
- fatalErrors <- err
- }
- if bytesRead < 0 || bytesRead > len(buffer) {
- continue
- }
- packet := buffer[:bytesRead]
- if len(packet) != 32+8 || !bytes.HasPrefix(packet, token[:]) {
- continue
- }
- foundIndex := binary.LittleEndian.Uint64(packet[32:])
- if foundIndex < expectedIndex {
- errors <- fmt.Errorf("Tun write, sock read: expected packet %d, received packet %d", expectedIndex, foundIndex)
- }
- expectedIndex = foundIndex + 1
- i++
- }
- }()
- done := make(chan bool, 2)
- doneFunc := func() {
- wait.Wait()
- done <- true
- }
- defer doneFunc()
- go doneFunc()
- for {
- select {
- case err := <-fatalErrors:
- t.Fatal(err)
- case err := <-errors:
- t.Error(err)
- case <-done:
- return
- }
- }
-}