aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--tunnel/defaultroutemonitor.go22
-rw-r--r--tunnel/interfacewatcher.go44
-rw-r--r--tunnel/winipcfg/types.go4
3 files changed, 45 insertions, 25 deletions
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go
index c1722c45..f9c63e56 100644
--- a/tunnel/defaultroutemonitor.go
+++ b/tunnel/defaultroutemonitor.go
@@ -7,6 +7,7 @@ package tunnel
import (
"log"
+ "sync"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/device"
@@ -29,7 +30,6 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU
}
ifrow, err := r[i].InterfaceLUID.Interface()
if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp {
- log.Printf("Found default route for interface %d, but not up, so skipping", r[i].InterfaceIndex)
continue
}
if r[i].Metric < lowestMetric {
@@ -53,7 +53,7 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU
return nil
}
-func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
+func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, tun *tun.NativeTun) ([]winipcfg.ChangeCallback, error) {
var minMTU uint32
if family == windows.AF_INET {
minMTU = 576
@@ -64,7 +64,10 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device,
lastLUID := winipcfg.LUID(0)
lastIndex := uint32(0)
lastMTU := uint32(0)
+ mutex := sync.Mutex{}
doIt := func() error {
+ mutex.Lock()
+ defer mutex.Unlock()
err := bindSocketRoute(family, device, ourLUID, &lastLUID, &lastIndex)
if err != nil {
return err
@@ -104,13 +107,22 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device,
if err != nil {
return nil, err
}
- cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
+ cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
if route != nil && route.DestinationPrefix.PrefixLength == 0 {
- _ = doIt()
+ doIt()
}
})
if err != nil {
return nil, err
}
- return cb, nil
+ cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
+ if notificationType == winipcfg.MibParameterNotification {
+ doIt()
+ }
+ })
+ if err != nil {
+ cbr.Unregister()
+ return nil, err
+ }
+ return []winipcfg.ChangeCallback{cbr, cbi}, nil
}
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go
index d74db0e9..92d08e90 100644
--- a/tunnel/interfacewatcher.go
+++ b/tunnel/interfacewatcher.go
@@ -35,32 +35,34 @@ type interfaceWatcher struct {
tun *tun.NativeTun
setupMutex sync.Mutex
- routeChangeCallback4 *winipcfg.RouteChangeCallback
- routeChangeCallback6 *winipcfg.RouteChangeCallback
- interfaceChangeCallback *winipcfg.InterfaceChangeCallback
+ interfaceChangeCallback winipcfg.ChangeCallback
+ changeCallbacks4 []winipcfg.ChangeCallback
+ changeCallbacks6 []winipcfg.ChangeCallback
storedEvents []interfaceWatcherEvent
}
func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
- var routeChangeCallback **winipcfg.RouteChangeCallback
+ var changeCallbacks *[]winipcfg.ChangeCallback
var ipversion string
if family == windows.AF_INET {
- routeChangeCallback = &iw.routeChangeCallback4
+ changeCallbacks = &iw.changeCallbacks4
ipversion = "v4"
} else if family == windows.AF_INET6 {
- routeChangeCallback = &iw.routeChangeCallback6
+ changeCallbacks = &iw.changeCallbacks6
ipversion = "v6"
} else {
return
}
- if *routeChangeCallback != nil {
- (*routeChangeCallback).Unregister()
- *routeChangeCallback = nil
+ if len(*changeCallbacks) != 0 {
+ for _, cb := range *changeCallbacks {
+ cb.Unregister()
+ }
+ *changeCallbacks = nil
}
var err error
log.Printf("Monitoring default %s routes", ipversion)
- *routeChangeCallback, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, iw.tun)
+ *changeCallbacks, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, iw.tun)
if err != nil {
iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
return
@@ -116,8 +118,8 @@ func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config,
func (iw *interfaceWatcher) Destroy() {
iw.setupMutex.Lock()
- routeChangeCallback4 := iw.routeChangeCallback4
- routeChangeCallback6 := iw.routeChangeCallback6
+ changeCallbacks4 := iw.changeCallbacks4
+ changeCallbacks6 := iw.changeCallbacks6
interfaceChangeCallback := iw.interfaceChangeCallback
tun := iw.tun
iw.setupMutex.Unlock()
@@ -125,22 +127,24 @@ func (iw *interfaceWatcher) Destroy() {
if interfaceChangeCallback != nil {
interfaceChangeCallback.Unregister()
}
- if routeChangeCallback4 != nil {
- routeChangeCallback4.Unregister()
+ for _, cb := range changeCallbacks4 {
+ cb.Unregister()
}
- if routeChangeCallback6 != nil {
- routeChangeCallback6.Unregister()
+ for _, cb := range changeCallbacks6 {
+ cb.Unregister()
}
iw.setupMutex.Lock()
if interfaceChangeCallback == iw.interfaceChangeCallback {
iw.interfaceChangeCallback = nil
}
- if routeChangeCallback4 == iw.routeChangeCallback4 {
- iw.routeChangeCallback4 = nil
+ for len(changeCallbacks4) > 0 && len(iw.changeCallbacks4) > 0 {
+ iw.changeCallbacks4 = iw.changeCallbacks4[1:]
+ changeCallbacks4 = changeCallbacks4[1:]
}
- if routeChangeCallback6 == iw.routeChangeCallback6 {
- iw.routeChangeCallback6 = nil
+ for len(changeCallbacks6) > 0 && len(iw.changeCallbacks6) > 0 {
+ iw.changeCallbacks6 = iw.changeCallbacks6[1:]
+ changeCallbacks6 = changeCallbacks6[1:]
}
firewall.DisableFirewall()
if tun != nil && iw.tun == tun {
diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go
index 684a6c77..81f9335d 100644
--- a/tunnel/winipcfg/types.go
+++ b/tunnel/winipcfg/types.go
@@ -510,6 +510,10 @@ const (
MibInitialNotification // Initial notification
)
+type ChangeCallback interface {
+ Unregister() error
+}
+
// TunnelType enumeration type defines the encapsulation method used by a tunnel, as described by the Internet Assigned Names Authority (IANA).
// https://docs.microsoft.com/en-us/windows/desktop/api/ifdef/ne-ifdef-tunnel_type
type TunnelType uint32