From b77c634b9eb0e1e732b60667a5974f8340a50253 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 17 Jun 2019 13:08:13 +0200 Subject: tunnel: wait for IP service to attach to wintun This helps fix startup races without needing to poll, as well as reconfiguring interfaces after wintun destroys and re-adds. It also deals gracefully with IPv6 being disabled. --- tunnel/addressconfig.go | 201 ++++++++++++++++++++++++++++++++++ tunnel/defaultroutemonitor.go | 56 ++++------ tunnel/ifaceconfig.go | 244 ------------------------------------------ tunnel/interfacewatcher.go | 148 +++++++++++++++++++++++++ tunnel/service.go | 37 +++---- tunnel/winipcfg/luid.go | 67 ++++++++++++ tunnel/winipcfg/winipcfg.go | 4 +- 7 files changed, 450 insertions(+), 307 deletions(-) create mode 100644 tunnel/addressconfig.go delete mode 100644 tunnel/ifaceconfig.go create mode 100644 tunnel/interfacewatcher.go diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go new file mode 100644 index 00000000..a1e5dc59 --- /dev/null +++ b/tunnel/addressconfig.go @@ -0,0 +1,201 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package tunnel + +import ( + "bytes" + "log" + "net" + "sort" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/tun" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/tunnel/firewall" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) { + if len(addresses) == 0 { + return + } + includedInAddresses := func(a net.IPNet) bool { + // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer! + for _, addr := range addresses { + ip := addr.IP + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + mA, _ := addr.Mask.Size() + mB, _ := a.Mask.Size() + if bytes.Equal(ip, a.IP) && mA == mB { + return true + } + } + return false + } + interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault) + if err != nil { + return + } + for _, iface := range interfaces { + if iface.OperStatus == winipcfg.IfOperStatusUp { + continue + } + for address := iface.FirstUnicastAddress; address != nil; address = address.Next { + ip := address.Address.IP() + ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))} + if includedInAddresses(ipnet) { + log.Printf("Cleaning up stale address %s from interface '%s'", ipnet.String(), iface.FriendlyName()) + iface.LUID.DeleteIPAddress(ipnet) + } + } + } +} + +func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error { + luid := winipcfg.LUID(tun.LUID()) + + estimatedRouteCount := len(conf.Interface.Addresses) + for _, peer := range conf.Peers { + estimatedRouteCount += len(peer.AllowedIPs) + } + routes := make([]winipcfg.RouteData, 0, estimatedRouteCount) + var firstGateway4 *net.IP + var firstGateway6 *net.IP + addresses := make([]net.IPNet, len(conf.Interface.Addresses)) + for i, addr := range conf.Interface.Addresses { + ipnet := addr.IPNet() + addresses[i] = ipnet + gateway := ipnet.IP.Mask(ipnet.Mask) + if addr.Bits() == 32 && firstGateway4 == nil { + firstGateway4 = &gateway + } else if addr.Bits() == 128 && firstGateway6 == nil { + firstGateway6 = &gateway + } + routes = append(routes, winipcfg.RouteData{ + Destination: net.IPNet{ + IP: gateway, + Mask: ipnet.Mask, + }, + NextHop: gateway, + Metric: 0, + }) + } + + foundDefault4 := false + foundDefault6 := false + for _, peer := range conf.Peers { + for _, allowedip := range peer.AllowedIPs { + if (allowedip.Bits() == 32 && firstGateway4 == nil) || (allowedip.Bits() == 128 && firstGateway6 == nil) { + continue + } + route := winipcfg.RouteData{ + Destination: allowedip.IPNet(), + Metric: 0, + } + if allowedip.Bits() == 32 { + if allowedip.Cidr == 0 { + foundDefault4 = true + } + route.NextHop = *firstGateway4 + } else if allowedip.Bits() == 128 { + if allowedip.Cidr == 0 { + foundDefault6 = true + } + route.NextHop = *firstGateway6 + } + routes = append(routes, route) + } + } + + err := luid.SetIPAddressesForFamily(family, addresses) + if err == windows.ERROR_OBJECT_ALREADY_EXISTS { + cleanupAddressesOnDisconnectedInterfaces(family, addresses) + err = luid.SetIPAddressesForFamily(family, addresses) + } + if err != nil { + return err + } + + deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes)) + sort.Slice(routes, func(i, j int) bool { + return routes[i].Metric < routes[j].Metric || + bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 || + bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 || + bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == -1 + }) + for i := 0; i < len(routes); i++ { + if i > 0 && routes[i].Metric == routes[i-1].Metric && + bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) && + bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) && + bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) { + continue + } + deduplicatedRoutes = append(deduplicatedRoutes, &routes[i]) + } + + err = luid.SetRoutesForFamily(family, deduplicatedRoutes) + if err != nil { + return nil + } + + ipif, err := luid.IPInterface(family) + if err != nil { + return err + } + if conf.Interface.MTU > 0 { + ipif.NLMTU = uint32(conf.Interface.MTU) + tun.ForceMTU(int(ipif.NLMTU)) + } + if family == windows.AF_INET { + if foundDefault4 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 + } + } else if family == windows.AF_INET6 { + if foundDefault6 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 + } + ipif.DadTransmits = 0 + ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + } + err = ipif.Set() + if err != nil { + return err + } + + err = luid.SetDNSForFamily(family, conf.Interface.DNS) + if err != nil { + return err + } + + return nil +} + +func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { + restrictAll := false + if len(conf.Peers) == 1 { + nextallowedip: + for _, allowedip := range conf.Peers[0].AllowedIPs { + if allowedip.Cidr == 0 { + for _, b := range allowedip.IP { + if b != 0 { + continue nextallowedip + } + } + restrictAll = true + break + } + } + } + if restrictAll && len(conf.Interface.DNS) == 0 { + log.Println("Warning: no DNS server specified, despite having an allowed IPs of 0.0.0.0/0 or ::/0. There may be connectivity issues.") + } + return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS, restrictAll) +} diff --git a/tunnel/defaultroutemonitor.go b/tunnel/defaultroutemonitor.go index e9440710..c1722c45 100644 --- a/tunnel/defaultroutemonitor.go +++ b/tunnel/defaultroutemonitor.go @@ -44,28 +44,28 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU *lastLUID = luid *lastIndex = index if family == windows.AF_INET { - log.Printf("Binding UDPv4 socket to interface %d", index) + log.Printf("Binding v4 socket to interface %d", index) return device.BindSocketToInterface4(index) } else if family == windows.AF_INET6 { - log.Printf("Binding UDPv6 socket to interface %d", index) + log.Printf("Binding v6 socket to interface %d", index) return device.BindSocketToInterface6(index) } return nil } -func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { +func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { + var minMTU uint32 + if family == windows.AF_INET { + minMTU = 576 + } else if family == windows.AF_INET6 { + minMTU = 1280 + } ourLUID := winipcfg.LUID(tun.LUID()) - lastLUID4 := winipcfg.LUID(0) - lastLUID6 := winipcfg.LUID(0) - lastIndex4 := uint32(0) - lastIndex6 := uint32(0) + lastLUID := winipcfg.LUID(0) + lastIndex := uint32(0) lastMTU := uint32(0) doIt := func() error { - err := bindSocketRoute(windows.AF_INET, device, ourLUID, &lastLUID4, &lastIndex4) - if err != nil { - return err - } - err = bindSocketRoute(windows.AF_INET6, device, ourLUID, &lastLUID6, &lastIndex6) + err := bindSocketRoute(family, device, ourLUID, &lastLUID, &lastIndex) if err != nil { return err } @@ -73,8 +73,8 @@ func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTu return nil } mtu := uint32(0) - if lastLUID4 != 0 { - iface, err := lastLUID4.Interface() + if lastLUID != 0 { + iface, err := lastLUID.Interface() if err != nil { return err } @@ -82,40 +82,20 @@ func monitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTu mtu = iface.MTU } } - if lastLUID6 != 0 { - iface, err := lastLUID6.Interface() - if err != nil { - return err - } - if iface.MTU > 0 && iface.MTU < mtu { - mtu = iface.MTU - } - } if mtu > 0 && lastMTU != mtu { - iface, err := ourLUID.IPInterface(windows.AF_INET) + iface, err := ourLUID.IPInterface(family) if err != nil { return err } iface.NLMTU = mtu - 80 - if iface.NLMTU < 576 { - iface.NLMTU = 576 + if iface.NLMTU < minMTU { + iface.NLMTU = minMTU } err = iface.Set() if err != nil { return err } - tun.ForceMTU(int(iface.NLMTU)) // TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now. - iface, err = ourLUID.IPInterface(windows.AF_INET6) - if err == nil { // People seem to like to disable IPv6, so we make this non-fatal. - iface.NLMTU = mtu - 80 - if iface.NLMTU < 1280 { - iface.NLMTU = 1280 - } - err = iface.Set() - if err != nil { - return err - } - } + tun.ForceMTU(int(iface.NLMTU)) // TODO: having one MTU for both v4 and v6 kind of breaks the windows model, so right now this just gets the second one which is... bad. lastMTU = mtu } return nil diff --git a/tunnel/ifaceconfig.go b/tunnel/ifaceconfig.go deleted file mode 100644 index a71b612e..00000000 --- a/tunnel/ifaceconfig.go +++ /dev/null @@ -1,244 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package tunnel - -import ( - "bytes" - "log" - "net" - "sort" - "time" - - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/tun" - - "golang.zx2c4.com/wireguard/windows/conf" - "golang.zx2c4.com/wireguard/windows/tunnel/firewall" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" -) - -func cleanupAddressesOnDisconnectedInterfaces(addresses []net.IPNet) { - if len(addresses) == 0 { - return - } - includedInAddresses := func(a net.IPNet) bool { - // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer! - for _, addr := range addresses { - ip := addr.IP - if ip4 := ip.To4(); ip4 != nil { - ip = ip4 - } - mA, _ := addr.Mask.Size() - mB, _ := a.Mask.Size() - if bytes.Equal(ip, a.IP) && mA == mB { - return true - } - } - return false - } - interfaces, err := winipcfg.GetAdaptersAddresses(windows.AF_UNSPEC, winipcfg.GAAFlagDefault) - if err != nil { - return - } - for _, iface := range interfaces { - if iface.OperStatus == winipcfg.IfOperStatusUp { - continue - } - for address := iface.FirstUnicastAddress; address != nil; address = address.Next { - ip := address.Address.IP() - ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))} - if includedInAddresses(ipnet) { - log.Printf("Cleaning up stale address %s from interface '%s'", ipnet.String(), iface.FriendlyName()) - iface.LUID.DeleteIPAddress(ipnet) - } - } - } -} - -func configureInterface(conf *conf.Config, tun *tun.NativeTun) error { - luid := winipcfg.LUID(tun.LUID()) - - estimatedRouteCount := len(conf.Interface.Addresses) - for _, peer := range conf.Peers { - estimatedRouteCount += len(peer.AllowedIPs) - } - routes := make([]winipcfg.RouteData, 0, estimatedRouteCount) - var firstGateway4 *net.IP - var firstGateway6 *net.IP - addresses := make([]net.IPNet, len(conf.Interface.Addresses)) - for i, addr := range conf.Interface.Addresses { - ipnet := addr.IPNet() - addresses[i] = ipnet - gateway := ipnet.IP.Mask(ipnet.Mask) - if addr.Bits() == 32 && firstGateway4 == nil { - firstGateway4 = &gateway - } else if addr.Bits() == 128 && firstGateway6 == nil { - firstGateway6 = &gateway - } - routes = append(routes, winipcfg.RouteData{ - Destination: net.IPNet{ - IP: gateway, - Mask: ipnet.Mask, - }, - NextHop: gateway, - Metric: 0, - }) - } - - foundDefault4 := false - foundDefault6 := false - for _, peer := range conf.Peers { - for _, allowedip := range peer.AllowedIPs { - if (allowedip.Bits() == 32 && firstGateway4 == nil) || (allowedip.Bits() == 128 && firstGateway6 == nil) { - continue - } - route := winipcfg.RouteData{ - Destination: allowedip.IPNet(), - Metric: 0, - } - if allowedip.Bits() == 32 { - if allowedip.Cidr == 0 { - foundDefault4 = true - } - route.NextHop = *firstGateway4 - } else if allowedip.Bits() == 128 { - if allowedip.Cidr == 0 { - foundDefault6 = true - } - route.NextHop = *firstGateway6 - } - routes = append(routes, route) - } - } - - err := luid.SetIPAddresses(addresses) - if err == windows.ERROR_OBJECT_ALREADY_EXISTS { - cleanupAddressesOnDisconnectedInterfaces(addresses) - err = luid.SetIPAddresses(addresses) - } - if err != nil { - return err - } - - deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes)) - sort.Slice(routes, func(i, j int) bool { - return routes[i].Metric < routes[j].Metric || - bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 || - bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 || - bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == -1 - }) - for i := 0; i < len(routes); i++ { - if i > 0 && routes[i].Metric == routes[i-1].Metric && - bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) && - bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) && - bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) { - continue - } - deduplicatedRoutes = append(deduplicatedRoutes, &routes[i]) - } - - err = luid.SetRoutes(deduplicatedRoutes) - if err != nil { - return nil - } - - ipif, err := luid.IPInterface(windows.AF_INET) - if err != nil { - return err - } - if foundDefault4 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - if conf.Interface.MTU > 0 { - ipif.NLMTU = uint32(conf.Interface.MTU) - tun.ForceMTU(int(ipif.NLMTU)) - } - err = ipif.Set() - if err != nil { - return err - } - - ipif, err = luid.IPInterface(windows.AF_INET6) - if err != nil && firstGateway6 != nil { - log.Printf("Is IPv6 disabled by Windows?") - return err - } else if err == nil { // People seem to like to disable IPv6, so we make this non-fatal. - if foundDefault6 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - if conf.Interface.MTU > 0 { - ipif.NLMTU = uint32(conf.Interface.MTU) - } - ipif.DadTransmits = 0 - ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled - err = ipif.Set() - if err != nil { - return err - } - } - - err = luid.SetDNS(conf.Interface.DNS) - if err != nil { - return err - } - - return nil -} - -func unconfigureInterface(tun *tun.NativeTun) { - // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active - // routes, so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(tun.LUID()) - luid.FlushRoutes(windows.AF_INET) - luid.FlushIPAddresses(windows.AF_INET) - luid.FlushRoutes(windows.AF_INET6) - luid.FlushIPAddresses(windows.AF_INET6) - luid.FlushDNS() - - firewall.DisableFirewall() -} - -func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { - restrictAll := false - if len(conf.Peers) == 1 { - nextallowedip: - for _, allowedip := range conf.Peers[0].AllowedIPs { - if allowedip.Cidr == 0 { - for _, b := range allowedip.IP { - if b != 0 { - continue nextallowedip - } - } - restrictAll = true - break - } - } - } - if restrictAll && len(conf.Interface.DNS) == 0 { - log.Println("Warning: no DNS server specified, despite having an allowed IPs of 0.0.0.0/0 or ::/0. There may be connectivity issues.") - } - return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS, restrictAll) -} - -func waitForFamilies(tun *tun.NativeTun) { - // TODO: This whole thing is a disgusting hack that shouldn't be neccessary. - - f := func(luid winipcfg.LUID, family winipcfg.AddressFamily, maxRetries int) { - for i := 0; i < maxRetries; i++ { - _, err := luid.IPInterface(family) - if i != maxRetries-1 && err == windows.ERROR_NOT_FOUND { - time.Sleep(time.Millisecond * 50) - continue - } - break - } - } - luid := winipcfg.LUID(tun.LUID()) - f(luid, windows.AF_INET, 100) - f(luid, windows.AF_INET6, 3) -} diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go new file mode 100644 index 00000000..b7a07f77 --- /dev/null +++ b/tunnel/interfacewatcher.go @@ -0,0 +1,148 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package tunnel + +import ( + "log" + "sync" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/services" + "golang.zx2c4.com/wireguard/windows/tunnel/firewall" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +type interfaceWatcherError struct { + serviceError services.Error + err error +} +type interfaceWatcherEvent struct { + luid winipcfg.LUID + family winipcfg.AddressFamily +} +type interfaceWatcher struct { + errors chan interfaceWatcherError + + device *device.Device + conf *conf.Config + tun *tun.NativeTun + + setupMutex sync.Mutex + routeChangeCallback4 *winipcfg.RouteChangeCallback + routeChangeCallback6 *winipcfg.RouteChangeCallback + interfaceChangeCallback *winipcfg.InterfaceChangeCallback + storedEvents []interfaceWatcherEvent +} + +func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { + var routeChangeCallback **winipcfg.RouteChangeCallback + var ipversion string + if family == windows.AF_INET { + routeChangeCallback = &iw.routeChangeCallback4 + ipversion = "v4" + } else if family == windows.AF_INET6 { + routeChangeCallback = &iw.routeChangeCallback6 + ipversion = "v6" + } else { + return + } + if *routeChangeCallback != nil { + (*routeChangeCallback).Unregister() + *routeChangeCallback = nil + } + var err error + + log.Printf("Monitoring default %s routes", ipversion) + *routeChangeCallback, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, iw.tun) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} + return + } + + log.Printf("Setting device %s addresses", ipversion) + err = configureInterface(family, iw.conf, iw.tun) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err} + return + } +} + +func watchInterface() (*interfaceWatcher, error) { + iw := &interfaceWatcher{ + errors: make(chan interfaceWatcherError, 2), + } + var err error + iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { + iw.setupMutex.Lock() + defer iw.setupMutex.Unlock() + + if notificationType != winipcfg.MibAddInstance { + return + } + if iw.tun == nil { + iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family}) + return + } + if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) { + return + } + iw.setup(iface.Family) + }) + if err != nil { + return nil, err + } + return iw, nil +} + +func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, tun *tun.NativeTun) { + iw.setupMutex.Lock() + defer iw.setupMutex.Unlock() + + iw.device, iw.conf, iw.tun = device, conf, tun + for _, event := range iw.storedEvents { + if event.luid == winipcfg.LUID(iw.tun.LUID()) { + iw.setup(event.family) + } + } + iw.storedEvents = nil +} + +func (iw *interfaceWatcher) Destroy() { + iw.setupMutex.Lock() + defer iw.setupMutex.Unlock() + + if iw.tun == nil { + return + } + + if iw.routeChangeCallback4 != nil { + iw.routeChangeCallback4.Unregister() + iw.routeChangeCallback4 = nil + } + if iw.routeChangeCallback6 != nil { + iw.routeChangeCallback6.Unregister() + iw.routeChangeCallback6 = nil + } + if iw.interfaceChangeCallback != nil { + iw.interfaceChangeCallback.Unregister() + iw.interfaceChangeCallback = nil + } + + firewall.DisableFirewall() + + // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active + // routes, so to be certain, just remove everything before destroying. + luid := winipcfg.LUID(iw.tun.LUID()) + luid.FlushRoutes(windows.AF_INET) + luid.FlushIPAddresses(windows.AF_INET) + luid.FlushRoutes(windows.AF_INET6) + luid.FlushIPAddresses(windows.AF_INET6) + luid.FlushDNS() +} diff --git a/tunnel/service.go b/tunnel/service.go index 1978cae0..c0ead084 100644 --- a/tunnel/service.go +++ b/tunnel/service.go @@ -26,7 +26,6 @@ import ( "golang.zx2c4.com/wireguard/windows/conf" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/version" ) @@ -39,7 +38,7 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang var dev *device.Device var uapi net.Listener - var routeChangeCallback *winipcfg.RouteChangeCallback + var watcher *interfaceWatcher var nativeTun *tun.NativeTun var err error serviceError := services.ErrorSuccess @@ -84,11 +83,8 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang } }() - if routeChangeCallback != nil { - routeChangeCallback.Unregister() - } - if nativeTun != nil { - unconfigureInterface(nativeTun) + if watcher != nil { + watcher.Destroy() } if uapi != nil { uapi.Close() @@ -140,6 +136,13 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang m.Disconnect() } + log.Println("Watching network interfaces") + watcher, err = watchInterface() + if err != nil { + serviceError = services.ErrorSetNetConfig + return + } + log.Println("Resolving DNS names") uapiConf, err := conf.ToUAPI() if err != nil { @@ -197,22 +200,7 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang log.Println("Bringing peers up") dev.Up() - log.Println("Waiting for TCP/IP to attach to interface") - waitForFamilies(nativeTun) // TODO: move this sort of thing into tun/wintun/CreateInterface - - log.Println("Monitoring default routes") - routeChangeCallback, err = monitorDefaultRoutes(dev, conf.Interface.MTU == 0, nativeTun) - if err != nil { - serviceError = services.ErrorBindSocketsToDefaultRoutes - return - } - - log.Println("Setting device address") - err = configureInterface(conf, nativeTun) - if err != nil { - serviceError = services.ErrorSetNetConfig - return - } + watcher.Configure(dev, conf, nativeTun) log.Println("Listening for UAPI requests") go func() { @@ -241,6 +229,9 @@ func (service *Service) Execute(args []string, r <-chan svc.ChangeRequest, chang } case <-dev.Wait(): return + case e := <-watcher.errors: + serviceError, err = e.serviceError, e.err + return } } } diff --git a/tunnel/winipcfg/luid.go b/tunnel/winipcfg/luid.go index ff7061d2..396fbbb2 100644 --- a/tunnel/winipcfg/luid.go +++ b/tunnel/winipcfg/luid.go @@ -116,6 +116,27 @@ func (luid LUID) SetIPAddresses(addresses []net.IPNet) error { return luid.AddIPAddresses(addresses) } +// SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface. +func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error { + err := luid.FlushIPAddresses(family) + if err != nil { + return err + } + for i := range addresses { + asV4 := addresses[i].IP.To4() + if asV4 == nil && family == windows.AF_INET { + continue + } else if asV4 != nil && family == windows.AF_INET6 { + continue + } + err := luid.AddIPAddress(addresses[i]) + if err != nil { + return err + } + } + return nil +} + // DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry). func (luid LUID) DeleteIPAddress(address net.IPNet) error { @@ -210,6 +231,27 @@ func (luid LUID) SetRoutes(routesData []*RouteData) error { return luid.AddRoutes(routesData) } +// SetRoutesForFamily method sets (flush than add) multiple routes for a specific family to the interface. +func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteData) error { + err := luid.FlushRoutes(family) + if err != nil { + return err + } + for _, rd := range routesData { + asV4 := rd.Destination.IP.To4() + if asV4 == nil && family == windows.AF_INET { + continue + } else if asV4 != nil && family == windows.AF_INET6 { + continue + } + err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric) + if err != nil { + return err + } + } + return nil +} + // DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2). func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error { @@ -359,3 +401,28 @@ func (luid LUID) SetDNS(dnses []net.IP) error { } return runNetsh(cmds) } + +// SetDNSForFamily method clears previous and associates new DNS servers with the adapter for a specific family. +func (luid LUID) SetDNSForFamily(family AddressFamily, dnses []net.IP) error { + var templateFlush string + if family == windows.AF_INET { + templateFlush = netshCmdTemplateFlush4 + } else if family == windows.AF_INET6 { + templateFlush = netshCmdTemplateFlush6 + } + + cmds := make([]string, 0, 1+len(dnses)) + ipif, err := luid.IPInterface(family) + if err != nil { + return err + } + cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex)) + for i := 0; i < len(dnses); i++ { + if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String())) + } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String())) + } + } + return runNetsh(cmds) +} diff --git a/tunnel/winipcfg/winipcfg.go b/tunnel/winipcfg/winipcfg.go index 5af9f1aa..2fc0c875 100644 --- a/tunnel/winipcfg/winipcfg.go +++ b/tunnel/winipcfg/winipcfg.go @@ -32,13 +32,13 @@ import ( // GetAdaptersAddresses function retrieves the addresses associated with the adapters on the local computer. // https://docs.microsoft.com/en-us/windows/desktop/api/iphlpapi/nf-iphlpapi-getadaptersaddresses -func GetAdaptersAddresses(family uint32, flags GAAFlags) ([]*IPAdapterAddresses, error) { +func GetAdaptersAddresses(family AddressFamily, flags GAAFlags) ([]*IPAdapterAddresses, error) { var b []byte size := uint32(15000) for { b = make([]byte, size) - err := windows.GetAdaptersAddresses(family, uint32(flags), 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &size) + err := windows.GetAdaptersAddresses(uint32(family), uint32(flags), 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &size) if err == nil { break } -- cgit v1.2.3-59-g8ed1b