aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-06-17 13:08:13 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2019-06-18 12:08:51 +0200
commitb77c634b9eb0e1e732b60667a5974f8340a50253 (patch)
tree13a2d97bcb3ad8e44f288d3220fd9030d9eb968c
parentversion: bump (diff)
downloadwireguard-windows-b77c634b9eb0e1e732b60667a5974f8340a50253.tar.xz
wireguard-windows-b77c634b9eb0e1e732b60667a5974f8340a50253.zip
tunnel: wait for IP service to attach to wintun
This helps fix startup races without needing to poll, as well as reconfiguring interfaces after wintun destroys and re-adds. It also deals gracefully with IPv6 being disabled.
-rw-r--r--tunnel/addressconfig.go (renamed from tunnel/ifaceconfig.go)81
-rw-r--r--tunnel/defaultroutemonitor.go56
-rw-r--r--tunnel/interfacewatcher.go148
-rw-r--r--tunnel/service.go37
-rw-r--r--tunnel/winipcfg/luid.go67
-rw-r--r--tunnel/winipcfg/winipcfg.go4
6 files changed, 268 insertions, 125 deletions
diff --git a/tunnel/ifaceconfig.go b/tunnel/addressconfig.go
index a71b612e..a1e5dc59 100644
--- a/tunnel/ifaceconfig.go
+++ b/tunnel/addressconfig.go
@@ -10,7 +10,6 @@ import (
"log"
"net"
"sort"
- "time"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun"
@@ -20,7 +19,7 @@ import (
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
-func cleanupAddressesOnDisconnectedInterfaces(addresses []net.IPNet) {
+func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) {
if len(addresses) == 0 {
return
}
@@ -39,7 +38,7 @@ func cleanupAddressesOnDisconnectedInterfaces(addresses []net.IPNet) {
}
return false
}
- interfaces, err := winipcfg.GetAdaptersAddresses(windows.AF_UNSPEC, winipcfg.GAAFlagDefault)
+ interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault)
if err != nil {
return
}
@@ -58,7 +57,7 @@ func cleanupAddressesOnDisconnectedInterfaces(addresses []net.IPNet) {
}
}
-func configureInterface(conf *conf.Config, tun *tun.NativeTun) error {
+func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error {
luid := winipcfg.LUID(tun.LUID())
estimatedRouteCount := len(conf.Interface.Addresses)
@@ -114,10 +113,10 @@ func configureInterface(conf *conf.Config, tun *tun.NativeTun) error {
}
}
- err := luid.SetIPAddresses(addresses)
+ err := luid.SetIPAddressesForFamily(family, addresses)
if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
- cleanupAddressesOnDisconnectedInterfaces(addresses)
- err = luid.SetIPAddresses(addresses)
+ cleanupAddressesOnDisconnectedInterfaces(family, addresses)
+ err = luid.SetIPAddressesForFamily(family, addresses)
}
if err != nil {
return err
@@ -140,49 +139,38 @@ func configureInterface(conf *conf.Config, tun *tun.NativeTun) error {
deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
}
- err = luid.SetRoutes(deduplicatedRoutes)
+ err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
if err != nil {
return nil
}
- ipif, err := luid.IPInterface(windows.AF_INET)
+ ipif, err := luid.IPInterface(family)
if err != nil {
return err
}
- if foundDefault4 {
- ipif.UseAutomaticMetric = false
- ipif.Metric = 0
- }
if conf.Interface.MTU > 0 {
ipif.NLMTU = uint32(conf.Interface.MTU)
tun.ForceMTU(int(ipif.NLMTU))
}
- err = ipif.Set()
- if err != nil {
- return err
- }
-
- ipif, err = luid.IPInterface(windows.AF_INET6)
- if err != nil && firstGateway6 != nil {
- log.Printf("Is IPv6 disabled by Windows?")
- return err
- } else if err == nil { // People seem to like to disable IPv6, so we make this non-fatal.
- if foundDefault6 {
+ if family == windows.AF_INET {
+ if foundDefault4 {
ipif.UseAutomaticMetric = false
ipif.Metric = 0
}
- if conf.Interface.MTU > 0 {
- ipif.NLMTU = uint32(conf.Interface.MTU)
+ } else if family == windows.AF_INET6 {
+ if foundDefault6 {
+ ipif.UseAutomaticMetric = false
+ ipif.Metric = 0
}
ipif.DadTransmits = 0
ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
- err = ipif.Set()
- if err != nil {
- return err
- }
+ }
+ err = ipif.Set()
+ if err != nil {
+ return err
}
- err = luid.SetDNS(conf.Interface.DNS)
+ err = luid.SetDNSForFamily(family, conf.Interface.DNS)
if err != nil {
return err
}
@@ -190,19 +178,6 @@ func configureInterface(conf *conf.Config, tun *tun.NativeTun) error {
return nil
}
-func unconfigureInterface(tun *tun.NativeTun) {
- // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active
- // routes, so to be certain, just remove everything before destroying.
- luid := winipcfg.LUID(tun.LUID())
- luid.FlushRoutes(windows.AF_INET)
- luid.FlushIPAddresses(windows.AF_INET)
- luid.FlushRoutes(windows.AF_INET6)
- luid.FlushIPAddresses(windows.AF_INET6)
- luid.FlushDNS()
-
- firewall.DisableFirewall()
-}
-
func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error {
restrictAll := false
if len(conf.Peers) == 1 {
@@ -224,21 +199,3 @@ func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error {
}
return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS, restrictAll)
}
-
-func waitForFamilies(tun *tun.NativeTun) {
- // TODO: This whole thing is a disgusting hack that shouldn't be neccessary.
-
- f := func(luid winipcfg.LUID, family winipcfg.AddressFamily, maxRetries int) {
- for i := 0; i < maxRetries; i++ {
- _, err := luid.IPInterface(family)
- if i != maxRetries-1 && err == windows.ERROR_NOT_FOUND {
- time.Sleep(time.Millisecond * 50)
- continue
- }
- break
- }
- }
- luid := winipcfg.LUID(tun.LUID())
- f(luid, windows.AF_INET, 100)
- f(luid, windows.AF_INET6, 3)
-}
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go
index e9440710..c1722c45 100644
--- a/tunnel/defaultroutemonitor.go
+++ b/tunnel/defaultroutemonitor.go
@@ -44,28 +44,28 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU
*lastLUID = luid
*lastIndex = index
if family == windows.AF_INET {
- log.Printf("Binding UDPv4 socket to interface %d", index)
+ log.Printf("Binding v4 socket to interface %d", index)
return device.BindSocketToInterface4(index)
} else if family == windows.AF_INET6 {
- log.Printf("Binding UDPv6 socket to interface %d", index)
+ log.Printf("Binding v6 socket to interface %d", index)
return device.BindSocketToInterface6(index)
}
return nil
}
-func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
+func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) {
+ var minMTU uint32
+ if family == windows.AF_INET {
+ minMTU = 576
+ } else if family == windows.AF_INET6 {
+ minMTU = 1280
+ }
ourLUID := winipcfg.LUID(tun.LUID())
- lastLUID4 := winipcfg.LUID(0)
- lastLUID6 := winipcfg.LUID(0)
- lastIndex4 := uint32(0)
- lastIndex6 := uint32(0)
+ lastLUID := winipcfg.LUID(0)
+ lastIndex := uint32(0)
lastMTU := uint32(0)
doIt := func() error {
- err := bindSocketRoute(windows.AF_INET, device, ourLUID, &lastLUID4, &lastIndex4)
- if err != nil {
- return err
- }
- err = bindSocketRoute(windows.AF_INET6, device, ourLUID, &lastLUID6, &lastIndex6)
+ err := bindSocketRoute(family, device, ourLUID, &lastLUID, &lastIndex)
if err != nil {
return err
}
@@ -73,8 +73,8 @@ func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTu
return nil
}
mtu := uint32(0)
- if lastLUID4 != 0 {
- iface, err := lastLUID4.Interface()
+ if lastLUID != 0 {
+ iface, err := lastLUID.Interface()
if err != nil {
return err
}
@@ -82,40 +82,20 @@ func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTu
mtu = iface.MTU
}
}
- if lastLUID6 != 0 {
- iface, err := lastLUID6.Interface()
- if err != nil {
- return err
- }
- if iface.MTU > 0 && iface.MTU < mtu {
- mtu = iface.MTU
- }
- }
if mtu > 0 && lastMTU != mtu {
- iface, err := ourLUID.IPInterface(windows.AF_INET)
+ iface, err := ourLUID.IPInterface(family)
if err != nil {
return err
}
iface.NLMTU = mtu - 80
- if iface.NLMTU < 576 {
- iface.NLMTU = 576
+ if iface.NLMTU < minMTU {
+ iface.NLMTU = minMTU
}
err = iface.Set()
if err != nil {
return err
}
- tun.ForceMTU(int(iface.NLMTU)) // TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now.
- iface, err = ourLUID.IPInterface(windows.AF_INET6)
- if err == nil { // People seem to like to disable IPv6, so we make this non-fatal.
- iface.NLMTU = mtu - 80
- if iface.NLMTU < 1280 {
- iface.NLMTU = 1280
- }
- err = iface.Set()
- if err != nil {
- return err
- }
- }
+ tun.ForceMTU(int(iface.NLMTU)) // TODO: having one MTU for both v4 and v6 kind of breaks the windows model, so right now this just gets the second one which is... bad.
lastMTU = mtu
}
return nil
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go
new file mode 100644
index 00000000..b7a07f77
--- /dev/null
+++ b/tunnel/interfacewatcher.go
@@ -0,0 +1,148 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package tunnel
+
+import (
+ "log"
+ "sync"
+
+ "golang.org/x/sys/windows"
+ "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/services"
+ "golang.zx2c4.com/wireguard/windows/tunnel/firewall"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+)
+
+type interfaceWatcherError struct {
+ serviceError services.Error
+ err error
+}
+type interfaceWatcherEvent struct {
+ luid winipcfg.LUID
+ family winipcfg.AddressFamily
+}
+type interfaceWatcher struct {
+ errors chan interfaceWatcherError
+
+ device *device.Device
+ conf *conf.Config
+ tun *tun.NativeTun
+
+ setupMutex sync.Mutex
+ routeChangeCallback4 *winipcfg.RouteChangeCallback
+ routeChangeCallback6 *winipcfg.RouteChangeCallback
+ interfaceChangeCallback *winipcfg.InterfaceChangeCallback
+ storedEvents []interfaceWatcherEvent
+}
+
+func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
+ var routeChangeCallback **winipcfg.RouteChangeCallback
+ var ipversion string
+ if family == windows.AF_INET {
+ routeChangeCallback = &iw.routeChangeCallback4
+ ipversion = "v4"
+ } else if family == windows.AF_INET6 {
+ routeChangeCallback = &iw.routeChangeCallback6
+ ipversion = "v6"
+ } else {
+ return
+ }
+ if *routeChangeCallback != nil {
+ (*routeChangeCallback).Unregister()
+ *routeChangeCallback = nil
+ }
+ var err error
+
+ log.Printf("Monitoring default %s routes", ipversion)
+ *routeChangeCallback, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, iw.tun)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
+ return
+ }
+
+ log.Printf("Setting device %s addresses", ipversion)
+ err = configureInterface(family, iw.conf, iw.tun)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err}
+ return
+ }
+}
+
+func watchInterface() (*interfaceWatcher, error) {
+ iw := &interfaceWatcher{
+ errors: make(chan interfaceWatcherError, 2),
+ }
+ var err error
+ iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
+ iw.setupMutex.Lock()
+ defer iw.setupMutex.Unlock()
+
+ if notificationType != winipcfg.MibAddInstance {
+ return
+ }
+ if iw.tun == nil {
+ iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family})
+ return
+ }
+ if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) {
+ return
+ }
+ iw.setup(iface.Family)
+ })
+ if err != nil {
+ return nil, err
+ }
+ return iw, nil
+}
+
+func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, tun *tun.NativeTun) {
+ iw.setupMutex.Lock()
+ defer iw.setupMutex.Unlock()
+
+ iw.device, iw.conf, iw.tun = device, conf, tun
+ for _, event := range iw.storedEvents {
+ if event.luid == winipcfg.LUID(iw.tun.LUID()) {
+ iw.setup(event.family)
+ }
+ }
+ iw.storedEvents = nil
+}
+
+func (iw *interfaceWatcher) Destroy() {
+ iw.setupMutex.Lock()
+ defer iw.setupMutex.Unlock()
+
+ if iw.tun == nil {
+ return
+ }
+
+ if iw.routeChangeCallback4 != nil {
+ iw.routeChangeCallback4.Unregister()
+ iw.routeChangeCallback4 = nil
+ }
+ if iw.routeChangeCallback6 != nil {
+ iw.routeChangeCallback6.Unregister()
+ iw.routeChangeCallback6 = nil
+ }
+ if iw.interfaceChangeCallback != nil {
+ iw.interfaceChangeCallback.Unregister()
+ iw.interfaceChangeCallback = nil
+ }
+
+ firewall.DisableFirewall()
+
+ // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active
+ // routes, so to be certain, just remove everything before destroying.
+ luid := winipcfg.LUID(iw.tun.LUID())
+ luid.FlushRoutes(windows.AF_INET)
+ luid.FlushIPAddresses(windows.AF_INET)
+ luid.FlushRoutes(windows.AF_INET6)
+ luid.FlushIPAddresses(windows.AF_INET6)
+ luid.FlushDNS()
+}
diff --git a/tunnel/service.go b/tunnel/service.go
index 1978cae0..c0ead084 100644
--- a/tunnel/service.go
+++ b/tunnel/service.go
@@ -26,7 +26,6 @@ import (
"golang.zx2c4.com/wireguard/windows/conf"
"golang.zx2c4.com/wireguard/windows/ringlogger"
"golang.zx2c4.com/wireguard/windows/services"
- "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"golang.zx2c4.com/wireguard/windows/version"
)
@@ -39,7 +38,7 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang
var dev *device.Device
var uapi net.Listener
- var routeChangeCallback *winipcfg.RouteChangeCallback
+ var watcher *interfaceWatcher
var nativeTun *tun.NativeTun
var err error
serviceError := services.ErrorSuccess
@@ -84,11 +83,8 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang
}
}()
- if routeChangeCallback != nil {
- routeChangeCallback.Unregister()
- }
- if nativeTun != nil {
- unconfigureInterface(nativeTun)
+ if watcher != nil {
+ watcher.Destroy()
}
if uapi != nil {
uapi.Close()
@@ -140,6 +136,13 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang
m.Disconnect()
}
+ log.Println("Watching network interfaces")
+ watcher, err = watchInterface()
+ if err != nil {
+ serviceError = services.ErrorSetNetConfig
+ return
+ }
+
log.Println("Resolving DNS names")
uapiConf, err := conf.ToUAPI()
if err != nil {
@@ -197,22 +200,7 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang
log.Println("Bringing peers up")
dev.Up()
- log.Println("Waiting for TCP/IP to attach to interface")
- waitForFamilies(nativeTun) // TODO: move this sort of thing into tun/wintun/CreateInterface
-
- log.Println("Monitoring default routes")
- routeChangeCallback, err = monitorDefaultRoutes(dev, conf.Interface.MTU == 0, nativeTun)
- if err != nil {
- serviceError = services.ErrorBindSocketsToDefaultRoutes
- return
- }
-
- log.Println("Setting device address")
- err = configureInterface(conf, nativeTun)
- if err != nil {
- serviceError = services.ErrorSetNetConfig
- return
- }
+ watcher.Configure(dev, conf, nativeTun)
log.Println("Listening for UAPI requests")
go func() {
@@ -241,6 +229,9 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang
}
case <-dev.Wait():
return
+ case e := <-watcher.errors:
+ serviceError, err = e.serviceError, e.err
+ return
}
}
}
diff --git a/tunnel/winipcfg/luid.go b/tunnel/winipcfg/luid.go
index ff7061d2..396fbbb2 100644
--- a/tunnel/winipcfg/luid.go
+++ b/tunnel/winipcfg/luid.go
@@ -116,6 +116,27 @@ func (luid LUID) SetIPAddresses(addresses []net.IPNet) error {
return luid.AddIPAddresses(addresses)
}
+// SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface.
+func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error {
+ err := luid.FlushIPAddresses(family)
+ if err != nil {
+ return err
+ }
+ for i := range addresses {
+ asV4 := addresses[i].IP.To4()
+ if asV4 == nil && family == windows.AF_INET {
+ continue
+ } else if asV4 != nil && family == windows.AF_INET6 {
+ continue
+ }
+ err := luid.AddIPAddress(addresses[i])
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
// DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry).
func (luid LUID) DeleteIPAddress(address net.IPNet) error {
@@ -210,6 +231,27 @@ func (luid LUID) SetRoutes(routesData []*RouteData) error {
return luid.AddRoutes(routesData)
}
+// SetRoutesForFamily method sets (flush than add) multiple routes for a specific family to the interface.
+func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteData) error {
+ err := luid.FlushRoutes(family)
+ if err != nil {
+ return err
+ }
+ for _, rd := range routesData {
+ asV4 := rd.Destination.IP.To4()
+ if asV4 == nil && family == windows.AF_INET {
+ continue
+ } else if asV4 != nil && family == windows.AF_INET6 {
+ continue
+ }
+ err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
// DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2).
func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error {
@@ -359,3 +401,28 @@ func (luid LUID) SetDNS(dnses []net.IP) error {
}
return runNetsh(cmds)
}
+
+// SetDNSForFamily method clears previous and associates new DNS servers with the adapter for a specific family.
+func (luid LUID) SetDNSForFamily(family AddressFamily, dnses []net.IP) error {
+ var templateFlush string
+ if family == windows.AF_INET {
+ templateFlush = netshCmdTemplateFlush4
+ } else if family == windows.AF_INET6 {
+ templateFlush = netshCmdTemplateFlush6
+ }
+
+ cmds := make([]string, 0, 1+len(dnses))
+ ipif, err := luid.IPInterface(family)
+ if err != nil {
+ return err
+ }
+ cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex))
+ for i := 0; i < len(dnses); i++ {
+ if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET {
+ cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String()))
+ } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 {
+ cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String()))
+ }
+ }
+ return runNetsh(cmds)
+}
diff --git a/tunnel/winipcfg/winipcfg.go b/tunnel/winipcfg/winipcfg.go
index 5af9f1aa..2fc0c875 100644
--- a/tunnel/winipcfg/winipcfg.go
+++ b/tunnel/winipcfg/winipcfg.go
@@ -32,13 +32,13 @@ import (
// GetAdaptersAddresses function retrieves the addresses associated with the adapters on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/iphlpapi/nf-iphlpapi-getadaptersaddresses
-func GetAdaptersAddresses(family uint32, flags GAAFlags) ([]*IPAdapterAddresses, error) {
+func GetAdaptersAddresses(family AddressFamily, flags GAAFlags) ([]*IPAdapterAddresses, error) {
var b []byte
size := uint32(15000)
for {
b = make([]byte, size)
- err := windows.GetAdaptersAddresses(family, uint32(flags), 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &size)
+ err := windows.GetAdaptersAddresses(uint32(family), uint32(flags), 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &size)
if err == nil {
break
}