diff options
Diffstat (limited to 'tunnel/addressconfig.go')
-rw-r--r-- | tunnel/addressconfig.go | 186 |
1 files changed, 77 insertions, 109 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go index 777c96cd..a3ce6295 100644 --- a/tunnel/addressconfig.go +++ b/tunnel/addressconfig.go @@ -1,42 +1,30 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel import ( - "bytes" + "fmt" "log" - "net" - "sort" + "net/netip" + "time" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) { +func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) { if len(addresses) == 0 { return } - includedInAddresses := func(a net.IPNet) bool { - // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer! - for _, addr := range addresses { - ip := addr.IP - if ip4 := ip.To4(); ip4 != nil { - ip = ip4 - } - mA, _ := addr.Mask.Size() - mB, _ := a.Mask.Size() - if bytes.Equal(ip, a.IP) && mA == mB { - return true - } - } - return false + addrHash := make(map[netip.Addr]bool, len(addresses)) + for i := range addresses { + addrHash[addresses[i].Addr()] = true } interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault) if err != nil { @@ -47,144 +35,124 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add continue } for address := iface.FirstUnicastAddress; address != nil; address = address.Next { - ip := address.Address.IP() - ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))} - if includedInAddresses(ipnet) { - log.Printf("Cleaning up stale address %s from interface ā%sā", ipnet.String(), iface.FriendlyName()) - iface.LUID.DeleteIPAddress(ipnet) + if ip, _ := netip.AddrFromSlice(address.Address.IP()); addrHash[ip] { + prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength)) + log.Printf("Cleaning up stale address %s from interface ā%sā", prefix.String(), iface.FriendlyName()) + iface.LUID.DeleteIPAddress(prefix) } } } } -func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error { - luid := winipcfg.LUID(tun.LUID()) +func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) error { + retryOnFailure := services.StartedAtBoot() + tryTimes := 0 +startOver: + var err error + if tryTimes > 0 { + log.Printf("Retrying interface configuration after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + time.Sleep(time.Second) + retryOnFailure = retryOnFailure && tryTimes < 15 + } + tryTimes++ estimatedRouteCount := 0 for _, peer := range conf.Peers { estimatedRouteCount += len(peer.AllowedIPs) } - routes := make([]winipcfg.RouteData, 0, estimatedRouteCount) - addresses := make([]net.IPNet, len(conf.Interface.Addresses)) - var haveV4Address, haveV6Address bool - for i, addr := range conf.Interface.Addresses { - addresses[i] = addr.IPNet() - if addr.Bits() == 32 { - haveV4Address = true - } else if addr.Bits() == 128 { - haveV6Address = true - } - } + routes := make(map[winipcfg.RouteData]bool, estimatedRouteCount) foundDefault4 := false foundDefault6 := false for _, peer := range conf.Peers { for _, allowedip := range peer.AllowedIPs { - if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) { - continue - } route := winipcfg.RouteData{ - Destination: allowedip.IPNet(), + Destination: allowedip.Masked(), Metric: 0, } - if allowedip.Bits() == 32 { - if allowedip.Cidr == 0 { + if allowedip.Addr().Is4() { + if allowedip.Bits() == 0 { foundDefault4 = true } - route.NextHop = net.IPv4zero - } else if allowedip.Bits() == 128 { - if allowedip.Cidr == 0 { + route.NextHop = netip.IPv4Unspecified() + } else if allowedip.Addr().Is6() { + if allowedip.Bits() == 0 { foundDefault6 = true } - route.NextHop = net.IPv6zero + route.NextHop = netip.IPv6Unspecified() } - routes = append(routes, route) + routes[route] = true } } - err := luid.SetIPAddressesForFamily(family, addresses) - if err == windows.ERROR_OBJECT_ALREADY_EXISTS { - cleanupAddressesOnDisconnectedInterfaces(family, addresses) - err = luid.SetIPAddressesForFamily(family, addresses) - } - if err != nil { - return err + deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes)) + for route := range routes { + r := route + deduplicatedRoutes = append(deduplicatedRoutes, &r) } - deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes)) - sort.Slice(routes, func(i, j int) bool { - return routes[i].Metric < routes[j].Metric || - bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 || - bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 || - bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == -1 - }) - for i := 0; i < len(routes); i++ { - if i > 0 && routes[i].Metric == routes[i-1].Metric && - bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) && - bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) && - bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) { - continue + if !conf.Interface.TableOff { + err = luid.SetRoutesForFamily(family, deduplicatedRoutes) + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set routes: %w", err) } - deduplicatedRoutes = append(deduplicatedRoutes, &routes[i]) } - err = luid.SetRoutesForFamily(family, deduplicatedRoutes) - if err != nil { - return nil + err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses) + if err == windows.ERROR_OBJECT_ALREADY_EXISTS { + cleanupAddressesOnDisconnectedInterfaces(family, conf.Interface.Addresses) + err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses) + } + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set ips: %w", err) } - ipif, err := luid.IPInterface(family) + var ipif *winipcfg.MibIPInterfaceRow + ipif, err = luid.IPInterface(family) if err != nil { return err } + ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + ipif.DadTransmits = 0 + ipif.ManagedAddressConfigurationSupported = false + ipif.OtherStatefulConfigurationSupported = false if conf.Interface.MTU > 0 { ipif.NLMTU = uint32(conf.Interface.MTU) - tun.ForceMTU(int(ipif.NLMTU)) } - if family == windows.AF_INET { - if foundDefault4 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - } else if family == windows.AF_INET6 { - if foundDefault6 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - ipif.DadTransmits = 0 - ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 } err = ipif.Set() - if err != nil { - return err + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set metric and MTU: %w", err) } - err = luid.SetDNSForFamily(family, conf.Interface.DNS) - if err != nil { - return err + err = luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch) + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set DNS: %w", err) } - return nil } -func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { - restrictAll := false - if len(conf.Peers) == 1 { - nextallowedip: +func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error { + doNotRestrict := true + if len(conf.Peers) == 1 && !conf.Interface.TableOff { for _, allowedip := range conf.Peers[0].AllowedIPs { - if allowedip.Cidr == 0 { - for _, b := range allowedip.IP { - if b != 0 { - continue nextallowedip - } - } - restrictAll = true + if allowedip.Bits() == 0 && allowedip == allowedip.Masked() { + doNotRestrict = false break } } } - if restrictAll && len(conf.Interface.DNS) == 0 { - log.Println("Warning: no DNS server specified, despite having an allowed IPs of 0.0.0.0/0 or ::/0. There may be connectivity issues.") - } - return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS, restrictAll) + log.Println("Enabling firewall rules") + return firewall.EnableFirewall(uint64(luid), doNotRestrict, conf.Interface.DNS) } |