diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-09-28 19:20:49 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-10-01 13:59:42 +0200 |
commit | bba001018f0a8f2fb1c88a6b97adb23690a6512b (patch) | |
tree | 9e9d8ad1787e72f9398fdf3b01f003d743178d2e /tunnel | |
parent | elevate: use fallback shellexecute when not EV-signed (diff) | |
download | wireguard-windows-bba001018f0a8f2fb1c88a6b97adb23690a6512b.tar.xz wireguard-windows-bba001018f0a8f2fb1c88a6b97adb23690a6512b.zip |
tunnel: windows does not always add/remove routes with up/down interface
On Linux, we're used to routes being added after an interface is up, and
routes being removed as a consequence of an interface going down. On
Windows, this isn't always the case, at least not from the perspective
of the route notifiers. In order to work around this and make a
multi-interface model coherent, we search for a new default route not
only whenever the routing table changes but also whenever any interface
link parameters change, such as up/down.
The practical consequence is that now WireGuard connects properly when
wifi is disconnected and then reconnected.
Reported-by: Nenad Kozul <me@nenadkozul.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'tunnel')
-rw-r--r-- | tunnel/defaultroutemonitor.go | 22 | ||||
-rw-r--r-- | tunnel/interfacewatcher.go | 44 | ||||
-rw-r--r-- | tunnel/winipcfg/types.go | 4 |
3 files changed, 45 insertions, 25 deletions
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go index c1722c45..f9c63e56 100644 --- a/tunnel/defaultroutemonitor.go +++ b/tunnel/defaultroutemonitor.go @@ -7,6 +7,7 @@ package tunnel import ( "log" + "sync" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/device" @@ -29,7 +30,6 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU } ifrow, err := r[i].InterfaceLUID.Interface() if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp { - log.Printf("Found default route for interface %d, but not up, so skipping", r[i].InterfaceIndex) continue } if r[i].Metric < lowestMetric { @@ -53,7 +53,7 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU return nil } -func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { +func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, tun *tun.NativeTun) ([]winipcfg.ChangeCallback, error) { var minMTU uint32 if family == windows.AF_INET { minMTU = 576 @@ -64,7 +64,10 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, lastLUID := winipcfg.LUID(0) lastIndex := uint32(0) lastMTU := uint32(0) + mutex := sync.Mutex{} doIt := func() error { + mutex.Lock() + defer mutex.Unlock() err := bindSocketRoute(family, device, ourLUID, &lastLUID, &lastIndex) if err != nil { return err @@ -104,13 +107,22 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, if err != nil { return nil, err } - cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { + cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { if route != nil && route.DestinationPrefix.PrefixLength == 0 { - _ = doIt() + doIt() } }) if err != nil { return nil, err } - return cb, nil + cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { + if notificationType == winipcfg.MibParameterNotification { + doIt() + } + }) + if err != nil { + cbr.Unregister() + return nil, err + } + return []winipcfg.ChangeCallback{cbr, cbi}, nil } diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go index d74db0e9..92d08e90 100644 --- a/tunnel/interfacewatcher.go +++ b/tunnel/interfacewatcher.go @@ -35,32 +35,34 @@ type interfaceWatcher struct { tun *tun.NativeTun setupMutex sync.Mutex - routeChangeCallback4 *winipcfg.RouteChangeCallback - routeChangeCallback6 *winipcfg.RouteChangeCallback - interfaceChangeCallback *winipcfg.InterfaceChangeCallback + interfaceChangeCallback winipcfg.ChangeCallback + changeCallbacks4 []winipcfg.ChangeCallback + changeCallbacks6 []winipcfg.ChangeCallback storedEvents []interfaceWatcherEvent } func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { - var routeChangeCallback **winipcfg.RouteChangeCallback + var changeCallbacks *[]winipcfg.ChangeCallback var ipversion string if family == windows.AF_INET { - routeChangeCallback = &iw.routeChangeCallback4 + changeCallbacks = &iw.changeCallbacks4 ipversion = "v4" } else if family == windows.AF_INET6 { - routeChangeCallback = &iw.routeChangeCallback6 + changeCallbacks = &iw.changeCallbacks6 ipversion = "v6" } else { return } - if *routeChangeCallback != nil { - (*routeChangeCallback).Unregister() - *routeChangeCallback = nil + if len(*changeCallbacks) != 0 { + for _, cb := range *changeCallbacks { + cb.Unregister() + } + *changeCallbacks = nil } var err error log.Printf("Monitoring default %s routes", ipversion) - *routeChangeCallback, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, iw.tun) + *changeCallbacks, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, iw.tun) if err != nil { iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} return @@ -116,8 +118,8 @@ func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, func (iw *interfaceWatcher) Destroy() { iw.setupMutex.Lock() - routeChangeCallback4 := iw.routeChangeCallback4 - routeChangeCallback6 := iw.routeChangeCallback6 + changeCallbacks4 := iw.changeCallbacks4 + changeCallbacks6 := iw.changeCallbacks6 interfaceChangeCallback := iw.interfaceChangeCallback tun := iw.tun iw.setupMutex.Unlock() @@ -125,22 +127,24 @@ func (iw *interfaceWatcher) Destroy() { if interfaceChangeCallback != nil { interfaceChangeCallback.Unregister() } - if routeChangeCallback4 != nil { - routeChangeCallback4.Unregister() + for _, cb := range changeCallbacks4 { + cb.Unregister() } - if routeChangeCallback6 != nil { - routeChangeCallback6.Unregister() + for _, cb := range changeCallbacks6 { + cb.Unregister() } iw.setupMutex.Lock() if interfaceChangeCallback == iw.interfaceChangeCallback { iw.interfaceChangeCallback = nil } - if routeChangeCallback4 == iw.routeChangeCallback4 { - iw.routeChangeCallback4 = nil + for len(changeCallbacks4) > 0 && len(iw.changeCallbacks4) > 0 { + iw.changeCallbacks4 = iw.changeCallbacks4[1:] + changeCallbacks4 = changeCallbacks4[1:] } - if routeChangeCallback6 == iw.routeChangeCallback6 { - iw.routeChangeCallback6 = nil + for len(changeCallbacks6) > 0 && len(iw.changeCallbacks6) > 0 { + iw.changeCallbacks6 = iw.changeCallbacks6[1:] + changeCallbacks6 = changeCallbacks6[1:] } firewall.DisableFirewall() if tun != nil && iw.tun == tun { diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go index 684a6c77..81f9335d 100644 --- a/tunnel/winipcfg/types.go +++ b/tunnel/winipcfg/types.go @@ -510,6 +510,10 @@ const ( MibInitialNotification // Initial notification ) +type ChangeCallback interface { + Unregister() error +} + // TunnelType enumeration type defines the encapsulation method used by a tunnel, as described by the Internet Assigned Names Authority (IANA). // https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-tunnel_type type TunnelType uint32 |