diff options
-rw-r--r-- | tunnel/defaultroutemonitor.go | 22 | ||||
-rw-r--r-- | tunnel/interfacewatcher.go | 44 | ||||
-rw-r--r-- | tunnel/winipcfg/types.go | 4 |
3 files changed, 45 insertions, 25 deletions
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go index c1722c45..f9c63e56 100644 --- a/tunnel/defaultroutemonitor.go +++ b/tunnel/defaultroutemonitor.go @@ -7,6 +7,7 @@ package tunnel import ( "log" + "sync" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/device" @@ -29,7 +30,6 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU } ifrow, err := r[i].InterfaceLUID.Interface() if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp { - log.Printf("Found default route for interface %d, but not up, so skipping", r[i].InterfaceIndex) continue } if r[i].Metric < lowestMetric { @@ -53,7 +53,7 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU return nil } -func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { +func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, tun *tun.NativeTun) ([]winipcfg.ChangeCallback, error) { var minMTU uint32 if family == windows.AF_INET { minMTU = 576 @@ -64,7 +64,10 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, lastLUID := winipcfg.LUID(0) lastIndex := uint32(0) lastMTU := uint32(0) + mutex := sync.Mutex{} doIt := func() error { + mutex.Lock() + defer mutex.Unlock() err := bindSocketRoute(family, device, ourLUID, &lastLUID, &lastIndex) if err != nil { return err @@ -104,13 +107,22 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, if err != nil { return nil, err } - cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { + cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { if route != nil && route.DestinationPrefix.PrefixLength == 0 { - _ = doIt() + doIt() } }) if err != nil { return nil, err } - return cb, nil + cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { + if notificationType == winipcfg.MibParameterNotification { + doIt() + } + }) + if err != nil { + cbr.Unregister() + return nil, err + } + return []winipcfg.ChangeCallback{cbr, cbi}, nil } diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go index d74db0e9..92d08e90 100644 --- a/tunnel/interfacewatcher.go +++ b/tunnel/interfacewatcher.go @@ -35,32 +35,34 @@ type interfaceWatcher struct { tun *tun.NativeTun setupMutex sync.Mutex - routeChangeCallback4 *winipcfg.RouteChangeCallback - routeChangeCallback6 *winipcfg.RouteChangeCallback - interfaceChangeCallback *winipcfg.InterfaceChangeCallback + interfaceChangeCallback winipcfg.ChangeCallback + changeCallbacks4 []winipcfg.ChangeCallback + changeCallbacks6 []winipcfg.ChangeCallback storedEvents []interfaceWatcherEvent } func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { - var routeChangeCallback **winipcfg.RouteChangeCallback + var changeCallbacks *[]winipcfg.ChangeCallback var ipversion string if family == windows.AF_INET { - routeChangeCallback = &iw.routeChangeCallback4 + changeCallbacks = &iw.changeCallbacks4 ipversion = "v4" } else if family == windows.AF_INET6 { - routeChangeCallback = &iw.routeChangeCallback6 + changeCallbacks = &iw.changeCallbacks6 ipversion = "v6" } else { return } - if *routeChangeCallback != nil { - (*routeChangeCallback).Unregister() - *routeChangeCallback = nil + if len(*changeCallbacks) != 0 { + for _, cb := range *changeCallbacks { + cb.Unregister() + } + *changeCallbacks = nil } var err error log.Printf("Monitoring default %s routes", ipversion) - *routeChangeCallback, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, iw.tun) + *changeCallbacks, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, iw.tun) if err != nil { iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} return @@ -116,8 +118,8 @@ func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, func (iw *interfaceWatcher) Destroy() { iw.setupMutex.Lock() - routeChangeCallback4 := iw.routeChangeCallback4 - routeChangeCallback6 := iw.routeChangeCallback6 + changeCallbacks4 := iw.changeCallbacks4 + changeCallbacks6 := iw.changeCallbacks6 interfaceChangeCallback := iw.interfaceChangeCallback tun := iw.tun iw.setupMutex.Unlock() @@ -125,22 +127,24 @@ func (iw *interfaceWatcher) Destroy() { if interfaceChangeCallback != nil { interfaceChangeCallback.Unregister() } - if routeChangeCallback4 != nil { - routeChangeCallback4.Unregister() + for _, cb := range changeCallbacks4 { + cb.Unregister() } - if routeChangeCallback6 != nil { - routeChangeCallback6.Unregister() + for _, cb := range changeCallbacks6 { + cb.Unregister() } iw.setupMutex.Lock() if interfaceChangeCallback == iw.interfaceChangeCallback { iw.interfaceChangeCallback = nil } - if routeChangeCallback4 == iw.routeChangeCallback4 { - iw.routeChangeCallback4 = nil + for len(changeCallbacks4) > 0 && len(iw.changeCallbacks4) > 0 { + iw.changeCallbacks4 = iw.changeCallbacks4[1:] + changeCallbacks4 = changeCallbacks4[1:] } - if routeChangeCallback6 == iw.routeChangeCallback6 { - iw.routeChangeCallback6 = nil + for len(changeCallbacks6) > 0 && len(iw.changeCallbacks6) > 0 { + iw.changeCallbacks6 = iw.changeCallbacks6[1:] + changeCallbacks6 = changeCallbacks6[1:] } firewall.DisableFirewall() if tun != nil && iw.tun == tun { diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go index 684a6c77..81f9335d 100644 --- a/tunnel/winipcfg/types.go +++ b/tunnel/winipcfg/types.go @@ -510,6 +510,10 @@ const ( MibInitialNotification // Initial notification ) +type ChangeCallback interface { + Unregister() error +} + // TunnelType enumeration type defines the encapsulation method used by a tunnel, as described by the Internet Assigned Names Authority (IANA). // https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-tunnel_type type TunnelType uint32 |