aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tunnel/addressconfig.go
diff options
context:
space:
mode:
Diffstat (limited to 'tunnel/addressconfig.go')
-rw-r--r--tunnel/addressconfig.go186
1 files changed, 77 insertions, 109 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go
index 777c96cd..a3ce6295 100644
--- a/tunnel/addressconfig.go
+++ b/tunnel/addressconfig.go
@@ -1,42 +1,30 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
import (
- "bytes"
+ "fmt"
"log"
- "net"
- "sort"
+ "net/netip"
+ "time"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
-func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) {
+func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) {
if len(addresses) == 0 {
return
}
- includedInAddresses := func(a net.IPNet) bool {
- // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer!
- for _, addr := range addresses {
- ip := addr.IP
- if ip4 := ip.To4(); ip4 != nil {
- ip = ip4
- }
- mA, _ := addr.Mask.Size()
- mB, _ := a.Mask.Size()
- if bytes.Equal(ip, a.IP) && mA == mB {
- return true
- }
- }
- return false
+ addrHash := make(map[netip.Addr]bool, len(addresses))
+ for i := range addresses {
+ addrHash[addresses[i].Addr()] = true
}
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault)
if err != nil {
@@ -47,144 +35,124 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add
continue
}
for address := iface.FirstUnicastAddress; address != nil; address = address.Next {
- ip := address.Address.IP()
- ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))}
- if includedInAddresses(ipnet) {
- log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", ipnet.String(), iface.FriendlyName())
- iface.LUID.DeleteIPAddress(ipnet)
+ if ip, _ := netip.AddrFromSlice(address.Address.IP()); addrHash[ip] {
+ prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength))
+ log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", prefix.String(), iface.FriendlyName())
+ iface.LUID.DeleteIPAddress(prefix)
}
}
}
}
-func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error {
- luid := winipcfg.LUID(tun.LUID())
+func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) error {
+ retryOnFailure := services.StartedAtBoot()
+ tryTimes := 0
+startOver:
+ var err error
+ if tryTimes > 0 {
+ log.Printf("Retrying interface configuration after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ time.Sleep(time.Second)
+ retryOnFailure = retryOnFailure && tryTimes < 15
+ }
+ tryTimes++
estimatedRouteCount := 0
for _, peer := range conf.Peers {
estimatedRouteCount += len(peer.AllowedIPs)
}
- routes := make([]winipcfg.RouteData, 0, estimatedRouteCount)
- addresses := make([]net.IPNet, len(conf.Interface.Addresses))
- var haveV4Address, haveV6Address bool
- for i, addr := range conf.Interface.Addresses {
- addresses[i] = addr.IPNet()
- if addr.Bits() == 32 {
- haveV4Address = true
- } else if addr.Bits() == 128 {
- haveV6Address = true
- }
- }
+ routes := make(map[winipcfg.RouteData]bool, estimatedRouteCount)
foundDefault4 := false
foundDefault6 := false
for _, peer := range conf.Peers {
for _, allowedip := range peer.AllowedIPs {
- if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) {
- continue
- }
route := winipcfg.RouteData{
- Destination: allowedip.IPNet(),
+ Destination: allowedip.Masked(),
Metric: 0,
}
- if allowedip.Bits() == 32 {
- if allowedip.Cidr == 0 {
+ if allowedip.Addr().Is4() {
+ if allowedip.Bits() == 0 {
foundDefault4 = true
}
- route.NextHop = net.IPv4zero
- } else if allowedip.Bits() == 128 {
- if allowedip.Cidr == 0 {
+ route.NextHop = netip.IPv4Unspecified()
+ } else if allowedip.Addr().Is6() {
+ if allowedip.Bits() == 0 {
foundDefault6 = true
}
- route.NextHop = net.IPv6zero
+ route.NextHop = netip.IPv6Unspecified()
}
- routes = append(routes, route)
+ routes[route] = true
}
}
- err := luid.SetIPAddressesForFamily(family, addresses)
- if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
- cleanupAddressesOnDisconnectedInterfaces(family, addresses)
- err = luid.SetIPAddressesForFamily(family, addresses)
- }
- if err != nil {
- return err
+ deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
+ for route := range routes {
+ r := route
+ deduplicatedRoutes = append(deduplicatedRoutes, &r)
}
- deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
- sort.Slice(routes, func(i, j int) bool {
- return routes[i].Metric < routes[j].Metric ||
- bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 ||
- bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 ||
- bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == -1
- })
- for i := 0; i < len(routes); i++ {
- if i > 0 && routes[i].Metric == routes[i-1].Metric &&
- bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) &&
- bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
- bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
- continue
+ if !conf.Interface.TableOff {
+ err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set routes: %w", err)
}
- deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
}
- err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
- if err != nil {
- return nil
+ err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses)
+ if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
+ cleanupAddressesOnDisconnectedInterfaces(family, conf.Interface.Addresses)
+ err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses)
+ }
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set ips: %w", err)
}
- ipif, err := luid.IPInterface(family)
+ var ipif *winipcfg.MibIPInterfaceRow
+ ipif, err = luid.IPInterface(family)
if err != nil {
return err
}
+ ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
+ ipif.DadTransmits = 0
+ ipif.ManagedAddressConfigurationSupported = false
+ ipif.OtherStatefulConfigurationSupported = false
if conf.Interface.MTU > 0 {
ipif.NLMTU = uint32(conf.Interface.MTU)
- tun.ForceMTU(int(ipif.NLMTU))
}
- if family == windows.AF_INET {
- if foundDefault4 {
- ipif.UseAutomaticMetric = false
- ipif.Metric = 0
- }
- } else if family == windows.AF_INET6 {
- if foundDefault6 {
- ipif.UseAutomaticMetric = false
- ipif.Metric = 0
- }
- ipif.DadTransmits = 0
- ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
+ if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) {
+ ipif.UseAutomaticMetric = false
+ ipif.Metric = 0
}
err = ipif.Set()
- if err != nil {
- return err
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set metric and MTU: %w", err)
}
- err = luid.SetDNSForFamily(family, conf.Interface.DNS)
- if err != nil {
- return err
+ err = luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch)
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set DNS: %w", err)
}
-
return nil
}
-func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error {
- restrictAll := false
- if len(conf.Peers) == 1 {
- nextallowedip:
+func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error {
+ doNotRestrict := true
+ if len(conf.Peers) == 1 && !conf.Interface.TableOff {
for _, allowedip := range conf.Peers[0].AllowedIPs {
- if allowedip.Cidr == 0 {
- for _, b := range allowedip.IP {
- if b != 0 {
- continue nextallowedip
- }
- }
- restrictAll = true
+ if allowedip.Bits() == 0 && allowedip == allowedip.Masked() {
+ doNotRestrict = false
break
}
}
}
- if restrictAll && len(conf.Interface.DNS) == 0 {
- log.Println("Warning: no DNS server specified, despite having an allowed IPs of 0.0.0.0/0 or ::/0. There may be connectivity issues.")
- }
- return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS, restrictAll)
+ log.Println("Enabling firewall rules")
+ return firewall.EnableFirewall(uint64(luid), doNotRestrict, conf.Interface.DNS)
}