From c8fc9df2d766e9a27c1d027c080697cb65d10590 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 20 May 2019 14:01:38 +0200 Subject: service: move route monitor and account for changing index --- service/defaultroutemonitor.go | 150 +++++++++++++++++++++++++++++++++++++++++ service/ifaceconfig.go | 133 ------------------------------------ 2 files changed, 150 insertions(+), 133 deletions(-) create mode 100644 service/defaultroutemonitor.go diff --git a/service/defaultroutemonitor.go b/service/defaultroutemonitor.go new file mode 100644 index 00000000..d4105447 --- /dev/null +++ b/service/defaultroutemonitor.go @@ -0,0 +1,150 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package service + +import ( + "log" + "time" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/winipcfg" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" +) + +func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLUID uint64, lastLUID *uint64, lastIndex *uint32) error { + routes, err := winipcfg.GetRoutes(family) + if err != nil { + return err + } + lowestMetric := ^uint32(0) + index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want. + luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so. + for _, route := range routes { + if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLUID == ourLUID { + continue + } + ifrow, err := winipcfg.GetIfRow(route.InterfaceLUID) + if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp { + log.Printf("Found default route for interface %d, but not up, so skipping", route.InterfaceIndex) + continue + } + if route.Metric < lowestMetric { + lowestMetric = route.Metric + index = route.InterfaceIndex + luid = route.InterfaceLUID + } + } + if luid == *lastLUID && index == *lastIndex { + return nil + } + *lastLUID = luid + *lastIndex = index + if family == windows.AF_INET { + log.Printf("Binding UDPv4 socket to interface %d", index) + return device.BindSocketToInterface4(index) + } else if family == windows.AF_INET6 { + log.Printf("Binding UDPv6 socket to interface %d", index) + return device.BindSocketToInterface6(index) + } + return nil +} + +func getIPInterfaceRetry(luid uint64, family winipcfg.AddressFamily, retry bool) (ipi *winipcfg.IPInterface, err error) { + const maxRetries = 100 + for i := 0; i < maxRetries; i++ { + ipi, err = winipcfg.GetIPInterface(luid, family) + if retry && i != maxRetries-1 && err == windows.ERROR_NOT_FOUND { + time.Sleep(time.Millisecond * 50) + continue + } + break + } + return +} + +func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { + ourLUID := tun.LUID() + lastLUID4 := uint64(0) + lastLUID6 := uint64(0) + lastIndex4 := uint32(0) + lastIndex6 := uint32(0) + lastMTU := uint32(0) + doIt := func(retry bool) error { + err := bindSocketRoute(windows.AF_INET, device, ourLUID, &lastLUID4, &lastIndex4) + if err != nil { + return err + } + err = bindSocketRoute(windows.AF_INET6, device, ourLUID, &lastLUID6, &lastIndex6) + if err != nil { + return err + } + if !autoMTU { + return nil + } + mtu := uint32(0) + if lastLUID4 != 0 { + iface, err := winipcfg.InterfaceFromLUID(lastLUID4) + if err != nil { + return err + } + if iface.MTU > 0 { + mtu = iface.MTU + } + } + if lastLUID6 != 0 { + iface, err := winipcfg.InterfaceFromLUID(lastLUID6) + if err != nil { + return err + } + if iface.MTU > 0 && iface.MTU < mtu { + mtu = iface.MTU + } + } + if mtu > 0 && (lastMTU == 0 || lastMTU != mtu) { + iface, err := getIPInterfaceRetry(ourLUID, windows.AF_INET, retry) + if err != nil { + return err + } + iface.NLMTU = mtu - 80 + if iface.NLMTU < 576 { + iface.NLMTU = 576 + } + err = iface.Set() + if err != nil { + return err + } + tun.ForceMTU(int(iface.NLMTU)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now. + iface, err = getIPInterfaceRetry(ourLUID, windows.AF_INET6, retry) + if err != nil { + return err + } + iface.NLMTU = mtu - 80 + if iface.NLMTU < 1280 { + iface.NLMTU = 1280 + } + err = iface.Set() + if err != nil { + return err + } + lastMTU = mtu + } + return nil + } + err := doIt(true) + if err != nil { + return nil, err + } + cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) { + if route.DestinationPrefix.PrefixLength == 0 { + _ = doIt(false) + } + }) + if err != nil { + return nil, err + } + return cb, nil +} diff --git a/service/ifaceconfig.go b/service/ifaceconfig.go index ea28ecdc..7d8b2f76 100644 --- a/service/ifaceconfig.go +++ b/service/ifaceconfig.go @@ -10,148 +10,15 @@ import ( "log" "net" "sort" - "time" "golang.org/x/sys/windows" "golang.zx2c4.com/winipcfg" - "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/service/firewall" ) -func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLUID uint64, lastLUID *uint64) error { - routes, err := winipcfg.GetRoutes(family) - if err != nil { - return err - } - lowestMetric := ^uint32(0) - index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want. - luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so. - for _, route := range routes { - if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLUID == ourLUID { - continue - } - ifrow, err := winipcfg.GetIfRow(route.InterfaceLUID) - if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp { - log.Printf("Found default route for interface %d, but not up, so skipping", route.InterfaceIndex) - continue - } - if route.Metric < lowestMetric { - lowestMetric = route.Metric - index = route.InterfaceIndex - luid = route.InterfaceLUID - } - } - if luid == *lastLUID { - return nil - } - *lastLUID = luid - if family == windows.AF_INET { - log.Printf("Binding UDPv4 socket to interface %d", index) - return device.BindSocketToInterface4(index) - } else if family == windows.AF_INET6 { - log.Printf("Binding UDPv6 socket to interface %d", index) - return device.BindSocketToInterface6(index) - } - return nil -} - -func getIPInterfaceRetry(luid uint64, family winipcfg.AddressFamily, retry bool) (ipi *winipcfg.IPInterface, err error) { - const maxRetries = 100 - for i := 0; i < maxRetries; i++ { - ipi, err = winipcfg.GetIPInterface(luid, family) - if retry && i != maxRetries-1 && err == windows.ERROR_NOT_FOUND { - time.Sleep(time.Millisecond * 50) - continue - } - break - } - return -} - -func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { - ourLUID := tun.LUID() - lastLUID4 := uint64(0) - lastLUID6 := uint64(0) - lastMTU := uint32(0) - doIt := func(retry bool) error { - err := bindSocketRoute(windows.AF_INET, device, ourLUID, &lastLUID4) - if err != nil { - return err - } - err = bindSocketRoute(windows.AF_INET6, device, ourLUID, &lastLUID6) - if err != nil { - return err - } - if !autoMTU { - return nil - } - mtu := uint32(0) - if lastLUID4 != 0 { - iface, err := winipcfg.InterfaceFromLUID(lastLUID4) - if err != nil { - return err - } - if iface.MTU > 0 { - mtu = iface.MTU - } - } - if lastLUID6 != 0 { - iface, err := winipcfg.InterfaceFromLUID(lastLUID6) - if err != nil { - return err - } - if iface.MTU > 0 && iface.MTU < mtu { - mtu = iface.MTU - } - } - if mtu > 0 && (lastMTU == 0 || lastMTU != mtu) { - iface, err := getIPInterfaceRetry(ourLUID, windows.AF_INET, retry) - if err != nil { - return err - } - iface.NLMTU = mtu - 80 - if iface.NLMTU < 576 { - iface.NLMTU = 576 - } - err = iface.Set() - if err != nil { - return err - } - tun.ForceMTU(int(iface.NLMTU)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now. - iface, err = getIPInterfaceRetry(ourLUID, windows.AF_INET6, retry) - if err != nil { - return err - } - iface.NLMTU = mtu - 80 - if iface.NLMTU < 1280 { - iface.NLMTU = 1280 - } - err = iface.Set() - if err != nil { - return err - } - lastMTU = mtu - } - return nil - } - err := doIt(true) - if err != nil { - return nil, err - } - cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) { - if route.DestinationPrefix.PrefixLength == 0 { - _ = doIt(false) - } - }) - if err != nil { - return nil, err - } - return cb, nil -} - func cleanupAddressesOnDisconnectedInterfaces(addresses []*net.IPNet) { if len(addresses) == 0 { return -- cgit v1.2.3-59-g8ed1b