aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tunnel/interfacewatcher.go
diff options
context:
space:
mode:
Diffstat (limited to 'tunnel/interfacewatcher.go')
-rw-r--r--tunnel/interfacewatcher.go119
1 files changed, 55 insertions, 64 deletions
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go
index 1f632725..a831d06e 100644
--- a/tunnel/interfacewatcher.go
+++ b/tunnel/interfacewatcher.go
@@ -1,20 +1,20 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
import (
+ "errors"
+ "fmt"
"log"
"sync"
+ "time"
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -24,63 +24,30 @@ type interfaceWatcherError struct {
serviceError services.Error
err error
}
+
type interfaceWatcherEvent struct {
luid winipcfg.LUID
family winipcfg.AddressFamily
}
+
type interfaceWatcher struct {
- errors chan interfaceWatcherError
+ errors chan interfaceWatcherError
+ started chan winipcfg.AddressFamily
- device *device.Device
- conf *conf.Config
- tun *tun.NativeTun
+ conf *conf.Config
+ adapter *driver.Adapter
+ luid winipcfg.LUID
setupMutex sync.Mutex
interfaceChangeCallback winipcfg.ChangeCallback
changeCallbacks4 []winipcfg.ChangeCallback
changeCallbacks6 []winipcfg.ChangeCallback
storedEvents []interfaceWatcherEvent
-}
-
-func hasDefaultRoute(family winipcfg.AddressFamily, peers []conf.Peer) bool {
- var (
- foundV401 bool
- foundV41281 bool
- foundV600001 bool
- foundV680001 bool
- foundV400 bool
- foundV600 bool
- v40 = [4]byte{}
- v60 = [16]byte{}
- v48 = [4]byte{0x80}
- v68 = [16]byte{0x80}
- )
- for _, peer := range peers {
- for _, allowedip := range peer.AllowedIPs {
- if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) {
- foundV600001 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) {
- foundV680001 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) {
- foundV401 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) {
- foundV41281 = true
- } else if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) {
- foundV600 = true
- } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) {
- foundV400 = true
- }
- }
- }
- if family == windows.AF_INET {
- return foundV400 || (foundV401 && foundV41281)
- } else if family == windows.AF_INET6 {
- return foundV600 || (foundV600001 && foundV680001)
- }
- return false
+ watchdog *time.Timer
}
func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
+ iw.watchdog.Stop()
var changeCallbacks *[]winipcfg.ChangeCallback
var ipversion string
if family == windows.AF_INET {
@@ -100,25 +67,35 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
}
var err error
- log.Printf("Monitoring default %s routes", ipversion)
- *changeCallbacks, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun)
- if err != nil {
- iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
- return
+ if iw.conf.Interface.MTU == 0 {
+ log.Printf("Monitoring MTU of default %s routes", ipversion)
+ *changeCallbacks, err = monitorMTU(family, iw.luid)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err}
+ return
+ }
}
log.Printf("Setting device %s addresses", ipversion)
- err = configureInterface(family, iw.conf, iw.tun)
+ err = configureInterface(family, iw.conf, iw.luid)
if err != nil {
iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err}
return
}
+ evaluateDynamicPitfalls(family, iw.conf, iw.luid)
+
+ iw.started <- family
}
func watchInterface() (*interfaceWatcher, error) {
iw := &interfaceWatcher{
- errors: make(chan interfaceWatcherError, 2),
+ errors: make(chan interfaceWatcherError, 2),
+ started: make(chan winipcfg.AddressFamily, 4),
}
+ iw.watchdog = time.AfterFunc(time.Duration(1<<63-1), func() {
+ iw.errors <- interfaceWatcherError{services.ErrorCreateNetworkAdapter, errors.New("TCP/IP interface for adapter did not appear after one minute")}
+ })
+ iw.watchdog.Stop()
var err error
iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
iw.setupMutex.Lock()
@@ -127,28 +104,41 @@ func watchInterface() (*interfaceWatcher, error) {
if notificationType != winipcfg.MibAddInstance {
return
}
- if iw.tun == nil {
+ if iw.luid == 0 {
iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family})
return
}
- if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) {
+ if iface.InterfaceLUID != iw.luid {
return
}
iw.setup(iface.Family)
+
+ if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown {
+ log.Println("Reinitializing adapter configuration")
+ err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration())
+ if err != nil {
+ log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err))
+ }
+ err = iw.adapter.SetAdapterState(driver.AdapterStateUp)
+ if err != nil {
+ log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err))
+ }
+ }
})
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("unable to register interface change callback: %w", err)
}
return iw, nil
}
-func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, tun *tun.NativeTun) {
+func (iw *interfaceWatcher) Configure(adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) {
iw.setupMutex.Lock()
defer iw.setupMutex.Unlock()
+ iw.watchdog.Reset(time.Minute)
- iw.device, iw.conf, iw.tun = device, conf, tun
+ iw.adapter, iw.conf, iw.luid = adapter, conf, luid
for _, event := range iw.storedEvents {
- if event.luid == winipcfg.LUID(iw.tun.LUID()) {
+ if event.luid == luid {
iw.setup(event.family)
}
}
@@ -157,10 +147,11 @@ func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config,
func (iw *interfaceWatcher) Destroy() {
iw.setupMutex.Lock()
+ iw.watchdog.Stop()
changeCallbacks4 := iw.changeCallbacks4
changeCallbacks6 := iw.changeCallbacks6
interfaceChangeCallback := iw.interfaceChangeCallback
- tun := iw.tun
+ luid := iw.luid
iw.setupMutex.Unlock()
if interfaceChangeCallback != nil {
@@ -186,15 +177,15 @@ func (iw *interfaceWatcher) Destroy() {
changeCallbacks6 = changeCallbacks6[1:]
}
firewall.DisableFirewall()
- if tun != nil && iw.tun == tun {
+ if luid != 0 && iw.luid == luid {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active
// routes, so to be certain, just remove everything before destroying.
- luid := winipcfg.LUID(tun.LUID())
luid.FlushRoutes(windows.AF_INET)
luid.FlushIPAddresses(windows.AF_INET)
+ luid.FlushDNS(windows.AF_INET)
luid.FlushRoutes(windows.AF_INET6)
luid.FlushIPAddresses(windows.AF_INET6)
- luid.FlushDNS()
+ luid.FlushDNS(windows.AF_INET6)
}
iw.setupMutex.Unlock()
}