diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-06-17 13:08:13 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-06-18 12:08:51 +0200 |
commit | b77c634b9eb0e1e732b60667a5974f8340a50253 (patch) | |
tree | 13a2d97bcb3ad8e44f288d3220fd9030d9eb968c | |
parent | version: bump (diff) | |
download | wireguard-windows-b77c634b9eb0e1e732b60667a5974f8340a50253.tar.xz wireguard-windows-b77c634b9eb0e1e732b60667a5974f8340a50253.zip |
tunnel: wait for IP service to attach to wintun
This helps fix startup races without needing to poll, as well as
reconfiguring interfaces after wintun destroys and re-adds. It also
deals gracefully with IPv6 being disabled.
Diffstat (limited to '')
-rw-r--r-- | tunnel/addressconfig.go (renamed from tunnel/ifaceconfig.go) | 81 | ||||
-rw-r--r-- | tunnel/defaultroutemonitor.go | 56 | ||||
-rw-r--r-- | tunnel/interfacewatcher.go | 148 | ||||
-rw-r--r-- | tunnel/service.go | 37 | ||||
-rw-r--r-- | tunnel/winipcfg/luid.go | 67 | ||||
-rw-r--r-- | tunnel/winipcfg/winipcfg.go | 4 |
6 files changed, 268 insertions, 125 deletions
diff --git a/tunnel/ifaceconfig.go b/tunnel/addressconfig.go index a71b612e..a1e5dc59 100644 --- a/tunnel/ifaceconfig.go +++ b/tunnel/addressconfig.go @@ -10,7 +10,6 @@ import ( "log" "net" "sort" - "time" "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/tun" @@ -20,7 +19,7 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func cleanupAddressesOnDisconnectedInterfaces(addresses []net.IPNet) { +func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) { if len(addresses) == 0 { return } @@ -39,7 +38,7 @@ func cleanupAddressesOnDisconnectedInterfaces(addresses []net.IPNet) { } return false } - interfaces, err := winipcfg.GetAdaptersAddresses(windows.AF_UNSPEC, winipcfg.GAAFlagDefault) + interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault) if err != nil { return } @@ -58,7 +57,7 @@ func cleanupAddressesOnDisconnectedInterfaces(addresses []net.IPNet) { } } -func configureInterface(conf *conf.Config, tun *tun.NativeTun) error { +func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error { luid := winipcfg.LUID(tun.LUID()) estimatedRouteCount := len(conf.Interface.Addresses) @@ -114,10 +113,10 @@ func configureInterface(conf *conf.Config, tun *tun.NativeTun) error { } } - err := luid.SetIPAddresses(addresses) + err := luid.SetIPAddressesForFamily(family, addresses) if err == windows.ERROR_OBJECT_ALREADY_EXISTS { - cleanupAddressesOnDisconnectedInterfaces(addresses) - err = luid.SetIPAddresses(addresses) + cleanupAddressesOnDisconnectedInterfaces(family, addresses) + err = luid.SetIPAddressesForFamily(family, addresses) } if err != nil { return err @@ -140,49 +139,38 @@ func configureInterface(conf *conf.Config, tun *tun.NativeTun) error { deduplicatedRoutes = append(deduplicatedRoutes, &routes[i]) } - err = luid.SetRoutes(deduplicatedRoutes) + err = luid.SetRoutesForFamily(family, deduplicatedRoutes) if err != nil { return nil } - ipif, err := luid.IPInterface(windows.AF_INET) + ipif, err := luid.IPInterface(family) if err != nil { return err } - if foundDefault4 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } if conf.Interface.MTU > 0 { ipif.NLMTU = uint32(conf.Interface.MTU) tun.ForceMTU(int(ipif.NLMTU)) } - err = ipif.Set() - if err != nil { - return err - } - - ipif, err = luid.IPInterface(windows.AF_INET6) - if err != nil && firstGateway6 != nil { - log.Printf("Is IPv6 disabled by Windows?") - return err - } else if err == nil { // People seem to like to disable IPv6, so we make this non-fatal. - if foundDefault6 { + if family == windows.AF_INET { + if foundDefault4 { ipif.UseAutomaticMetric = false ipif.Metric = 0 } - if conf.Interface.MTU > 0 { - ipif.NLMTU = uint32(conf.Interface.MTU) + } else if family == windows.AF_INET6 { + if foundDefault6 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 } ipif.DadTransmits = 0 ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled - err = ipif.Set() - if err != nil { - return err - } + } + err = ipif.Set() + if err != nil { + return err } - err = luid.SetDNS(conf.Interface.DNS) + err = luid.SetDNSForFamily(family, conf.Interface.DNS) if err != nil { return err } @@ -190,19 +178,6 @@ func configureInterface(conf *conf.Config, tun *tun.NativeTun) error { return nil } -func unconfigureInterface(tun *tun.NativeTun) { - // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active - // routes, so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(tun.LUID()) - luid.FlushRoutes(windows.AF_INET) - luid.FlushIPAddresses(windows.AF_INET) - luid.FlushRoutes(windows.AF_INET6) - luid.FlushIPAddresses(windows.AF_INET6) - luid.FlushDNS() - - firewall.DisableFirewall() -} - func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { restrictAll := false if len(conf.Peers) == 1 { @@ -224,21 +199,3 @@ func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { } return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS, restrictAll) } - -func waitForFamilies(tun *tun.NativeTun) { - // TODO: This whole thing is a disgusting hack that shouldn't be neccessary. - - f := func(luid winipcfg.LUID, family winipcfg.AddressFamily, maxRetries int) { - for i := 0; i < maxRetries; i++ { - _, err := luid.IPInterface(family) - if i != maxRetries-1 && err == windows.ERROR_NOT_FOUND { - time.Sleep(time.Millisecond * 50) - continue - } - break - } - } - luid := winipcfg.LUID(tun.LUID()) - f(luid, windows.AF_INET, 100) - f(luid, windows.AF_INET6, 3) -} diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go index e9440710..c1722c45 100644 --- a/tunnel/defaultroutemonitor.go +++ b/tunnel/defaultroutemonitor.go @@ -44,28 +44,28 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU *lastLUID = luid *lastIndex = index if family == windows.AF_INET { - log.Printf("Binding UDPv4 socket to interface %d", index) + log.Printf("Binding v4 socket to interface %d", index) return device.BindSocketToInterface4(index) } else if family == windows.AF_INET6 { - log.Printf("Binding UDPv6 socket to interface %d", index) + log.Printf("Binding v6 socket to interface %d", index) return device.BindSocketToInterface6(index) } return nil } -func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { +func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { + var minMTU uint32 + if family == windows.AF_INET { + minMTU = 576 + } else if family == windows.AF_INET6 { + minMTU = 1280 + } ourLUID := winipcfg.LUID(tun.LUID()) - lastLUID4 := winipcfg.LUID(0) - lastLUID6 := winipcfg.LUID(0) - lastIndex4 := uint32(0) - lastIndex6 := uint32(0) + lastLUID := winipcfg.LUID(0) + lastIndex := uint32(0) lastMTU := uint32(0) doIt := func() error { - err := bindSocketRoute(windows.AF_INET, device, ourLUID, &lastLUID4, &lastIndex4) - if err != nil { - return err - } - err = bindSocketRoute(windows.AF_INET6, device, ourLUID, &lastLUID6, &lastIndex6) + err := bindSocketRoute(family, device, ourLUID, &lastLUID, &lastIndex) if err != nil { return err } @@ -73,8 +73,8 @@ func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTu return nil } mtu := uint32(0) - if lastLUID4 != 0 { - iface, err := lastLUID4.Interface() + if lastLUID != 0 { + iface, err := lastLUID.Interface() if err != nil { return err } @@ -82,40 +82,20 @@ func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTu mtu = iface.MTU } } - if lastLUID6 != 0 { - iface, err := lastLUID6.Interface() - if err != nil { - return err - } - if iface.MTU > 0 && iface.MTU < mtu { - mtu = iface.MTU - } - } if mtu > 0 && lastMTU != mtu { - iface, err := ourLUID.IPInterface(windows.AF_INET) + iface, err := ourLUID.IPInterface(family) if err != nil { return err } iface.NLMTU = mtu - 80 - if iface.NLMTU < 576 { - iface.NLMTU = 576 + if iface.NLMTU < minMTU { + iface.NLMTU = minMTU } err = iface.Set() if err != nil { return err } - tun.ForceMTU(int(iface.NLMTU)) // TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now. - iface, err = ourLUID.IPInterface(windows.AF_INET6) - if err == nil { // People seem to like to disable IPv6, so we make this non-fatal. - iface.NLMTU = mtu - 80 - if iface.NLMTU < 1280 { - iface.NLMTU = 1280 - } - err = iface.Set() - if err != nil { - return err - } - } + tun.ForceMTU(int(iface.NLMTU)) // TODO: having one MTU for both v4 and v6 kind of breaks the windows model, so right now this just gets the second one which is... bad. lastMTU = mtu } return nil diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go new file mode 100644 index 00000000..b7a07f77 --- /dev/null +++ b/tunnel/interfacewatcher.go @@ -0,0 +1,148 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package tunnel + +import ( + "log" + "sync" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/services" + "golang.zx2c4.com/wireguard/windows/tunnel/firewall" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +type interfaceWatcherError struct { + serviceError services.Error + err error +} +type interfaceWatcherEvent struct { + luid winipcfg.LUID + family winipcfg.AddressFamily +} +type interfaceWatcher struct { + errors chan interfaceWatcherError + + device *device.Device + conf *conf.Config + tun *tun.NativeTun + + setupMutex sync.Mutex + routeChangeCallback4 *winipcfg.RouteChangeCallback + routeChangeCallback6 *winipcfg.RouteChangeCallback + interfaceChangeCallback *winipcfg.InterfaceChangeCallback + storedEvents []interfaceWatcherEvent +} + +func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { + var routeChangeCallback **winipcfg.RouteChangeCallback + var ipversion string + if family == windows.AF_INET { + routeChangeCallback = &iw.routeChangeCallback4 + ipversion = "v4" + } else if family == windows.AF_INET6 { + routeChangeCallback = &iw.routeChangeCallback6 + ipversion = "v6" + } else { + return + } + if *routeChangeCallback != nil { + (*routeChangeCallback).Unregister() + *routeChangeCallback = nil + } + var err error + + log.Printf("Monitoring default %s routes", ipversion) + *routeChangeCallback, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, iw.tun) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} + return + } + + log.Printf("Setting device %s addresses", ipversion) + err = configureInterface(family, iw.conf, iw.tun) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err} + return + } +} + +func watchInterface() (*interfaceWatcher, error) { + iw := &interfaceWatcher{ + errors: make(chan interfaceWatcherError, 2), + } + var err error + iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { + iw.setupMutex.Lock() + defer iw.setupMutex.Unlock() + + if notificationType != winipcfg.MibAddInstance { + return + } + if iw.tun == nil { + iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family}) + return + } + if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) { + return + } + iw.setup(iface.Family) + }) + if err != nil { + return nil, err + } + return iw, nil +} + +func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, tun *tun.NativeTun) { + iw.setupMutex.Lock() + defer iw.setupMutex.Unlock() + + iw.device, iw.conf, iw.tun = device, conf, tun + for _, event := range iw.storedEvents { + if event.luid == winipcfg.LUID(iw.tun.LUID()) { + iw.setup(event.family) + } + } + iw.storedEvents = nil +} + +func (iw *interfaceWatcher) Destroy() { + iw.setupMutex.Lock() + defer iw.setupMutex.Unlock() + + if iw.tun == nil { + return + } + + if iw.routeChangeCallback4 != nil { + iw.routeChangeCallback4.Unregister() + iw.routeChangeCallback4 = nil + } + if iw.routeChangeCallback6 != nil { + iw.routeChangeCallback6.Unregister() + iw.routeChangeCallback6 = nil + } + if iw.interfaceChangeCallback != nil { + iw.interfaceChangeCallback.Unregister() + iw.interfaceChangeCallback = nil + } + + firewall.DisableFirewall() + + // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active + // routes, so to be certain, just remove everything before destroying. + luid := winipcfg.LUID(iw.tun.LUID()) + luid.FlushRoutes(windows.AF_INET) + luid.FlushIPAddresses(windows.AF_INET) + luid.FlushRoutes(windows.AF_INET6) + luid.FlushIPAddresses(windows.AF_INET6) + luid.FlushDNS() +} diff --git a/tunnel/service.go b/tunnel/service.go index 1978cae0..c0ead084 100644 --- a/tunnel/service.go +++ b/tunnel/service.go @@ -26,7 +26,6 @@ import ( "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/version" ) @@ -39,7 +38,7 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang var dev *device.Device var uapi net.Listener - var routeChangeCallback *winipcfg.RouteChangeCallback + var watcher *interfaceWatcher var nativeTun *tun.NativeTun var err error serviceError := services.ErrorSuccess @@ -84,11 +83,8 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang } }() - if routeChangeCallback != nil { - routeChangeCallback.Unregister() - } - if nativeTun != nil { - unconfigureInterface(nativeTun) + if watcher != nil { + watcher.Destroy() } if uapi != nil { uapi.Close() @@ -140,6 +136,13 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang m.Disconnect() } + log.Println("Watching network interfaces") + watcher, err = watchInterface() + if err != nil { + serviceError = services.ErrorSetNetConfig + return + } + log.Println("Resolving DNS names") uapiConf, err := conf.ToUAPI() if err != nil { @@ -197,22 +200,7 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang log.Println("Bringing peers up") dev.Up() - log.Println("Waiting for TCP/IP to attach to interface") - waitForFamilies(nativeTun) // TODO: move this sort of thing into tun/wintun/CreateInterface - - log.Println("Monitoring default routes") - routeChangeCallback, err = monitorDefaultRoutes(dev, conf.Interface.MTU == 0, nativeTun) - if err != nil { - serviceError = services.ErrorBindSocketsToDefaultRoutes - return - } - - log.Println("Setting device address") - err = configureInterface(conf, nativeTun) - if err != nil { - serviceError = services.ErrorSetNetConfig - return - } + watcher.Configure(dev, conf, nativeTun) log.Println("Listening for UAPI requests") go func() { @@ -241,6 +229,9 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang } case <-dev.Wait(): return + case e := <-watcher.errors: + serviceError, err = e.serviceError, e.err + return } } } diff --git a/tunnel/winipcfg/luid.go b/tunnel/winipcfg/luid.go index ff7061d2..396fbbb2 100644 --- a/tunnel/winipcfg/luid.go +++ b/tunnel/winipcfg/luid.go @@ -116,6 +116,27 @@ func (luid LUID) SetIPAddresses(addresses []net.IPNet) error { return luid.AddIPAddresses(addresses) } +// SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface. +func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error { + err := luid.FlushIPAddresses(family) + if err != nil { + return err + } + for i := range addresses { + asV4 := addresses[i].IP.To4() + if asV4 == nil && family == windows.AF_INET { + continue + } else if asV4 != nil && family == windows.AF_INET6 { + continue + } + err := luid.AddIPAddress(addresses[i]) + if err != nil { + return err + } + } + return nil +} + // DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry). func (luid LUID) DeleteIPAddress(address net.IPNet) error { @@ -210,6 +231,27 @@ func (luid LUID) SetRoutes(routesData []*RouteData) error { return luid.AddRoutes(routesData) } +// SetRoutesForFamily method sets (flush than add) multiple routes for a specific family to the interface. +func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteData) error { + err := luid.FlushRoutes(family) + if err != nil { + return err + } + for _, rd := range routesData { + asV4 := rd.Destination.IP.To4() + if asV4 == nil && family == windows.AF_INET { + continue + } else if asV4 != nil && family == windows.AF_INET6 { + continue + } + err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric) + if err != nil { + return err + } + } + return nil +} + // DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2). func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error { @@ -359,3 +401,28 @@ func (luid LUID) SetDNS(dnses []net.IP) error { } return runNetsh(cmds) } + +// SetDNSForFamily method clears previous and associates new DNS servers with the adapter for a specific family. +func (luid LUID) SetDNSForFamily(family AddressFamily, dnses []net.IP) error { + var templateFlush string + if family == windows.AF_INET { + templateFlush = netshCmdTemplateFlush4 + } else if family == windows.AF_INET6 { + templateFlush = netshCmdTemplateFlush6 + } + + cmds := make([]string, 0, 1+len(dnses)) + ipif, err := luid.IPInterface(family) + if err != nil { + return err + } + cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex)) + for i := 0; i < len(dnses); i++ { + if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String())) + } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String())) + } + } + return runNetsh(cmds) +} diff --git a/tunnel/winipcfg/winipcfg.go b/tunnel/winipcfg/winipcfg.go index 5af9f1aa..2fc0c875 100644 --- a/tunnel/winipcfg/winipcfg.go +++ b/tunnel/winipcfg/winipcfg.go @@ -32,13 +32,13 @@ import ( // GetAdaptersAddresses function retrieves the addresses associated with the adapters on the local computer. // https://docs.microsoft.com/en-us/windows/desktop/api/iphlpapi/nf-iphlpapi-getadaptersaddresses -func GetAdaptersAddresses(family uint32, flags GAAFlags) ([]*IPAdapterAddresses, error) { +func GetAdaptersAddresses(family AddressFamily, flags GAAFlags) ([]*IPAdapterAddresses, error) { var b []byte size := uint32(15000) for { b = make([]byte, size) - err := windows.GetAdaptersAddresses(family, uint32(flags), 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &size) + err := windows.GetAdaptersAddresses(uint32(family), uint32(flags), 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &size) if err == nil { break } |