diff options
Diffstat (limited to 'tunnel/defaultroutemonitor.go')
-rw-r--r-- | tunnel/defaultroutemonitor.go | 45 |
1 files changed, 23 insertions, 22 deletions
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go index 1ffce5fa..8dd3273c 100644 --- a/tunnel/defaultroutemonitor.go +++ b/tunnel/defaultroutemonitor.go @@ -10,32 +10,33 @@ import ( "time" "golang.org/x/sys/windows" - "golang.zx2c4.com/winipcfg" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" + + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLUID uint64, lastLUID *uint64, lastIndex *uint32) error { - routes, err := winipcfg.GetRoutes(family) +func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32) error { + r, err := winipcfg.GetIPForwardTable2(family) if err != nil { return err } lowestMetric := ^uint32(0) - index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want. - luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so. - for _, route := range routes { - if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLUID == ourLUID { + index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want. + luid := winipcfg.LUID(0) // Hopefully luid zero is unspecified, but hard to find docs saying so. + for i := range r { + if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID { continue } - ifrow, err := winipcfg.GetIfRow(route.InterfaceLUID) + ifrow, err := r[i].InterfaceLUID.Interface() if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp { - log.Printf("Found default route for interface %d, but not up, so skipping", route.InterfaceIndex) + log.Printf("Found default route for interface %d, but not up, so skipping", r[i].InterfaceIndex) continue } - if route.Metric < lowestMetric { - lowestMetric = route.Metric - index = route.InterfaceIndex - luid = route.InterfaceLUID + if r[i].Metric < lowestMetric { + lowestMetric = r[i].Metric + index = r[i].InterfaceIndex + luid = r[i].InterfaceLUID } } if luid == *lastLUID && index == *lastIndex { @@ -53,10 +54,10 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU return nil } -func getIPInterfaceRetry(luid uint64, family winipcfg.AddressFamily, retry bool) (ipi *winipcfg.IPInterface, err error) { +func getIPInterfaceRetry(luid winipcfg.LUID, family winipcfg.AddressFamily, retry bool) (ipi *winipcfg.MibIPInterfaceRow, err error) { const maxRetries = 100 for i := 0; i < maxRetries; i++ { - ipi, err = winipcfg.GetIPInterface(luid, family) + ipi, err = luid.IPInterface(family) if retry && i != maxRetries-1 && err == windows.ERROR_NOT_FOUND { time.Sleep(time.Millisecond * 50) continue @@ -67,9 +68,9 @@ func getIPInterfaceRetry(luid uint64, family winipcfg.AddressFamily, retry bool) } func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { - ourLUID := tun.LUID() - lastLUID4 := uint64(0) - lastLUID6 := uint64(0) + ourLUID := winipcfg.LUID(tun.LUID()) + lastLUID4 := winipcfg.LUID(0) + lastLUID6 := winipcfg.LUID(0) lastIndex4 := uint32(0) lastIndex6 := uint32(0) lastMTU := uint32(0) @@ -87,7 +88,7 @@ func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTu } mtu := uint32(0) if lastLUID4 != 0 { - iface, err := winipcfg.InterfaceFromLUID(lastLUID4) + iface, err := lastLUID4.Interface() if err != nil { return err } @@ -96,7 +97,7 @@ func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTu } } if lastLUID6 != 0 { - iface, err := winipcfg.InterfaceFromLUID(lastLUID6) + iface, err := lastLUID6.Interface() if err != nil { return err } @@ -138,8 +139,8 @@ func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTu if err != nil { return nil, err } - cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) { - if route.DestinationPrefix.PrefixLength == 0 { + cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { + if route != nil && route.DestinationPrefix.PrefixLength == 0 { _ = doIt(false) } }) |