diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-06-17 13:08:13 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-06-18 12:08:51 +0200 |
commit | b77c634b9eb0e1e732b60667a5974f8340a50253 (patch) | |
tree | 13a2d97bcb3ad8e44f288d3220fd9030d9eb968c /tunnel/addressconfig.go | |
parent | version: bump (diff) | |
download | wireguard-windows-b77c634b9eb0e1e732b60667a5974f8340a50253.tar.xz wireguard-windows-b77c634b9eb0e1e732b60667a5974f8340a50253.zip |
tunnel: wait for IP service to attach to wintun
This helps fix startup races without needing to poll, as well as
reconfiguring interfaces after wintun destroys and re-adds. It also
deals gracefully with IPv6 being disabled.
Diffstat (limited to 'tunnel/addressconfig.go')
-rw-r--r-- | tunnel/addressconfig.go | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go new file mode 100644 index 00000000..a1e5dc59 --- /dev/null +++ b/tunnel/addressconfig.go @@ -0,0 +1,201 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package tunnel + +import ( + "bytes" + "log" + "net" + "sort" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/tun" + + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/tunnel/firewall" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) { + if len(addresses) == 0 { + return + } + includedInAddresses := func(a net.IPNet) bool { + // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer! + for _, addr := range addresses { + ip := addr.IP + if ip4 := ip.To4(); ip4 != nil { + ip = ip4 + } + mA, _ := addr.Mask.Size() + mB, _ := a.Mask.Size() + if bytes.Equal(ip, a.IP) && mA == mB { + return true + } + } + return false + } + interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault) + if err != nil { + return + } + for _, iface := range interfaces { + if iface.OperStatus == winipcfg.IfOperStatusUp { + continue + } + for address := iface.FirstUnicastAddress; address != nil; address = address.Next { + ip := address.Address.IP() + ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))} + if includedInAddresses(ipnet) { + log.Printf("Cleaning up stale address %s from interface '%s'", ipnet.String(), iface.FriendlyName()) + iface.LUID.DeleteIPAddress(ipnet) + } + } + } +} + +func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error { + luid := winipcfg.LUID(tun.LUID()) + + estimatedRouteCount := len(conf.Interface.Addresses) + for _, peer := range conf.Peers { + estimatedRouteCount += len(peer.AllowedIPs) + } + routes := make([]winipcfg.RouteData, 0, estimatedRouteCount) + var firstGateway4 *net.IP + var firstGateway6 *net.IP + addresses := make([]net.IPNet, len(conf.Interface.Addresses)) + for i, addr := range conf.Interface.Addresses { + ipnet := addr.IPNet() + addresses[i] = ipnet + gateway := ipnet.IP.Mask(ipnet.Mask) + if addr.Bits() == 32 && firstGateway4 == nil { + firstGateway4 = &gateway + } else if addr.Bits() == 128 && firstGateway6 == nil { + firstGateway6 = &gateway + } + routes = append(routes, winipcfg.RouteData{ + Destination: net.IPNet{ + IP: gateway, + Mask: ipnet.Mask, + }, + NextHop: gateway, + Metric: 0, + }) + } + + foundDefault4 := false + foundDefault6 := false + for _, peer := range conf.Peers { + for _, allowedip := range peer.AllowedIPs { + if (allowedip.Bits() == 32 && firstGateway4 == nil) || (allowedip.Bits() == 128 && firstGateway6 == nil) { + continue + } + route := winipcfg.RouteData{ + Destination: allowedip.IPNet(), + Metric: 0, + } + if allowedip.Bits() == 32 { + if allowedip.Cidr == 0 { + foundDefault4 = true + } + route.NextHop = *firstGateway4 + } else if allowedip.Bits() == 128 { + if allowedip.Cidr == 0 { + foundDefault6 = true + } + route.NextHop = *firstGateway6 + } + routes = append(routes, route) + } + } + + err := luid.SetIPAddressesForFamily(family, addresses) + if err == windows.ERROR_OBJECT_ALREADY_EXISTS { + cleanupAddressesOnDisconnectedInterfaces(family, addresses) + err = luid.SetIPAddressesForFamily(family, addresses) + } + if err != nil { + return err + } + + deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes)) + sort.Slice(routes, func(i, j int) bool { + return routes[i].Metric < routes[j].Metric || + bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 || + bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 || + bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == -1 + }) + for i := 0; i < len(routes); i++ { + if i > 0 && routes[i].Metric == routes[i-1].Metric && + bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) && + bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) && + bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) { + continue + } + deduplicatedRoutes = append(deduplicatedRoutes, &routes[i]) + } + + err = luid.SetRoutesForFamily(family, deduplicatedRoutes) + if err != nil { + return nil + } + + ipif, err := luid.IPInterface(family) + if err != nil { + return err + } + if conf.Interface.MTU > 0 { + ipif.NLMTU = uint32(conf.Interface.MTU) + tun.ForceMTU(int(ipif.NLMTU)) + } + if family == windows.AF_INET { + if foundDefault4 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 + } + } else if family == windows.AF_INET6 { + if foundDefault6 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 + } + ipif.DadTransmits = 0 + ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + } + err = ipif.Set() + if err != nil { + return err + } + + err = luid.SetDNSForFamily(family, conf.Interface.DNS) + if err != nil { + return err + } + + return nil +} + +func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { + restrictAll := false + if len(conf.Peers) == 1 { + nextallowedip: + for _, allowedip := range conf.Peers[0].AllowedIPs { + if allowedip.Cidr == 0 { + for _, b := range allowedip.IP { + if b != 0 { + continue nextallowedip + } + } + restrictAll = true + break + } + } + } + if restrictAll && len(conf.Interface.DNS) == 0 { + log.Println("Warning: no DNS server specified, despite having an allowed IPs of 0.0.0.0/0 or ::/0. There may be connectivity issues.") + } + return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS, restrictAll) +} |