diff options
Diffstat (limited to 'tunnel/interfacewatcher.go')
-rw-r--r-- | tunnel/interfacewatcher.go | 119 |
1 files changed, 55 insertions, 64 deletions
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go index 1f632725..a831d06e 100644 --- a/tunnel/interfacewatcher.go +++ b/tunnel/interfacewatcher.go @@ -1,20 +1,20 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel import ( + "errors" + "fmt" "log" "sync" + "time" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" @@ -24,63 +24,30 @@ type interfaceWatcherError struct { serviceError services.Error err error } + type interfaceWatcherEvent struct { luid winipcfg.LUID family winipcfg.AddressFamily } + type interfaceWatcher struct { - errors chan interfaceWatcherError + errors chan interfaceWatcherError + started chan winipcfg.AddressFamily - device *device.Device - conf *conf.Config - tun *tun.NativeTun + conf *conf.Config + adapter *driver.Adapter + luid winipcfg.LUID setupMutex sync.Mutex interfaceChangeCallback winipcfg.ChangeCallback changeCallbacks4 []winipcfg.ChangeCallback changeCallbacks6 []winipcfg.ChangeCallback storedEvents []interfaceWatcherEvent -} - -func hasDefaultRoute(family winipcfg.AddressFamily, peers []conf.Peer) bool { - var ( - foundV401 bool - foundV41281 bool - foundV600001 bool - foundV680001 bool - foundV400 bool - foundV600 bool - v40 = [4]byte{} - v60 = [16]byte{} - v48 = [4]byte{0x80} - v68 = [16]byte{0x80} - ) - for _, peer := range peers { - for _, allowedip := range peer.AllowedIPs { - if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { - foundV600001 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) { - foundV680001 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { - foundV401 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) { - foundV41281 = true - } else if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { - foundV600 = true - } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { - foundV400 = true - } - } - } - if family == windows.AF_INET { - return foundV400 || (foundV401 && foundV41281) - } else if family == windows.AF_INET6 { - return foundV600 || (foundV600001 && foundV680001) - } - return false + watchdog *time.Timer } func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { + iw.watchdog.Stop() var changeCallbacks *[]winipcfg.ChangeCallback var ipversion string if family == windows.AF_INET { @@ -100,25 +67,35 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { } var err error - log.Printf("Monitoring default %s routes", ipversion) - *changeCallbacks, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun) - if err != nil { - iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} - return + if iw.conf.Interface.MTU == 0 { + log.Printf("Monitoring MTU of default %s routes", ipversion) + *changeCallbacks, err = monitorMTU(family, iw.luid) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err} + return + } } log.Printf("Setting device %s addresses", ipversion) - err = configureInterface(family, iw.conf, iw.tun) + err = configureInterface(family, iw.conf, iw.luid) if err != nil { iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err} return } + evaluateDynamicPitfalls(family, iw.conf, iw.luid) + + iw.started <- family } func watchInterface() (*interfaceWatcher, error) { iw := &interfaceWatcher{ - errors: make(chan interfaceWatcherError, 2), + errors: make(chan interfaceWatcherError, 2), + started: make(chan winipcfg.AddressFamily, 4), } + iw.watchdog = time.AfterFunc(time.Duration(1<<63-1), func() { + iw.errors <- interfaceWatcherError{services.ErrorCreateNetworkAdapter, errors.New("TCP/IP interface for adapter did not appear after one minute")} + }) + iw.watchdog.Stop() var err error iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { iw.setupMutex.Lock() @@ -127,28 +104,41 @@ func watchInterface() (*interfaceWatcher, error) { if notificationType != winipcfg.MibAddInstance { return } - if iw.tun == nil { + if iw.luid == 0 { iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family}) return } - if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) { + if iface.InterfaceLUID != iw.luid { return } iw.setup(iface.Family) + + if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown { + log.Println("Reinitializing adapter configuration") + err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration()) + if err != nil { + log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err)) + } + err = iw.adapter.SetAdapterState(driver.AdapterStateUp) + if err != nil { + log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err)) + } + } }) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to register interface change callback: %w", err) } return iw, nil } -func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, tun *tun.NativeTun) { +func (iw *interfaceWatcher) Configure(adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) { iw.setupMutex.Lock() defer iw.setupMutex.Unlock() + iw.watchdog.Reset(time.Minute) - iw.device, iw.conf, iw.tun = device, conf, tun + iw.adapter, iw.conf, iw.luid = adapter, conf, luid for _, event := range iw.storedEvents { - if event.luid == winipcfg.LUID(iw.tun.LUID()) { + if event.luid == luid { iw.setup(event.family) } } @@ -157,10 +147,11 @@ func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, func (iw *interfaceWatcher) Destroy() { iw.setupMutex.Lock() + iw.watchdog.Stop() changeCallbacks4 := iw.changeCallbacks4 changeCallbacks6 := iw.changeCallbacks6 interfaceChangeCallback := iw.interfaceChangeCallback - tun := iw.tun + luid := iw.luid iw.setupMutex.Unlock() if interfaceChangeCallback != nil { @@ -186,15 +177,15 @@ func (iw *interfaceWatcher) Destroy() { changeCallbacks6 = changeCallbacks6[1:] } firewall.DisableFirewall() - if tun != nil && iw.tun == tun { + if luid != 0 && iw.luid == luid { // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active // routes, so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(tun.LUID()) luid.FlushRoutes(windows.AF_INET) luid.FlushIPAddresses(windows.AF_INET) + luid.FlushDNS(windows.AF_INET) luid.FlushRoutes(windows.AF_INET6) luid.FlushIPAddresses(windows.AF_INET6) - luid.FlushDNS() + luid.FlushDNS(windows.AF_INET6) } iw.setupMutex.Unlock() } |