aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tunnel/defaultroutemonitor.go
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-05-20 14:18:01 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2019-05-20 14:18:01 +0200
commite493f911269a2dabab7b05ec28726cdaeffb660e (patch)
treedb88ec568dfc508da863e67164de909448c66742 /tunnel/defaultroutemonitor.go
parentservice: move route monitor and account for changing index (diff)
downloadwireguard-windows-e493f911269a2dabab7b05ec28726cdaeffb660e.tar.xz
wireguard-windows-e493f911269a2dabab7b05ec28726cdaeffb660e.zip
service: split into tunnel and manager
Diffstat (limited to 'tunnel/defaultroutemonitor.go')
-rw-r--r--tunnel/defaultroutemonitor.go150
1 files changed, 150 insertions, 0 deletions
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go
new file mode 100644
index 00000000..1ffce5fa
--- /dev/null
+++ b/tunnel/defaultroutemonitor.go
@@ -0,0 +1,150 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package tunnel
+
+import (
+ "log"
+ "time"
+
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/winipcfg"
+ "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun"
+)
+
+func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLUID uint64, lastLUID *uint64, lastIndex *uint32) error {
+ routes, err := winipcfg.GetRoutes(family)
+ if err != nil {
+ return err
+ }
+ lowestMetric := ^uint32(0)
+ index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want.
+ luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so.
+ for _, route := range routes {
+ if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLUID == ourLUID {
+ continue
+ }
+ ifrow, err := winipcfg.GetIfRow(route.InterfaceLUID)
+ if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp {
+ log.Printf("Found default route for interface %d, but not up, so skipping", route.InterfaceIndex)
+ continue
+ }
+ if route.Metric < lowestMetric {
+ lowestMetric = route.Metric
+ index = route.InterfaceIndex
+ luid = route.InterfaceLUID
+ }
+ }
+ if luid == *lastLUID && index == *lastIndex {
+ return nil
+ }
+ *lastLUID = luid
+ *lastIndex = index
+ if family == windows.AF_INET {
+ log.Printf("Binding UDPv4 socket to interface %d", index)
+ return device.BindSocketToInterface4(index)
+ } else if family == windows.AF_INET6 {
+ log.Printf("Binding UDPv6 socket to interface %d", index)
+ return device.BindSocketToInterface6(index)
+ }
+ return nil
+}
+
+func getIPInterfaceRetry(luid uint64, family winipcfg.AddressFamily, retry bool) (ipi *winipcfg.IPInterface, err error) {
+ const maxRetries = 100
+ for i := 0; i < maxRetries; i++ {
+ ipi, err = winipcfg.GetIPInterface(luid, family)
+ if retry && i != maxRetries-1 && err == windows.ERROR_NOT_FOUND {
+ time.Sleep(time.Millisecond * 50)
+ continue
+ }
+ break
+ }
+ return
+}
+
+func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
+ ourLUID := tun.LUID()
+ lastLUID4 := uint64(0)
+ lastLUID6 := uint64(0)
+ lastIndex4 := uint32(0)
+ lastIndex6 := uint32(0)
+ lastMTU := uint32(0)
+ doIt := func(retry bool) error {
+ err := bindSocketRoute(windows.AF_INET, device, ourLUID, &lastLUID4, &lastIndex4)
+ if err != nil {
+ return err
+ }
+ err = bindSocketRoute(windows.AF_INET6, device, ourLUID, &lastLUID6, &lastIndex6)
+ if err != nil {
+ return err
+ }
+ if !autoMTU {
+ return nil
+ }
+ mtu := uint32(0)
+ if lastLUID4 != 0 {
+ iface, err := winipcfg.InterfaceFromLUID(lastLUID4)
+ if err != nil {
+ return err
+ }
+ if iface.MTU > 0 {
+ mtu = iface.MTU
+ }
+ }
+ if lastLUID6 != 0 {
+ iface, err := winipcfg.InterfaceFromLUID(lastLUID6)
+ if err != nil {
+ return err
+ }
+ if iface.MTU > 0 && iface.MTU < mtu {
+ mtu = iface.MTU
+ }
+ }
+ if mtu > 0 && (lastMTU == 0 || lastMTU != mtu) {
+ iface, err := getIPInterfaceRetry(ourLUID, windows.AF_INET, retry)
+ if err != nil {
+ return err
+ }
+ iface.NLMTU = mtu - 80
+ if iface.NLMTU < 576 {
+ iface.NLMTU = 576
+ }
+ err = iface.Set()
+ if err != nil {
+ return err
+ }
+ tun.ForceMTU(int(iface.NLMTU)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now.
+ iface, err = getIPInterfaceRetry(ourLUID, windows.AF_INET6, retry)
+ if err != nil {
+ return err
+ }
+ iface.NLMTU = mtu - 80
+ if iface.NLMTU < 1280 {
+ iface.NLMTU = 1280
+ }
+ err = iface.Set()
+ if err != nil {
+ return err
+ }
+ lastMTU = mtu
+ }
+ return nil
+ }
+ err := doIt(true)
+ if err != nil {
+ return nil, err
+ }
+ cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) {
+ if route.DestinationPrefix.PrefixLength == 0 {
+ _ = doIt(false)
+ }
+ })
+ if err != nil {
+ return nil, err
+ }
+ return cb, nil
+}