diff options
Diffstat (limited to 'tunnel')
33 files changed, 638 insertions, 839 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go index 44bfd8ae..a3ce6295 100644 --- a/tunnel/addressconfig.go +++ b/tunnel/addressconfig.go @@ -1,42 +1,30 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel import ( - "bytes" + "fmt" "log" - "net" - "sort" + "net/netip" + "time" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) { +func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) { if len(addresses) == 0 { return } - includedInAddresses := func(a net.IPNet) bool { - // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer! - for _, addr := range addresses { - ip := addr.IP - if ip4 := ip.To4(); ip4 != nil { - ip = ip4 - } - mA, _ := addr.Mask.Size() - mB, _ := a.Mask.Size() - if bytes.Equal(ip, a.IP) && mA == mB { - return true - } - } - return false + addrHash := make(map[netip.Addr]bool, len(addresses)) + for i := range addresses { + addrHash[addresses[i].Addr()] = true } interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault) if err != nil { @@ -47,147 +35,124 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add continue } for address := iface.FirstUnicastAddress; address != nil; address = address.Next { - ip := address.Address.IP() - ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))} - if includedInAddresses(ipnet) { - log.Printf("Cleaning up stale address %s from interface ā%sā", ipnet.String(), iface.FriendlyName()) - iface.LUID.DeleteIPAddress(ipnet) + if ip, _ := netip.AddrFromSlice(address.Address.IP()); addrHash[ip] { + prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength)) + log.Printf("Cleaning up stale address %s from interface ā%sā", prefix.String(), iface.FriendlyName()) + iface.LUID.DeleteIPAddress(prefix) } } } } -func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error { - luid := winipcfg.LUID(tun.LUID()) +func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) error { + retryOnFailure := services.StartedAtBoot() + tryTimes := 0 +startOver: + var err error + if tryTimes > 0 { + log.Printf("Retrying interface configuration after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + time.Sleep(time.Second) + retryOnFailure = retryOnFailure && tryTimes < 15 + } + tryTimes++ estimatedRouteCount := 0 for _, peer := range conf.Peers { estimatedRouteCount += len(peer.AllowedIPs) } - routes := make([]winipcfg.RouteData, 0, estimatedRouteCount) - addresses := make([]net.IPNet, len(conf.Interface.Addresses)) - var haveV4Address, haveV6Address bool - for i, addr := range conf.Interface.Addresses { - addresses[i] = addr.IPNet() - if addr.Bits() == 32 { - haveV4Address = true - } else if addr.Bits() == 128 { - haveV6Address = true - } - } + routes := make(map[winipcfg.RouteData]bool, estimatedRouteCount) foundDefault4 := false foundDefault6 := false for _, peer := range conf.Peers { for _, allowedip := range peer.AllowedIPs { - allowedip.MaskSelf() - if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) { - continue - } route := winipcfg.RouteData{ - Destination: allowedip.IPNet(), + Destination: allowedip.Masked(), Metric: 0, } - if allowedip.Bits() == 32 { - if allowedip.Cidr == 0 { + if allowedip.Addr().Is4() { + if allowedip.Bits() == 0 { foundDefault4 = true } - route.NextHop = net.IPv4zero - } else if allowedip.Bits() == 128 { - if allowedip.Cidr == 0 { + route.NextHop = netip.IPv4Unspecified() + } else if allowedip.Addr().Is6() { + if allowedip.Bits() == 0 { foundDefault6 = true } - route.NextHop = net.IPv6zero + route.NextHop = netip.IPv6Unspecified() } - routes = append(routes, route) + routes[route] = true } } - err := luid.SetIPAddressesForFamily(family, addresses) - if err == windows.ERROR_OBJECT_ALREADY_EXISTS { - cleanupAddressesOnDisconnectedInterfaces(family, addresses) - err = luid.SetIPAddressesForFamily(family, addresses) - } - if err != nil { - return err + deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes)) + for route := range routes { + r := route + deduplicatedRoutes = append(deduplicatedRoutes, &r) } - deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes)) - sort.Slice(routes, func(i, j int) bool { - if routes[i].Metric != routes[j].Metric { - return routes[i].Metric < routes[j].Metric - } - if c := bytes.Compare(routes[i].NextHop, routes[j].NextHop); c != 0 { - return c < 0 + if !conf.Interface.TableOff { + err = luid.SetRoutesForFamily(family, deduplicatedRoutes) + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set routes: %w", err) } - if c := bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP); c != 0 { - return c < 0 - } - if c := bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask); c != 0 { - return c < 0 - } - return false - }) - for i := 0; i < len(routes); i++ { - if i > 0 && routes[i].Metric == routes[i-1].Metric && - bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) && - bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) && - bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) { - continue - } - deduplicatedRoutes = append(deduplicatedRoutes, &routes[i]) } - err = luid.SetRoutesForFamily(family, deduplicatedRoutes) - if err != nil { - return err + err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses) + if err == windows.ERROR_OBJECT_ALREADY_EXISTS { + cleanupAddressesOnDisconnectedInterfaces(family, conf.Interface.Addresses) + err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses) + } + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set ips: %w", err) } - ipif, err := luid.IPInterface(family) + var ipif *winipcfg.MibIPInterfaceRow + ipif, err = luid.IPInterface(family) if err != nil { return err } + ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + ipif.DadTransmits = 0 + ipif.ManagedAddressConfigurationSupported = false + ipif.OtherStatefulConfigurationSupported = false if conf.Interface.MTU > 0 { ipif.NLMTU = uint32(conf.Interface.MTU) - tun.ForceMTU(int(ipif.NLMTU)) } - if family == windows.AF_INET { - if foundDefault4 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - } else if family == windows.AF_INET6 { - if foundDefault6 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - ipif.DadTransmits = 0 - ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 } err = ipif.Set() - if err != nil { - return err + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set metric and MTU: %w", err) } - return luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch) + err = luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch) + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set DNS: %w", err) + } + return nil } -func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { +func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error { doNotRestrict := true - if len(conf.Peers) == 1 { - nextallowedip: + if len(conf.Peers) == 1 && !conf.Interface.TableOff { for _, allowedip := range conf.Peers[0].AllowedIPs { - if allowedip.Cidr == 0 { - for _, b := range allowedip.IP { - if b != 0 { - continue nextallowedip - } - } + if allowedip.Bits() == 0 && allowedip == allowedip.Masked() { doNotRestrict = false break } } } log.Println("Enabling firewall rules") - return firewall.EnableFirewall(tun.LUID(), doNotRestrict, conf.Interface.DNS) + return firewall.EnableFirewall(uint64(luid), doNotRestrict, conf.Interface.DNS) } diff --git a/tunnel/deterministicguid.go b/tunnel/deterministicguid.go index 455deaeb..405d33a3 100644 --- a/tunnel/deterministicguid.go +++ b/tunnel/deterministicguid.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel @@ -18,8 +18,10 @@ import ( "golang.zx2c4.com/wireguard/windows/conf" ) -const deterministicGUIDLabel = "Deterministic WireGuard Windows GUID v1 jason@zx2c4.com" -const fixedGUIDLabel = "Fixed WireGuard Windows GUID v1 jason@zx2c4.com" +const ( + deterministicGUIDLabel = "Deterministic WireGuard Windows GUID v1 jason@zx2c4.com" + fixedGUIDLabel = "Fixed WireGuard Windows GUID v1 jason@zx2c4.com" +) // Escape hatch for external consumers, not us. var UseFixedGUIDInsteadOfDeterministic = false @@ -80,13 +82,13 @@ func deterministicGUID(c *conf.Config) *windows.GUID { b2Number(len(peer.AllowedIPs)) sortedAllowedIPs := peer.AllowedIPs sort.Slice(sortedAllowedIPs, func(i, j int) bool { - if bi, bj := sortedAllowedIPs[i].Bits(), sortedAllowedIPs[j].Bits(); bi != bj { + if bi, bj := sortedAllowedIPs[i].Addr().BitLen(), sortedAllowedIPs[j].Addr().BitLen(); bi != bj { return bi < bj } - if sortedAllowedIPs[i].Cidr != sortedAllowedIPs[j].Cidr { - return sortedAllowedIPs[i].Cidr < sortedAllowedIPs[j].Cidr + if sortedAllowedIPs[i].Bits() != sortedAllowedIPs[j].Bits() { + return sortedAllowedIPs[i].Bits() < sortedAllowedIPs[j].Bits() } - return bytes.Compare(sortedAllowedIPs[i].IP[:], sortedAllowedIPs[j].IP[:]) < 0 + return sortedAllowedIPs[i].Addr().Compare(sortedAllowedIPs[j].Addr()) < 0 }) for _, allowedip := range sortedAllowedIPs { b2String(allowedip.String()) diff --git a/tunnel/firewall/blocker.go b/tunnel/firewall/blocker.go index 2cb2c7f2..8a4967ba 100644 --- a/tunnel/firewall/blocker.go +++ b/tunnel/firewall/blocker.go @@ -1,13 +1,13 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall import ( "errors" - "net" + "net/netip" "unsafe" "golang.org/x/sys/windows" @@ -101,7 +101,7 @@ func registerBaseObjects(session uintptr) (*baseObjects, error) { return bo, nil } -func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []net.IP) error { +func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []netip.Addr) error { if wfpSession != 0 { return errors.New("The firewall has already been enabled") } diff --git a/tunnel/firewall/helpers.go b/tunnel/firewall/helpers.go index e3e4eac6..46e43aa5 100644 --- a/tunnel/firewall/helpers.go +++ b/tunnel/firewall/helpers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall diff --git a/tunnel/firewall/mksyscall.go b/tunnel/firewall/mksyscall.go index d5ff98aa..fc108007 100644 --- a/tunnel/firewall/mksyscall.go +++ b/tunnel/firewall/mksyscall.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall diff --git a/tunnel/firewall/rules.go b/tunnel/firewall/rules.go index c4488a31..41632f98 100644 --- a/tunnel/firewall/rules.go +++ b/tunnel/firewall/rules.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall @@ -8,7 +8,7 @@ package firewall import ( "encoding/binary" "errors" - "net" + "net/netip" "runtime" "unsafe" @@ -582,7 +582,6 @@ func permitDHCPIPv6(session uintptr, baseObjects *baseObjects, weight uint8) err } func permitNdp(session uintptr, baseObjects *baseObjects, weight uint8) error { - /* TODO: actually handle the hop limit somehow! The rules should vaguely be: * - icmpv6 133: must be outgoing, dst must be FF02::2/128, hop limit must be 255 * - icmpv6 134: must be incoming, src must be FE80::/10, hop limit must be 255 @@ -985,7 +984,7 @@ func blockAll(session uintptr, baseObjects *baseObjects, weight uint8) error { } // Block all DNS traffic except towards specified DNS servers. -func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weightAllow uint8, weightDeny uint8) error { +func blockDNS(except []netip.Addr, session uintptr, baseObjects *baseObjects, weightAllow, weightDeny uint8) error { if weightDeny >= weightAllow { return errors.New("The allow weight must be greater than the deny weight") } @@ -1106,8 +1105,7 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight allowConditionsV4 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except)) allowConditionsV4 = append(allowConditionsV4, denyConditions...) for _, ip := range except { - ip4 := ip.To4() - if ip4 == nil { + if !ip.Is4() { continue } allowConditionsV4 = append(allowConditionsV4, wtFwpmFilterCondition0{ @@ -1115,7 +1113,7 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight matchType: cFWP_MATCH_EQUAL, conditionValue: wtFwpConditionValue0{ _type: cFWP_UINT32, - value: uintptr(binary.BigEndian.Uint32(ip4)), + value: uintptr(binary.BigEndian.Uint32(ip.AsSlice())), }, }) } @@ -1124,11 +1122,10 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight allowConditionsV6 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except)) allowConditionsV6 = append(allowConditionsV6, denyConditions...) for _, ip := range except { - if ip.To4() != nil { + if !ip.Is6() { continue } - var address wtFwpByteArray16 - copy(address.byteArray16[:], ip) + address := wtFwpByteArray16{byteArray16: ip.As16()} allowConditionsV6 = append(allowConditionsV6, wtFwpmFilterCondition0{ fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS, matchType: cFWP_MATCH_EQUAL, diff --git a/tunnel/firewall/syscall_windows.go b/tunnel/firewall/syscall_windows.go index 661527d9..4d8eea42 100644 --- a/tunnel/firewall/syscall_windows.go +++ b/tunnel/firewall/syscall_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall diff --git a/tunnel/firewall/types_windows.go b/tunnel/firewall/types_windows.go index 075daae4..54e2aad7 100644 --- a/tunnel/firewall/types_windows.go +++ b/tunnel/firewall/types_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall @@ -148,8 +148,10 @@ var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{ Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57}, } -var cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT -var cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT +var ( + cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT + cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT +) // 7bc43cbf-37ba-45f1-b74a-82ff518eeb10 var cFWPM_CONDITION_L2_FLAGS = windows.GUID{ diff --git a/tunnel/firewall/types_windows_32.go b/tunnel/firewall/types_windows_32.go index 83aced6e..29ae553a 100644 --- a/tunnel/firewall/types_windows_32.go +++ b/tunnel/firewall/types_windows_32.go @@ -1,8 +1,8 @@ -// +build 386 arm +//go:build 386 || arm /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall diff --git a/tunnel/firewall/types_windows_64.go b/tunnel/firewall/types_windows_64.go index 6e60aa5b..a476a745 100644 --- a/tunnel/firewall/types_windows_64.go +++ b/tunnel/firewall/types_windows_64.go @@ -1,8 +1,8 @@ -// +build amd64 arm64 +//go:build amd64 || arm64 /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall diff --git a/tunnel/firewall/types_windows_test.go b/tunnel/firewall/types_windows_test.go index 683e52ea..afa1988f 100644 --- a/tunnel/firewall/types_windows_test.go +++ b/tunnel/firewall/types_windows_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall @@ -11,7 +11,6 @@ import ( ) func TestWtFwpByteBlobSize(t *testing.T) { - const actualWtFwpByteBlobSize = unsafe.Sizeof(wtFwpByteBlob{}) if actualWtFwpByteBlobSize != wtFwpByteBlob_Size { @@ -21,7 +20,6 @@ func TestWtFwpByteBlobSize(t *testing.T) { } func TestWtFwpByteBlobOffsets(t *testing.T) { - s := wtFwpByteBlob{} sp := uintptr(unsafe.Pointer(&s)) @@ -34,7 +32,6 @@ func TestWtFwpByteBlobOffsets(t *testing.T) { } func TestWtFwpmAction0Size(t *testing.T) { - const actualWtFwpmAction0Size = unsafe.Sizeof(wtFwpmAction0{}) if actualWtFwpmAction0Size != wtFwpmAction0_Size { @@ -44,7 +41,6 @@ func TestWtFwpmAction0Size(t *testing.T) { } func TestWtFwpmAction0Offsets(t *testing.T) { - s := wtFwpmAction0{} sp := uintptr(unsafe.Pointer(&s)) @@ -58,7 +54,6 @@ func TestWtFwpmAction0Offsets(t *testing.T) { } func TestWtFwpBitmapArray64Size(t *testing.T) { - const actualWtFwpBitmapArray64Size = unsafe.Sizeof(wtFwpBitmapArray64{}) if actualWtFwpBitmapArray64Size != wtFwpBitmapArray64_Size { @@ -68,7 +63,6 @@ func TestWtFwpBitmapArray64Size(t *testing.T) { } func TestWtFwpByteArray6Size(t *testing.T) { - const actualWtFwpByteArray6Size = unsafe.Sizeof(wtFwpByteArray6{}) if actualWtFwpByteArray6Size != wtFwpByteArray6_Size { @@ -78,7 +72,6 @@ func TestWtFwpByteArray6Size(t *testing.T) { } func TestWtFwpByteArray16Size(t *testing.T) { - const actualWtFwpByteArray16Size = unsafe.Sizeof(wtFwpByteArray16{}) if actualWtFwpByteArray16Size != wtFwpByteArray16_Size { @@ -88,7 +81,6 @@ func TestWtFwpByteArray16Size(t *testing.T) { } func TestWtFwpConditionValue0Size(t *testing.T) { - const actualWtFwpConditionValue0Size = unsafe.Sizeof(wtFwpConditionValue0{}) if actualWtFwpConditionValue0Size != wtFwpConditionValue0_Size { @@ -98,7 +90,6 @@ func TestWtFwpConditionValue0Size(t *testing.T) { } func TestWtFwpConditionValue0Offsets(t *testing.T) { - s := wtFwpConditionValue0{} sp := uintptr(unsafe.Pointer(&s)) @@ -111,7 +102,6 @@ func TestWtFwpConditionValue0Offsets(t *testing.T) { } func TestWtFwpV4AddrAndMaskSize(t *testing.T) { - const actualWtFwpV4AddrAndMaskSize = unsafe.Sizeof(wtFwpV4AddrAndMask{}) if actualWtFwpV4AddrAndMaskSize != wtFwpV4AddrAndMask_Size { @@ -121,7 +111,6 @@ func TestWtFwpV4AddrAndMaskSize(t *testing.T) { } func TestWtFwpV4AddrAndMaskOffsets(t *testing.T) { - s := wtFwpV4AddrAndMask{} sp := uintptr(unsafe.Pointer(&s)) @@ -135,7 +124,6 @@ func TestWtFwpV4AddrAndMaskOffsets(t *testing.T) { } func TestWtFwpV6AddrAndMaskSize(t *testing.T) { - const actualWtFwpV6AddrAndMaskSize = unsafe.Sizeof(wtFwpV6AddrAndMask{}) if actualWtFwpV6AddrAndMaskSize != wtFwpV6AddrAndMask_Size { @@ -145,7 +133,6 @@ func TestWtFwpV6AddrAndMaskSize(t *testing.T) { } func TestWtFwpV6AddrAndMaskOffsets(t *testing.T) { - s := wtFwpV6AddrAndMask{} sp := uintptr(unsafe.Pointer(&s)) @@ -159,7 +146,6 @@ func TestWtFwpV6AddrAndMaskOffsets(t *testing.T) { } func TestWtFwpValue0Size(t *testing.T) { - const actualWtFwpValue0Size = unsafe.Sizeof(wtFwpValue0{}) if actualWtFwpValue0Size != wtFwpValue0_Size { @@ -168,7 +154,6 @@ func TestWtFwpValue0Size(t *testing.T) { } func TestWtFwpValue0Offsets(t *testing.T) { - s := wtFwpValue0{} sp := uintptr(unsafe.Pointer(&s)) @@ -181,7 +166,6 @@ func TestWtFwpValue0Offsets(t *testing.T) { } func TestWtFwpmDisplayData0Size(t *testing.T) { - const actualWtFwpmDisplayData0Size = unsafe.Sizeof(wtFwpmDisplayData0{}) if actualWtFwpmDisplayData0Size != wtFwpmDisplayData0_Size { @@ -191,7 +175,6 @@ func TestWtFwpmDisplayData0Size(t *testing.T) { } func TestWtFwpmDisplayData0Offsets(t *testing.T) { - s := wtFwpmDisplayData0{} sp := uintptr(unsafe.Pointer(&s)) @@ -205,7 +188,6 @@ func TestWtFwpmDisplayData0Offsets(t *testing.T) { } func TestWtFwpmFilterCondition0Size(t *testing.T) { - const actualWtFwpmFilterCondition0Size = unsafe.Sizeof(wtFwpmFilterCondition0{}) if actualWtFwpmFilterCondition0Size != wtFwpmFilterCondition0_Size { @@ -215,7 +197,6 @@ func TestWtFwpmFilterCondition0Size(t *testing.T) { } func TestWtFwpmFilterCondition0Offsets(t *testing.T) { - s := wtFwpmFilterCondition0{} sp := uintptr(unsafe.Pointer(&s)) @@ -237,7 +218,6 @@ func TestWtFwpmFilterCondition0Offsets(t *testing.T) { } func TestWtFwpmFilter0Size(t *testing.T) { - const actualWtFwpmFilter0Size = unsafe.Sizeof(wtFwpmFilter0{}) if actualWtFwpmFilter0Size != wtFwpmFilter0_Size { @@ -247,7 +227,6 @@ func TestWtFwpmFilter0Size(t *testing.T) { } func TestWtFwpmFilter0Offsets(t *testing.T) { - s := wtFwpmFilter0{} sp := uintptr(unsafe.Pointer(&s)) @@ -364,7 +343,6 @@ func TestWtFwpmFilter0Offsets(t *testing.T) { } func TestWtFwpProvider0Size(t *testing.T) { - const actualWtFwpProvider0Size = unsafe.Sizeof(wtFwpProvider0{}) if actualWtFwpProvider0Size != wtFwpProvider0_Size { @@ -374,7 +352,6 @@ func TestWtFwpProvider0Size(t *testing.T) { } func TestWtFwpProvider0Offsets(t *testing.T) { - s := wtFwpProvider0{} sp := uintptr(unsafe.Pointer(&s)) @@ -412,7 +389,6 @@ func TestWtFwpProvider0Offsets(t *testing.T) { } func TestWtFwpmSession0Size(t *testing.T) { - const actualWtFwpmSession0Size = unsafe.Sizeof(wtFwpmSession0{}) if actualWtFwpmSession0Size != wtFwpmSession0_Size { @@ -422,7 +398,6 @@ func TestWtFwpmSession0Size(t *testing.T) { } func TestWtFwpmSession0Offsets(t *testing.T) { - s := wtFwpmSession0{} sp := uintptr(unsafe.Pointer(&s)) @@ -482,7 +457,6 @@ func TestWtFwpmSession0Offsets(t *testing.T) { } func TestWtFwpmSublayer0Size(t *testing.T) { - const actualWtFwpmSublayer0Size = unsafe.Sizeof(wtFwpmSublayer0{}) if actualWtFwpmSublayer0Size != wtFwpmSublayer0_Size { @@ -492,7 +466,6 @@ func TestWtFwpmSublayer0Size(t *testing.T) { } func TestWtFwpmSublayer0Offsets(t *testing.T) { - s := wtFwpmSublayer0{} sp := uintptr(unsafe.Pointer(&s)) diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go index e12e5929..a831d06e 100644 --- a/tunnel/interfacewatcher.go +++ b/tunnel/interfacewatcher.go @@ -1,20 +1,20 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel import ( + "errors" + "fmt" "log" "sync" + "time" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" @@ -24,63 +24,30 @@ type interfaceWatcherError struct { serviceError services.Error err error } + type interfaceWatcherEvent struct { luid winipcfg.LUID family winipcfg.AddressFamily } + type interfaceWatcher struct { - errors chan interfaceWatcherError + errors chan interfaceWatcherError + started chan winipcfg.AddressFamily - binder conn.BindSocketToInterface - conf *conf.Config - tun *tun.NativeTun + conf *conf.Config + adapter *driver.Adapter + luid winipcfg.LUID setupMutex sync.Mutex interfaceChangeCallback winipcfg.ChangeCallback changeCallbacks4 []winipcfg.ChangeCallback changeCallbacks6 []winipcfg.ChangeCallback storedEvents []interfaceWatcherEvent -} - -func hasDefaultRoute(family winipcfg.AddressFamily, peers []conf.Peer) bool { - var ( - foundV401 bool - foundV41281 bool - foundV600001 bool - foundV680001 bool - foundV400 bool - foundV600 bool - v40 = [4]byte{} - v60 = [16]byte{} - v48 = [4]byte{0x80} - v68 = [16]byte{0x80} - ) - for _, peer := range peers { - for _, allowedip := range peer.AllowedIPs { - if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { - foundV600001 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) { - foundV680001 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { - foundV401 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) { - foundV41281 = true - } else if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { - foundV600 = true - } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { - foundV400 = true - } - } - } - if family == windows.AF_INET { - return foundV400 || (foundV401 && foundV41281) - } else if family == windows.AF_INET6 { - return foundV600 || (foundV600001 && foundV680001) - } - return false + watchdog *time.Timer } func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { + iw.watchdog.Stop() var changeCallbacks *[]winipcfg.ChangeCallback var ipversion string if family == windows.AF_INET { @@ -100,25 +67,35 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { } var err error - log.Printf("Monitoring default %s routes", ipversion) - *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun) - if err != nil { - iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} - return + if iw.conf.Interface.MTU == 0 { + log.Printf("Monitoring MTU of default %s routes", ipversion) + *changeCallbacks, err = monitorMTU(family, iw.luid) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err} + return + } } log.Printf("Setting device %s addresses", ipversion) - err = configureInterface(family, iw.conf, iw.tun) + err = configureInterface(family, iw.conf, iw.luid) if err != nil { iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err} return } + evaluateDynamicPitfalls(family, iw.conf, iw.luid) + + iw.started <- family } func watchInterface() (*interfaceWatcher, error) { iw := &interfaceWatcher{ - errors: make(chan interfaceWatcherError, 2), + errors: make(chan interfaceWatcherError, 2), + started: make(chan winipcfg.AddressFamily, 4), } + iw.watchdog = time.AfterFunc(time.Duration(1<<63-1), func() { + iw.errors <- interfaceWatcherError{services.ErrorCreateNetworkAdapter, errors.New("TCP/IP interface for adapter did not appear after one minute")} + }) + iw.watchdog.Stop() var err error iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { iw.setupMutex.Lock() @@ -127,28 +104,41 @@ func watchInterface() (*interfaceWatcher, error) { if notificationType != winipcfg.MibAddInstance { return } - if iw.tun == nil { + if iw.luid == 0 { iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family}) return } - if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) { + if iface.InterfaceLUID != iw.luid { return } iw.setup(iface.Family) + + if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown { + log.Println("Reinitializing adapter configuration") + err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration()) + if err != nil { + log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err)) + } + err = iw.adapter.SetAdapterState(driver.AdapterStateUp) + if err != nil { + log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err)) + } + } }) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to register interface change callback: %w", err) } return iw, nil } -func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, conf *conf.Config, tun *tun.NativeTun) { +func (iw *interfaceWatcher) Configure(adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) { iw.setupMutex.Lock() defer iw.setupMutex.Unlock() + iw.watchdog.Reset(time.Minute) - iw.binder, iw.conf, iw.tun = binder, conf, tun + iw.adapter, iw.conf, iw.luid = adapter, conf, luid for _, event := range iw.storedEvents { - if event.luid == winipcfg.LUID(iw.tun.LUID()) { + if event.luid == luid { iw.setup(event.family) } } @@ -157,10 +147,11 @@ func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, conf *c func (iw *interfaceWatcher) Destroy() { iw.setupMutex.Lock() + iw.watchdog.Stop() changeCallbacks4 := iw.changeCallbacks4 changeCallbacks6 := iw.changeCallbacks6 interfaceChangeCallback := iw.interfaceChangeCallback - tun := iw.tun + luid := iw.luid iw.setupMutex.Unlock() if interfaceChangeCallback != nil { @@ -186,10 +177,9 @@ func (iw *interfaceWatcher) Destroy() { changeCallbacks6 = changeCallbacks6[1:] } firewall.DisableFirewall() - if tun != nil && iw.tun == tun { + if luid != 0 && iw.luid == luid { // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active // routes, so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(tun.LUID()) luid.FlushRoutes(windows.AF_INET) luid.FlushIPAddresses(windows.AF_INET) luid.FlushDNS(windows.AF_INET) diff --git a/tunnel/ipcpermissions.go b/tunnel/ipcpermissions.go deleted file mode 100644 index 3a676e4b..00000000 --- a/tunnel/ipcpermissions.go +++ /dev/null @@ -1,63 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. - */ - -package tunnel - -import ( - "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/ipc" - - "golang.zx2c4.com/wireguard/windows/conf" -) - -func CopyConfigOwnerToIPCSecurityDescriptor(filename string) error { - if conf.PathIsEncrypted(filename) { - return nil - } - - fileSd, err := windows.GetNamedSecurityInfo(filename, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) - if err != nil { - return err - } - fileOwner, _, err := fileSd.Owner() - if err != nil { - return err - } - if fileOwner.IsWellKnown(windows.WinLocalSystemSid) { - return nil - } - additionalEntries := []windows.EXPLICIT_ACCESS{{ - AccessPermissions: windows.GENERIC_ALL, - AccessMode: windows.GRANT_ACCESS, - Trustee: windows.TRUSTEE{ - TrusteeForm: windows.TRUSTEE_IS_SID, - TrusteeType: windows.TRUSTEE_IS_USER, - TrusteeValue: windows.TrusteeValueFromSID(fileOwner), - }, - }} - - sd, err := ipc.UAPISecurityDescriptor.ToAbsolute() - if err != nil { - return err - } - dacl, defaulted, _ := sd.DACL() - - newDacl, err := windows.ACLFromEntries(additionalEntries, dacl) - if err != nil { - return err - } - err = sd.SetDACL(newDacl, true, defaulted) - if err != nil { - return err - } - sd, err = sd.ToSelfRelative() - if err != nil { - return err - } - ipc.UAPISecurityDescriptor = sd - - return nil -} diff --git a/tunnel/defaultroutemonitor.go b/tunnel/mtumonitor.go index aa0db675..c07823a2 100644 --- a/tunnel/defaultroutemonitor.go +++ b/tunnel/mtumonitor.go @@ -1,30 +1,23 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel import ( - "log" - "sync" - "time" - "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func bindSocketRoute(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32, blackholeWhenLoop bool) error { +func findDefaultLUID(family winipcfg.AddressFamily, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32) error { r, err := winipcfg.GetIPForwardTable2(family) if err != nil { return err } lowestMetric := ^uint32(0) - index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want. - luid := winipcfg.LUID(0) // Hopefully luid zero is unspecified, but hard to find docs saying so. + index := uint32(0) + luid := winipcfg.LUID(0) for i := range r { if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID { continue @@ -50,36 +43,24 @@ func bindSocketRoute(family winipcfg.AddressFamily, binder conn.BindSocketToInte } *lastLUID = luid *lastIndex = index - blackhole := blackholeWhenLoop && index == 0 - if family == windows.AF_INET { - log.Printf("Binding v4 socket to interface %d (blackhole=%v)", index, blackhole) - return binder.BindSocketToInterface4(index, blackhole) - } else if family == windows.AF_INET6 { - log.Printf("Binding v6 socket to interface %d (blackhole=%v)", index, blackhole) - return binder.BindSocketToInterface6(index, blackhole) - } return nil } -func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, autoMTU bool, blackholeWhenLoop bool, tun *tun.NativeTun) ([]winipcfg.ChangeCallback, error) { +func monitorMTU(family winipcfg.AddressFamily, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) { var minMTU uint32 if family == windows.AF_INET { minMTU = 576 } else if family == windows.AF_INET6 { minMTU = 1280 } - ourLUID := winipcfg.LUID(tun.LUID()) lastLUID := winipcfg.LUID(0) lastIndex := ^uint32(0) lastMTU := uint32(0) doIt := func() error { - err := bindSocketRoute(family, binder, ourLUID, &lastLUID, &lastIndex, blackholeWhenLoop) + err := findDefaultLUID(family, ourLUID, &lastLUID, &lastIndex) if err != nil { return err } - if !autoMTU { - return nil - } mtu := uint32(0) if lastLUID != 0 { iface, err := lastLUID.Interface() @@ -103,7 +84,6 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketT if err != nil { return err } - tun.ForceMTU(int(iface.NLMTU)) // TODO: having one MTU for both v4 and v6 kind of breaks the windows model, so right now this just gets the second one which is... bad. lastMTU = mtu } return nil @@ -112,32 +92,9 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketT if err != nil { return nil, err } - - firstBurst := time.Time{} - burstMutex := sync.Mutex{} - burstTimer := time.AfterFunc(time.Hour*200, func() { - burstMutex.Lock() - firstBurst = time.Time{} - doIt() - burstMutex.Unlock() - }) - burstTimer.Stop() - bump := func() { - burstMutex.Lock() - burstTimer.Reset(time.Millisecond * 150) - if firstBurst.IsZero() { - firstBurst = time.Now() - } else if time.Since(firstBurst) > time.Second*2 { - firstBurst = time.Time{} - burstTimer.Stop() - doIt() - } - burstMutex.Unlock() - } - cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { if route != nil && route.DestinationPrefix.PrefixLength == 0 { - bump() + doIt() } }) if err != nil { @@ -145,7 +102,7 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketT } cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { if notificationType == winipcfg.MibParameterNotification { - bump() + doIt() } }) if err != nil { diff --git a/tunnel/pitfalls.go b/tunnel/pitfalls.go new file mode 100644 index 00000000..fdef6eb2 --- /dev/null +++ b/tunnel/pitfalls.go @@ -0,0 +1,177 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package tunnel + +import ( + "log" + "net/netip" + "strings" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc/mgr" + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func evaluateStaticPitfalls() { + go func() { + pitfallDnsCacheDisabled() + pitfallVirtioNetworkDriver() + }() +} + +func evaluateDynamicPitfalls(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) { + go func() { + pitfallWeakHostSend(family, conf, luid) + }() +} + +func pitfallDnsCacheDisabled() { + scm, err := mgr.Connect() + if err != nil { + return + } + defer scm.Disconnect() + svc := mgr.Service{Name: "dnscache"} + svc.Handle, err = windows.OpenService(scm.Handle, windows.StringToUTF16Ptr(svc.Name), windows.SERVICE_QUERY_CONFIG) + if err != nil { + return + } + defer svc.Close() + cfg, err := svc.Config() + if err != nil { + return + } + if cfg.StartType != mgr.StartDisabled { + return + } + + log.Printf("Warning: the %q (dnscache) service is disabled; please re-enable it", cfg.DisplayName) +} + +func pitfallVirtioNetworkDriver() { + var modules []windows.RTL_PROCESS_MODULE_INFORMATION + for bufferSize := uint32(128 * 1024); ; { + moduleBuffer := make([]byte, bufferSize) + err := windows.NtQuerySystemInformation(windows.SystemModuleInformation, unsafe.Pointer(&moduleBuffer[0]), bufferSize, &bufferSize) + switch err { + case windows.STATUS_INFO_LENGTH_MISMATCH: + continue + case nil: + break + default: + return + } + mods := (*windows.RTL_PROCESS_MODULES)(unsafe.Pointer(&moduleBuffer[0])) + modules = unsafe.Slice(&mods.Modules[0], mods.NumberOfModules) + break + } + for i := range modules { + if !strings.EqualFold(windows.ByteSliceToString(modules[i].FullPathName[modules[i].OffsetToFileName:]), "netkvm.sys") { + continue + } + driverPath := `\\?\GLOBALROOT` + windows.ByteSliceToString(modules[i].FullPathName[:]) + var zero windows.Handle + infoSize, err := windows.GetFileVersionInfoSize(driverPath, &zero) + if err != nil { + return + } + versionInfo := make([]byte, infoSize) + err = windows.GetFileVersionInfo(driverPath, 0, infoSize, unsafe.Pointer(&versionInfo[0])) + if err != nil { + return + } + var fixedInfo *windows.VS_FIXEDFILEINFO + fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo)) + err = windows.VerQueryValue(unsafe.Pointer(&versionInfo[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen) + if err != nil { + return + } + const minimumPlausibleVersion = 40 << 48 + const minimumGoodVersion = (100 << 48) | (85 << 32) | (104 << 16) | (20800 << 0) + version := (uint64(fixedInfo.FileVersionMS) << 32) | uint64(fixedInfo.FileVersionLS) + if version >= minimumGoodVersion || version < minimumPlausibleVersion { + return + } + log.Println("Warning: the VirtIO network driver (NetKVM) is out of date and may cause known problems; please update to v100.85.104.20800 or later") + return + } +} + +func pitfallWeakHostSend(family winipcfg.AddressFamily, conf *conf.Config, ourLUID winipcfg.LUID) { + routingTable, err := winipcfg.GetIPForwardTable2(family) + if err != nil { + return + } + type endpointRoute struct { + addr netip.Addr + name string + lowestMetric uint32 + highestCIDR uint8 + weakHostSend bool + finalIsOurs bool + } + endpoints := make([]endpointRoute, 0, len(conf.Peers)) + for _, peer := range conf.Peers { + addr, err := netip.ParseAddr(peer.Endpoint.Host) + if err != nil || (addr.Is4() && family != windows.AF_INET) || (addr.Is6() && family != windows.AF_INET6) { + continue + } + endpoints = append(endpoints, endpointRoute{addr: addr, lowestMetric: ^uint32(0)}) + } + for i := range routingTable { + var ( + ifrow *winipcfg.MibIfRow2 + ifacerow *winipcfg.MibIPInterfaceRow + metric uint32 + ) + for j := range endpoints { + r, e := &routingTable[i], &endpoints[j] + if r.DestinationPrefix.PrefixLength < e.highestCIDR { + continue + } + if !r.DestinationPrefix.Prefix().Contains(e.addr) { + continue + } + if ifrow == nil { + ifrow, err = r.InterfaceLUID.Interface() + if err != nil { + continue + } + } + if ifrow.OperStatus != winipcfg.IfOperStatusUp { + continue + } + if ifacerow == nil { + ifacerow, err = r.InterfaceLUID.IPInterface(family) + if err != nil { + continue + } + metric = r.Metric + ifacerow.Metric + } + if r.DestinationPrefix.PrefixLength == e.highestCIDR && metric > e.lowestMetric { + continue + } + e.lowestMetric = metric + e.highestCIDR = r.DestinationPrefix.PrefixLength + e.finalIsOurs = r.InterfaceLUID == ourLUID + if !e.finalIsOurs { + e.name = ifrow.Alias() + e.weakHostSend = ifacerow.ForwardingEnabled || ifacerow.WeakHostSend + } + } + } + problematicInterfaces := make(map[string]bool, len(endpoints)) + for _, e := range endpoints { + if e.weakHostSend && e.finalIsOurs { + problematicInterfaces[e.name] = true + } + } + for iface := range problematicInterfaces { + log.Printf("Warning: the %q interface has Forwarding/WeakHostSend enabled, which will cause routing loops", iface) + } +} diff --git a/tunnel/scriptrunner.go b/tunnel/scriptrunner.go index 450d8e21..eb97d98d 100644 --- a/tunnel/scriptrunner.go +++ b/tunnel/scriptrunner.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel diff --git a/tunnel/service.go b/tunnel/service.go index 63cd243f..a56ed1f3 100644 --- a/tunnel/service.go +++ b/tunnel/service.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "log" - "net" "os" "runtime" "time" @@ -17,16 +16,12 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" - "golang.zx2c4.com/wireguard/windows/version" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) type tunnelService struct { @@ -34,12 +29,12 @@ type tunnelService struct { } func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { - changes <- svc.Status{State: svc.StartPending} + serviceState := svc.StartPending + changes <- svc.Status{State: serviceState} - var dev *device.Device - var uapi net.Listener var watcher *interfaceWatcher - var nativeTun *tun.NativeTun + var adapter *driver.Adapter + var luid winipcfg.LUID var config *conf.Config var err error serviceError := services.ErrorSuccess @@ -50,7 +45,8 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, if logErr != nil { log.Println(logErr) } - changes <- svc.Status{State: svc.StopPending} + serviceState = svc.StopPending + changes <- svc.Status{State: serviceState} stopIt := make(chan bool, 1) go func() { @@ -84,60 +80,63 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } }() - if logErr == nil && dev != nil && config != nil { + if logErr == nil && adapter != nil && config != nil { logErr = runScriptCommand(config.Interface.PreDown, config.Name) } if watcher != nil { watcher.Destroy() } - if uapi != nil { - uapi.Close() - } - if dev != nil { - dev.Close() + if adapter != nil { + adapter.Close() } - if logErr == nil && dev != nil && config != nil { + if logErr == nil && adapter != nil && config != nil { _ = runScriptCommand(config.Interface.PostDown, config.Name) } stopIt <- true log.Println("Shutting down") }() - err = ringlogger.InitGlobalLogger("TUN") + var logFile string + logFile, err = conf.LogFile(true) if err != nil { serviceError = services.ErrorRingloggerOpen return } - - config, err = conf.LoadFromPath(service.Path) + err = ringlogger.InitGlobalLogger(logFile, "TUN") if err != nil { - serviceError = services.ErrorLoadConfiguration + serviceError = services.ErrorRingloggerOpen return } - config.DeduplicateNetworkEntries() - err = CopyConfigOwnerToIPCSecurityDescriptor(service.Path) + + config, err = conf.LoadFromPath(service.Path) if err != nil { serviceError = services.ErrorLoadConfiguration return } + config.DeduplicateNetworkEntries() log.SetPrefix(fmt.Sprintf("[%s] ", config.Name)) - log.Println("Starting", version.UserAgent()) + services.PrintStarting() - if m, err := mgr.Connect(); err == nil { - if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked { - /* If we don't do this, then the Wintun installation will block forever, because - * installing a Wintun device starts a service too. Apparently at boot time, Windows - * 8.1 locks the SCM for each service start, creating a deadlock if we don't announce - * that we're running before starting additional services. - */ - log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner) - changes <- svc.Status{State: svc.Running} + if services.StartedAtBoot() { + if m, err := mgr.Connect(); err == nil { + if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked { + /* If we don't do this, then the driver installation will block forever, because + * installing a network adapter starts the driver service too. Apparently at boot time, + * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't + * announce that we're running before starting additional services. + */ + log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner) + serviceState = svc.Running + changes <- svc.Status{State: serviceState} + } + m.Disconnect() } - m.Disconnect() } + evaluateStaticPitfalls() + log.Println("Watching network interfaces") watcher, err = watchInterface() if err != nil { @@ -146,34 +145,40 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } log.Println("Resolving DNS names") - uapiConf, err := config.ToUAPI() + err = config.ResolveEndpoints() if err != nil { serviceError = services.ErrorDNSLookup return } - log.Println("Creating Wintun interface") - var wintun tun.Device - for i := 0; i < 5; i++ { + log.Println("Creating network adapter") + for i := 0; i < 15; i++ { if i > 0 { time.Sleep(time.Second) - log.Printf("Retrying Wintun creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) } - wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0) - if err == nil || windows.DurationSinceBoot() > time.Minute*4 { + adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config)) + if err == nil || !services.StartedAtBoot() { break } } if err != nil { - serviceError = services.ErrorCreateWintun + err = fmt.Errorf("Error creating adapter: %w", err) + serviceError = services.ErrorCreateNetworkAdapter return } - nativeTun = wintun.(*tun.NativeTun) - wintunVersion, err := nativeTun.RunningVersion() + luid = adapter.LUID() + driverVersion, err := driver.RunningVersion() if err != nil { - log.Printf("Warning: unable to determine Wintun version: %v", err) + log.Printf("Warning: unable to determine driver version: %v", err) } else { - log.Printf("Using Wintun/%d.%d", (wintunVersion>>16)&0xffff, wintunVersion&0xffff) + log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) + } + err = adapter.SetLogging(driver.AdapterLogOn) + if err != nil { + err = fmt.Errorf("Error enabling adapter logging: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return } err = runScriptCommand(config.Interface.PreUp, config.Name) @@ -182,7 +187,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - err = enableFirewall(config, nativeTun) + err = enableFirewall(config, luid) if err != nil { serviceError = services.ErrorFirewall return @@ -195,37 +200,18 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - log.Println("Creating interface instance") - bind := conn.NewDefaultBind() - dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf}) - log.Println("Setting interface configuration") - uapi, err = ipc.UAPIListen(config.Name) + err = adapter.SetConfiguration(config.ToDriverConfiguration()) if err != nil { - serviceError = services.ErrorUAPIListen + serviceError = services.ErrorDeviceSetConfig return } - err = dev.IpcSet(uapiConf) + err = adapter.SetAdapterState(driver.AdapterStateUp) if err != nil { - serviceError = services.ErrorDeviceSetConfig + serviceError = services.ErrorDeviceBringUp return } - - log.Println("Bringing peers up") - dev.Up() - - watcher.Configure(bind.(conn.BindSocketToInterface), config, nativeTun) - - log.Println("Listening for UAPI requests") - go func() { - for { - conn, err := uapi.Accept() - if err != nil { - continue - } - go dev.IpcHandle(conn) - } - }() + watcher.Configure(adapter, config, luid) err = runScriptCommand(config.Interface.PostUp, config.Name) if err != nil { @@ -233,9 +219,9 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} - log.Println("Startup complete") + changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown} + var started bool for { select { case c := <-r: @@ -247,8 +233,13 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, default: log.Printf("Unexpected service control request #%d\n", c) } - case <-dev.Wait(): - return + case <-watcher.started: + if !started { + serviceState = svc.Running + changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown} + log.Println("Startup complete") + started = true + } case e := <-watcher.errors: serviceError, err = e.serviceError, e.err return @@ -261,7 +252,7 @@ func Run(confPath string) error { if err != nil { return err } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return err } diff --git a/tunnel/winipcfg/interface_change_handler.go b/tunnel/winipcfg/interface_change_handler.go index 4d229e78..af29801a 100644 --- a/tunnel/winipcfg/interface_change_handler.go +++ b/tunnel/winipcfg/interface_change_handler.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/luid.go b/tunnel/winipcfg/luid.go index 02ba65b4..0c898b89 100644 --- a/tunnel/winipcfg/luid.go +++ b/tunnel/winipcfg/luid.go @@ -1,13 +1,13 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg import ( "errors" - "net" + "net/netip" "strings" "golang.org/x/sys/windows" @@ -76,10 +76,10 @@ func LUIDFromIndex(index uint32) (LUID, error) { // IPAddress method returns MibUnicastIPAddressRow struct that matches to provided 'ip' argument. Corresponds to GetUnicastIpAddressEntry // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry) -func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) { +func (luid LUID) IPAddress(addr netip.Addr) (*MibUnicastIPAddressRow, error) { row := &MibUnicastIPAddressRow{InterfaceLUID: luid} - err := row.Address.SetIP(ip, 0) + err := row.Address.SetAddr(addr) if err != nil { return nil, err } @@ -94,22 +94,24 @@ func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) { // AddIPAddress method adds new unicast IP address to the interface. Corresponds to CreateUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry). -func (luid LUID) AddIPAddress(address net.IPNet) error { +func (luid LUID) AddIPAddress(address netip.Prefix) error { row := &MibUnicastIPAddressRow{} row.Init() row.InterfaceLUID = luid - err := row.Address.SetIP(address.IP, 0) + row.DadState = DadStatePreferred + row.ValidLifetime = 0xffffffff + row.PreferredLifetime = 0xffffffff + err := row.Address.SetAddr(address.Addr()) if err != nil { return err } - ones, _ := address.Mask.Size() - row.OnLinkPrefixLength = uint8(ones) + row.OnLinkPrefixLength = uint8(address.Bits()) return row.Create() } // AddIPAddresses method adds multiple new unicast IP addresses to the interface. Corresponds to CreateUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry). -func (luid LUID) AddIPAddresses(addresses []net.IPNet) error { +func (luid LUID) AddIPAddresses(addresses []netip.Prefix) error { for i := range addresses { err := luid.AddIPAddress(addresses[i]) if err != nil { @@ -120,7 +122,7 @@ func (luid LUID) AddIPAddresses(addresses []net.IPNet) error { } // SetIPAddresses method sets new unicast IP addresses to the interface. -func (luid LUID) SetIPAddresses(addresses []net.IPNet) error { +func (luid LUID) SetIPAddresses(addresses []netip.Prefix) error { err := luid.FlushIPAddresses(windows.AF_UNSPEC) if err != nil { return err @@ -129,16 +131,15 @@ func (luid LUID) SetIPAddresses(addresses []net.IPNet) error { } // SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface. -func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error { +func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []netip.Prefix) error { err := luid.FlushIPAddresses(family) if err != nil { return err } for i := range addresses { - asV4 := addresses[i].IP.To4() - if asV4 == nil && family == windows.AF_INET { + if !addresses[i].Addr().Is4() && family == windows.AF_INET { continue - } else if asV4 != nil && family == windows.AF_INET6 { + } else if !addresses[i].Addr().Is6() && family == windows.AF_INET6 { continue } err := luid.AddIPAddress(addresses[i]) @@ -151,17 +152,16 @@ func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.I // DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry). -func (luid LUID) DeleteIPAddress(address net.IPNet) error { +func (luid LUID) DeleteIPAddress(address netip.Prefix) error { row := &MibUnicastIPAddressRow{} row.Init() row.InterfaceLUID = luid - err := row.Address.SetIP(address.IP, 0) + err := row.Address.SetAddr(address.Addr()) if err != nil { return err } // Note: OnLinkPrefixLength member is ignored by DeleteUnicastIpAddressEntry(). - ones, _ := address.Mask.Size() - row.OnLinkPrefixLength = uint8(ones) + row.OnLinkPrefixLength = uint8(address.Bits()) return row.Delete() } @@ -185,15 +185,17 @@ func (luid LUID) FlushIPAddresses(family AddressFamily) error { // Route method returns route determined with the input arguments. Corresponds to GetIpForwardEntry2 function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2). // NOTE: If the corresponding route isn't found, the method will return error. -func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2, error) { +func (luid LUID) Route(destination netip.Prefix, nextHop netip.Addr) (*MibIPforwardRow2, error) { row := &MibIPforwardRow2{} row.Init() row.InterfaceLUID = luid - err := row.DestinationPrefix.SetIPNet(destination) + row.ValidLifetime = 0xffffffff + row.PreferredLifetime = 0xffffffff + err := row.DestinationPrefix.SetPrefix(destination) if err != nil { return nil, err } - err = row.NextHop.SetIP(nextHop, 0) + err = row.NextHop.SetAddr(nextHop) if err != nil { return nil, err } @@ -207,15 +209,15 @@ func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2 // AddRoute method adds a route to the interface. Corresponds to CreateIpForwardEntry2 function, with added splitDefault feature. // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2) -func (luid LUID) AddRoute(destination net.IPNet, nextHop net.IP, metric uint32) error { +func (luid LUID) AddRoute(destination netip.Prefix, nextHop netip.Addr, metric uint32) error { row := &MibIPforwardRow2{} row.Init() row.InterfaceLUID = luid - err := row.DestinationPrefix.SetIPNet(destination) + err := row.DestinationPrefix.SetPrefix(destination) if err != nil { return err } - err = row.NextHop.SetIP(nextHop, 0) + err = row.NextHop.SetAddr(nextHop) if err != nil { return err } @@ -250,10 +252,9 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat return err } for _, rd := range routesData { - asV4 := rd.Destination.IP.To4() - if asV4 == nil && family == windows.AF_INET { + if !rd.Destination.Addr().Is4() && family == windows.AF_INET { continue - } else if asV4 != nil && family == windows.AF_INET6 { + } else if !rd.Destination.Addr().Is6() && family == windows.AF_INET6 { continue } err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric) @@ -266,15 +267,15 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat // DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2). -func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error { +func (luid LUID) DeleteRoute(destination netip.Prefix, nextHop netip.Addr) error { row := &MibIPforwardRow2{} row.Init() row.InterfaceLUID = luid - err := row.DestinationPrefix.SetIPNet(destination) + err := row.DestinationPrefix.SetPrefix(destination) if err != nil { return err } - err = row.NextHop.SetIP(nextHop, 0) + err = row.NextHop.SetAddr(nextHop) if err != nil { return err } @@ -307,17 +308,19 @@ func (luid LUID) FlushRoutes(family AddressFamily) error { } // DNS method returns all DNS server addresses associated with the adapter. -func (luid LUID) DNS() ([]net.IP, error) { +func (luid LUID) DNS() ([]netip.Addr, error) { addresses, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault) if err != nil { return nil, err } - r := make([]net.IP, 0, len(addresses)) + r := make([]netip.Addr, 0, len(addresses)) for _, addr := range addresses { if addr.LUID == luid { for dns := addr.FirstDNSServerAddress; dns != nil; dns = dns.Next { if ip := dns.Address.IP(); ip != nil { - r = append(r, ip) + if a, ok := netip.AddrFromSlice(ip); ok { + r = append(r, a) + } } else { return nil, windows.ERROR_INVALID_PARAMETER } @@ -328,17 +331,15 @@ func (luid LUID) DNS() ([]net.IP, error) { } // SetDNS method clears previous and associates new DNS servers and search domains with the adapter for a specific family. -func (luid LUID) SetDNS(family AddressFamily, servers []net.IP, domains []string) error { +func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []string) error { if family != windows.AF_INET && family != windows.AF_INET6 { return windows.ERROR_PROTOCOL_UNREACHABLE } var filteredServers []string for _, server := range servers { - if v4 := server.To4(); v4 != nil && family == windows.AF_INET { - filteredServers = append(filteredServers, v4.String()) - } else if v6 := server.To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 { - filteredServers = append(filteredServers, v6.String()) + if (server.Is4() && family == windows.AF_INET) || (server.Is6() && family == windows.AF_INET6) { + filteredServers = append(filteredServers, server.String()) } } servers16, err := windows.UTF16PtrFromString(strings.Join(filteredServers, ",")) @@ -353,17 +354,17 @@ func (luid LUID) SetDNS(family AddressFamily, servers []net.IP, domains []string if err != nil { return err } - var maybeV6 uint64 + dnsInterfaceSettings := &DnsInterfaceSettings{ + Version: DnsInterfaceSettingsVersion1, + Flags: DnsInterfaceSettingsFlagNameserver | DnsInterfaceSettingsFlagSearchList, + NameServer: servers16, + SearchList: domains16, + } if family == windows.AF_INET6 { - maybeV6 = disFlagsIPv6 + dnsInterfaceSettings.Flags |= DnsInterfaceSettingsFlagIPv6 } // For >= Windows 10 1809 - err = setInterfaceDnsSettings(*guid, &dnsInterfaceSettings{ - Version: disVersion1, - Flags: disFlagsNameServer | disFlagsSearchList | maybeV6, - NameServer: servers16, - SearchList: domains16, - }) + err = SetInterfaceDnsSettings(*guid, dnsInterfaceSettings) if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { return err } diff --git a/tunnel/winipcfg/mksyscall.go b/tunnel/winipcfg/mksyscall.go index e9e06676..d62d38df 100644 --- a/tunnel/winipcfg/mksyscall.go +++ b/tunnel/winipcfg/mksyscall.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/netsh.go b/tunnel/winipcfg/netsh.go index 1f3d12d0..4f8e5b13 100644 --- a/tunnel/winipcfg/netsh.go +++ b/tunnel/winipcfg/netsh.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg @@ -10,7 +10,7 @@ import ( "errors" "fmt" "io" - "net" + "net/netip" "os/exec" "path/filepath" "strings" @@ -57,7 +57,7 @@ const ( netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no" ) -func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) error { +func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error { var templateFlush string if family == windows.AF_INET { templateFlush = netshCmdTemplateFlush4 @@ -72,10 +72,10 @@ func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) e } cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex)) for i := 0; i < len(dnses); i++ { - if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String())) - } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String())) + if dnses[i].Is4() && family == windows.AF_INET { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, dnses[i].String())) + } else if dnses[i].Is6() && family == windows.AF_INET6 { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, dnses[i].String())) } } return runNetsh(cmds) diff --git a/tunnel/winipcfg/route_change_handler.go b/tunnel/winipcfg/route_change_handler.go index 75dcf82c..4b78331e 100644 --- a/tunnel/winipcfg/route_change_handler.go +++ b/tunnel/winipcfg/route_change_handler.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go index 02f7f788..8e8f4a59 100644 --- a/tunnel/winipcfg/types.go +++ b/tunnel/winipcfg/types.go @@ -1,12 +1,15 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg import ( - "net" + "encoding/binary" + "fmt" + "net/netip" + "strconv" "unsafe" "golang.org/x/sys/windows" @@ -582,11 +585,15 @@ const ( // RouteData structure describes a route to add type RouteData struct { - Destination net.IPNet - NextHop net.IP + Destination netip.Prefix + NextHop netip.Addr Metric uint32 } +func (routeData *RouteData) String() string { + return fmt.Sprintf("%+v", *routeData) +} + // IPAdapterDNSSuffix structure stores a DNS suffix in a linked list of DNS suffixes for a particular adapter. // https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_dns_suffix type IPAdapterDNSSuffix struct { @@ -679,8 +686,7 @@ func (row *MibIPInterfaceRow) Set() error { // get method returns all table rows as a Go slice. func (tab *mibIPInterfaceTable) get() (s []MibIPInterfaceRow) { - unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries)) - return + return unsafe.Slice(&tab.table[0], tab.numEntries) } // free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. @@ -717,8 +723,7 @@ func (row *MibIfRow2) get() (ret error) { // get method returns all table rows as a Go slice. func (tab *mibIfTable2) get() (s []MibIfRow2) { - unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries)) - return + return unsafe.Slice(&tab.table[0], tab.numEntries) } // free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. @@ -734,45 +739,82 @@ type RawSockaddrInet struct { data [26]byte } -// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port. +func ntohs(i uint16) uint16 { + return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&i))[:]) +} + +func htons(i uint16) uint16 { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, i) + return *(*uint16)(unsafe.Pointer(&b[0])) +} + +// SetAddrPort method sets family, address, and port to the given IPv4 or IPv6 address and port. // All other members of the structure are set to zero. -func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error { - if v4 := ip.To4(); v4 != nil { +func (addr *RawSockaddrInet) SetAddrPort(addrPort netip.AddrPort) error { + if addrPort.Addr().Is4() { addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)) addr4.Family = windows.AF_INET - copy(addr4.Addr[:], v4) - addr4.Port = port + addr4.Addr = addrPort.Addr().As4() + addr4.Port = htons(addrPort.Port()) for i := 0; i < 8; i++ { addr4.Zero[i] = 0 } return nil - } - - if v6 := ip.To16(); v6 != nil { + } else if addrPort.Addr().Is6() { addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)) addr6.Family = windows.AF_INET6 - addr6.Port = port + addr6.Addr = addrPort.Addr().As16() + addr6.Port = htons(addrPort.Port()) addr6.Flowinfo = 0 - copy(addr6.Addr[:], v6) - addr6.Scope_id = 0 + scopeId := uint32(0) + if z := addrPort.Addr().Zone(); z != "" { + if s, err := strconv.ParseUint(z, 10, 32); err == nil { + scopeId = uint32(s) + } + } + addr6.Scope_id = scopeId return nil } - return windows.ERROR_INVALID_PARAMETER } -// IP method returns IPv4 or IPv6 address. -// If the address is neither IPv4 not IPv6 nil is returned. -func (addr *RawSockaddrInet) IP() net.IP { +// SetAddr method sets family and address to the given IPv4 or IPv6 address. +// All other members of the structure are set to zero. +func (addr *RawSockaddrInet) SetAddr(netAddr netip.Addr) error { + return addr.SetAddrPort(netip.AddrPortFrom(netAddr, 0)) +} + +// AddrPort returns the IP address and port. +func (addr *RawSockaddrInet) AddrPort() netip.AddrPort { + return netip.AddrPortFrom(addr.Addr(), addr.Port()) +} + +// Addr returns IPv4 or IPv6 address, or an invalid address if the address is neither. +func (addr *RawSockaddrInet) Addr() netip.Addr { switch addr.Family { case windows.AF_INET: - return (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr[:] - + return netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr) case windows.AF_INET6: - return (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Addr[:] + raw := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)) + a := netip.AddrFrom16(raw.Addr) + if raw.Scope_id != 0 { + a = a.WithZone(strconv.FormatUint(uint64(raw.Scope_id), 10)) + } + return a } + return netip.Addr{} +} - return nil +// Port returns the port if the address if IPv4 or IPv6, or 0 if neither. +func (addr *RawSockaddrInet) Port() uint16 { + switch addr.Family { + case windows.AF_INET: + return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port) + case windows.AF_INET6: + return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port) + } + return 0 } // Init method initializes a MibUnicastIPAddressRow structure with default values for a unicast IP address entry on the local computer. @@ -807,8 +849,7 @@ func (row *MibUnicastIPAddressRow) Delete() error { // get method returns all table rows as a Go slice. func (tab *mibUnicastIPAddressTable) get() (s []MibUnicastIPAddressRow) { - unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries)) - return + return unsafe.Slice(&tab.table[0], tab.numEntries) } // free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. @@ -837,8 +878,7 @@ func (row *MibAnycastIPAddressRow) Delete() error { // get method returns all table rows as a Go slice. func (tab *mibAnycastIPAddressTable) get() (s []MibAnycastIPAddressRow) { - unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries)) - return + return unsafe.Slice(&tab.table[0], tab.numEntries) } // free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. @@ -850,32 +890,30 @@ func (tab *mibAnycastIPAddressTable) free() { // IPAddressPrefix structure stores an IP address prefix. // https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_ip_address_prefix type IPAddressPrefix struct { - Prefix RawSockaddrInet + RawPrefix RawSockaddrInet PrefixLength uint8 _ [2]byte } -// SetIPNet method sets IP address prefix using net.IPNet. -func (prefix *IPAddressPrefix) SetIPNet(net net.IPNet) error { - err := prefix.Prefix.SetIP(net.IP, 0) +// SetPrefix method sets IP address prefix using netip.Prefix. +func (prefix *IPAddressPrefix) SetPrefix(netPrefix netip.Prefix) error { + err := prefix.RawPrefix.SetAddr(netPrefix.Addr()) if err != nil { return err } - ones, _ := net.Mask.Size() - prefix.PrefixLength = uint8(ones) + prefix.PrefixLength = uint8(netPrefix.Bits()) return nil } -// IPNet method returns IP address prefix as net.IPNet. -// If the address is neither IPv4 not IPv6 an empty net.IPNet is returned. The resulting net.IPNet should be checked appropriately. -func (prefix *IPAddressPrefix) IPNet() net.IPNet { - switch prefix.Prefix.Family { +// Prefix returns IP address prefix as netip.Prefix. +func (prefix *IPAddressPrefix) Prefix() netip.Prefix { + switch prefix.RawPrefix.Family { case windows.AF_INET: - return net.IPNet{IP: (*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv4len)} + return netip.PrefixFrom(netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength)) case windows.AF_INET6: - return net.IPNet{IP: (*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv6len)} + return netip.PrefixFrom(netip.AddrFrom16((*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength)) } - return net.IPNet{} + return netip.Prefix{} } // MibIPforwardRow2 structure stores information about an IP route entry. @@ -930,8 +968,7 @@ func (row *MibIPforwardRow2) Delete() error { // get method returns all table rows as a Go slice. func (tab *mibIPforwardTable2) get() (s []MibIPforwardRow2) { - unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries)) - return + return unsafe.Slice(&tab.table[0], tab.numEntries) } // free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. @@ -941,11 +978,11 @@ func (tab *mibIPforwardTable2) free() { } // -// Undocumented DNS API +// DNS API // -// dnsInterfaceSettings is mean to be used with setInterfaceDnsSettings -type dnsInterfaceSettings struct { +// DnsInterfaceSettings is meant to be used with SetInterfaceDnsSettings +type DnsInterfaceSettings struct { Version uint32 _ [4]byte Flags uint64 @@ -960,37 +997,22 @@ type dnsInterfaceSettings struct { } const ( - disVersion1 = 1 - disVersion2 = 2 - - disFlagsIPv6 = 0x1 - disFlagsNameServer = 0x2 - disFlagsSearchList = 0x4 - disFlagsRegistrationEnabled = 0x8 - disFlagsRegisterAdapterName = 0x10 - disFlagsDomain = 0x20 - disFlagsHostname = 0x40 // ?? - disFlagsEnableLLMNR = 0x80 - disFlagsQueryAdapterName = 0x100 - disFlagsProfileNameServer = 0x200 - disFlagsVersion2 = 0x400 // ?? - v2 only - disFlagsMoreFlags = 0x800 // ?? - v2 only + DnsInterfaceSettingsVersion1 = 1 // for DnsInterfaceSettings + DnsInterfaceSettingsVersion2 = 2 // for DnsInterfaceSettingsEx + DnsInterfaceSettingsVersion3 = 3 // for DnsInterfaceSettings3 + + DnsInterfaceSettingsFlagIPv6 = 0x0001 + DnsInterfaceSettingsFlagNameserver = 0x0002 + DnsInterfaceSettingsFlagSearchList = 0x0004 + DnsInterfaceSettingsFlagRegistrationEnabled = 0x0008 + DnsInterfaceSettingsFlagRegisterAdapterName = 0x0010 + DnsInterfaceSettingsFlagDomain = 0x0020 + DnsInterfaceSettingsFlagHostname = 0x0040 + DnsInterfaceSettingsFlagEnableLLMNR = 0x0080 + DnsInterfaceSettingsFlagQueryAdapterName = 0x0100 + DnsInterfaceSettingsFlagProfileNameserver = 0x0200 + DnsInterfaceSettingsFlagDisableUnconstrainedQueries = 0x0400 // v2 only + DnsInterfaceSettingsFlagSupplementalSearchList = 0x0800 // v2 only + DnsInterfaceSettingsFlagDOH = 0x1000 // v3 only + DnsInterfaceSettingsFlagDOHProfile = 0x2000 // v3 only ) - -// unsafeSlice updates the slice slicePtr to be a slice -// referencing the provided data with its length & capacity set to -// lenCap. -// -// TODO: when Go 1.16 or Go 1.17 is the minimum supported version, -// update callers to use unsafe.Slice instead of this. -func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) { - type sliceHeader struct { - Data unsafe.Pointer - Len int - Cap int - } - h := (*sliceHeader)(slicePtr) - h.Data = data - h.Len = lenCap - h.Cap = lenCap -} diff --git a/tunnel/winipcfg/types_32.go b/tunnel/winipcfg/types_32.go index 51a8d31c..1a8d4443 100644 --- a/tunnel/winipcfg/types_32.go +++ b/tunnel/winipcfg/types_32.go @@ -1,8 +1,8 @@ -// +build 386 arm +//go:build 386 || arm /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/types_64.go b/tunnel/winipcfg/types_64.go index 6623ce54..3a1fe07f 100644 --- a/tunnel/winipcfg/types_64.go +++ b/tunnel/winipcfg/types_64.go @@ -1,8 +1,8 @@ -// +build amd64 arm64 +//go:build amd64 || arm64 /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/types_test.go b/tunnel/winipcfg/types_test.go index 26268dbe..b72d73f5 100644 --- a/tunnel/winipcfg/types_test.go +++ b/tunnel/winipcfg/types_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg @@ -957,7 +957,6 @@ func TestIPAddressPrefix(t *testing.T) { offset := uintptr(unsafe.Pointer(&s.PrefixLength)) - sp if offset != ipAddressPrefixPrefixLengthOffset { t.Errorf("IPAddressPrefix.PrefixLength offset is %d although %d is expected", offset, ipAddressPrefixPrefixLengthOffset) - } } diff --git a/tunnel/winipcfg/types_test_32.go b/tunnel/winipcfg/types_test_32.go index fa3f91d5..9e62bfef 100644 --- a/tunnel/winipcfg/types_test_32.go +++ b/tunnel/winipcfg/types_test_32.go @@ -1,8 +1,8 @@ -// +build 386 arm +//go:build 386 || arm /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/types_test_64.go b/tunnel/winipcfg/types_test_64.go index d20cdb30..8a181575 100644 --- a/tunnel/winipcfg/types_test_64.go +++ b/tunnel/winipcfg/types_test_64.go @@ -1,8 +1,8 @@ -// +build amd64 arm64 +//go:build amd64 || arm64 /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/unicast_address_change_handler.go b/tunnel/winipcfg/unicast_address_change_handler.go index 517ff037..cf4fcb3a 100644 --- a/tunnel/winipcfg/unicast_address_change_handler.go +++ b/tunnel/winipcfg/unicast_address_change_handler.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/winipcfg.go b/tunnel/winipcfg/winipcfg.go index dd09d3a0..e24157b9 100644 --- a/tunnel/winipcfg/winipcfg.go +++ b/tunnel/winipcfg/winipcfg.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg @@ -170,18 +170,18 @@ func GetIPForwardTable2(family AddressFamily) ([]MibIPforwardRow2, error) { //sys cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) = iphlpapi.CancelMibChangeNotify2 // -// Undocumented DNS API +// DNS-related functions // -//sys setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? -//sys setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? -//sys setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? +//sys setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? +//sys setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? +//sys setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? // The GUID is passed by value, not by reference, which means different // things on different calling conventions. On amd64, this means it's // passed by reference anyway, while on arm, arm64, and 386, it's split // into words. -func setInterfaceDnsSettings(guid windows.GUID, settings *dnsInterfaceSettings) error { +func SetInterfaceDnsSettings(guid windows.GUID, settings *DnsInterfaceSettings) error { words := (*[4]uintptr)(unsafe.Pointer(&guid)) switch runtime.GOARCH { case "amd64": diff --git a/tunnel/winipcfg/winipcfg_test.go b/tunnel/winipcfg/winipcfg_test.go index 7689a0c1..b49daf33 100644 --- a/tunnel/winipcfg/winipcfg_test.go +++ b/tunnel/winipcfg/winipcfg_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ /* @@ -8,8 +8,8 @@ Some tests in this file require: - A dedicated network adapter - Any network adapter will do. It may be virtual (Wintun etc.). The adapter name - must contain string "winipcfg_test". + Any network adapter will do. It may be virtual (WireGuardNT, Wintun, + etc.). The adapter name must contain string "winipcfg_test". Tests will add, remove, flush DNS servers, change adapter IP address, manipulate routes etc. The adapter will not be returned to previous state, so use an expendable one. @@ -22,8 +22,7 @@ Some tests in this file require: package winipcfg import ( - "bytes" - "net" + "net/netip" "strings" "syscall" "testing" @@ -38,22 +37,13 @@ const ( // TODO: Add IPv6 tests. var ( - unexistentIPAddresToAdd = net.IPNet{ - IP: net.IP{172, 16, 1, 114}, - Mask: net.IPMask{255, 255, 255, 0}, - } - unexistentRouteIPv4ToAdd = RouteData{ - Destination: net.IPNet{ - IP: net.IP{172, 16, 200, 0}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - NextHop: net.IP{172, 16, 1, 2}, - Metric: 0, - } - dnsesToSet = []net.IP{ - net.IPv4(8, 8, 8, 8), - net.IPv4(8, 8, 4, 4), + nonexistantIPv4ToAdd = netip.MustParsePrefix("172.16.1.114/24") + nonexistentRouteIPv4ToAdd = RouteData{ + Destination: netip.MustParsePrefix("172.16.200.0/24"), + NextHop: netip.MustParseAddr("172.16.1.2"), + Metric: 0, } + dnsesToSet = []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")} ) func runningElevated() bool { @@ -74,7 +64,7 @@ func getTestInterface() (*IPAdapterAddresses, error) { marker := strings.ToLower(testInterfaceMarker) for _, ifc := range ifcs { - if strings.Index(strings.ToLower(ifc.FriendlyName()), marker) != -1 { + if strings.Contains(strings.ToLower(ifc.FriendlyName()), marker) { return ifc, nil } } @@ -380,9 +370,9 @@ func TestAddDeleteIPAddress(t *testing.T) { return } - addr, err := ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP) + addr, err := ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) if err == nil { - t.Errorf("Unicast address %s already exists. Please set unexistentIPAddresToAdd appropriately.", unexistentIPAddresToAdd.IP.String()) + t.Errorf("Unicast address %s already exists. Please set nonexistantIPv4ToAdd appropriately.", nonexistantIPv4ToAdd.Addr().String()) return } else if err != windows.ERROR_NOT_FOUND { t.Errorf("LUID.IPAddress() returned an error: %w", err) @@ -410,7 +400,7 @@ func TestAddDeleteIPAddress(t *testing.T) { for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next { count-- } - err = ifc.LUID.AddIPAddresses([]net.IPNet{unexistentIPAddresToAdd}) + err = ifc.LUID.AddIPAddresses([]netip.Prefix{nonexistantIPv4ToAdd}) if err != nil { t.Errorf("LUID.AddIPAddresses() returned an error: %w", err) } @@ -424,26 +414,26 @@ func TestAddDeleteIPAddress(t *testing.T) { if count != 1 { t.Errorf("After adding there are %d new interface(s).", count) } - addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP) + addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) if err != nil { t.Errorf("LUID.IPAddress() returned an error: %w", err) } else if addr == nil { - t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", unexistentIPAddresToAdd.IP.String()) + t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", nonexistantIPv4ToAdd.Addr().String()) } if !created { t.Errorf("Notification handler has not been called on add.") } - err = ifc.LUID.DeleteIPAddress(unexistentIPAddresToAdd) + err = ifc.LUID.DeleteIPAddress(nonexistantIPv4ToAdd) if err != nil { t.Errorf("LUID.DeleteIPAddress() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) - addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP) + addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) if err == nil { - t.Errorf("Unicast address %s still exists, although it's deleted successfully.", unexistentIPAddresToAdd.IP.String()) + t.Errorf("Unicast address %s still exists, although it's deleted successfully.", nonexistantIPv4ToAdd.Addr().String()) } else if err != windows.ERROR_NOT_FOUND { t.Errorf("LUID.IPAddress() returned an error: %w", err) } @@ -460,14 +450,13 @@ func TestGetRoutes(t *testing.T) { } func TestAddDeleteRoute(t *testing.T) { - findRoute := func(luid LUID, dest net.IPNet) ([]MibIPforwardRow2, error) { + findRoute := func(luid LUID, dest netip.Prefix) ([]MibIPforwardRow2, error) { var family AddressFamily - switch { - case dest.IP.To4() != nil: + if dest.Addr().Is4() { family = windows.AF_INET - case dest.IP.To16() != nil: + } else if dest.Addr().Is6() { family = windows.AF_INET6 - default: + } else { return nil, windows.ERROR_INVALID_PARAMETER } r, err := GetIPForwardTable2(family) @@ -475,9 +464,8 @@ func TestAddDeleteRoute(t *testing.T) { return nil, err } matches := make([]MibIPforwardRow2, 0, len(r)) - ones, _ := dest.Mask.Size() for _, route := range r { - if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(ones) && route.DestinationPrefix.Prefix.Family == family && route.DestinationPrefix.Prefix.IP().Equal(dest.IP) { + if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(dest.Bits()) && route.DestinationPrefix.RawPrefix.Family == family && route.DestinationPrefix.RawPrefix.Addr() == dest.Addr() { matches = append(matches, route) } } @@ -494,20 +482,20 @@ func TestAddDeleteRoute(t *testing.T) { return } - _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err == nil { - t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?") + t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?") return } else if err != windows.ERROR_NOT_FOUND { t.Errorf("LUID.Route() returned an error: %w", err) return } - routes, err := findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination) + routes, err := findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) if err != nil { t.Errorf("findRoute() returned an error: %w", err) } else if len(routes) != 0 { - t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?", len(routes)) + t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?", len(routes)) } var created, deleted bool @@ -524,42 +512,42 @@ func TestAddDeleteRoute(t *testing.T) { } else { defer cb.Unregister() } - err = ifc.LUID.AddRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop, unexistentRouteIPv4ToAdd.Metric) + err = ifc.LUID.AddRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop, nonexistentRouteIPv4ToAdd.Metric) if err != nil { t.Errorf("LUID.AddRoute() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) - route, err := ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + route, err := ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err == windows.ERROR_NOT_FOUND { t.Error("LUID.Route() returned nil although the route is added successfully.") } else if err != nil { t.Errorf("LUID.Route() returned an error: %w", err) - } else if !route.DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) || !route.NextHop.IP().Equal(unexistentRouteIPv4ToAdd.NextHop) { + } else if route.DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() || route.NextHop.Addr() != nonexistentRouteIPv4ToAdd.NextHop { t.Error("LUID.Route() returned a wrong route!") } if !created { t.Errorf("Route handler has not been called on add.") } - routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination) + routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) if err != nil { t.Errorf("findRoute() returned an error: %w", err) } else if len(routes) != 1 { t.Errorf("findRoute() returned %d items although %d is expected.", len(routes), 1) - } else if !routes[0].DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) { - t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.Prefix.IP().String(), unexistentRouteIPv4ToAdd.Destination.IP.String()) + } else if routes[0].DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() { + t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.RawPrefix.Addr().String(), nonexistentRouteIPv4ToAdd.Destination.Addr().String()) } - err = ifc.LUID.DeleteRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + err = ifc.LUID.DeleteRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err != nil { t.Errorf("LUID.DeleteRoute() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) - _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err == nil { t.Error("LUID.Route() returned a route although it is removed successfully.") } else if err != windows.ERROR_NOT_FOUND { @@ -569,7 +557,7 @@ func TestAddDeleteRoute(t *testing.T) { t.Errorf("Route handler has not been called on delete.") } - routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination) + routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) if err != nil { t.Errorf("findRoute() returned an error: %w", err) } else if len(routes) != 0 { @@ -606,7 +594,7 @@ func TestFlushDNS(t *testing.T) { t.Errorf("LUID.DNS() returned an error: %w", err) } for _, a := range dns { - if len(a) != 16 || a.To4() != nil || !((a[15] == 1 || a[15] == 2 || a[15] == 3) && bytes.HasPrefix(a, []byte{0xfe, 0xc0, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})) { + if a.Is4() { n++ } } @@ -651,7 +639,7 @@ func TestSetDNS(t *testing.T) { t.Errorf("dnsesToSet contains %d items, while DNSServerAddresses contains %d.", len(dnsesToSet), len(newDNSes)) } else { for i := range dnsesToSet { - if !dnsesToSet[i].Equal(newDNSes[i]) { + if dnsesToSet[i] != newDNSes[i] { t.Errorf("dnsesToSet[%d] = %s while DNSServerAddresses[%d] = %s.", i, dnsesToSet[i].String(), i, newDNSes[i].String()) } } diff --git a/tunnel/winipcfg/zwinipcfg_windows.go b/tunnel/winipcfg/zwinipcfg_windows.go index ac89fec1..3a0d8680 100644 --- a/tunnel/winipcfg/zwinipcfg_windows.go +++ b/tunnel/winipcfg/zwinipcfg_windows.go @@ -289,7 +289,7 @@ func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, caller return } -func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *dnsInterfaceSettings) (ret error) { +func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) { ret = procSetInterfaceDnsSettings.Find() if ret != nil { return @@ -301,24 +301,24 @@ func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr return } -func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *dnsInterfaceSettings) (ret error) { +func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) { ret = procSetInterfaceDnsSettings.Find() if ret != nil { return } - r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0) + r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings))) if r0 != 0 { ret = syscall.Errno(r0) } return } -func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *dnsInterfaceSettings) (ret error) { +func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) { ret = procSetInterfaceDnsSettings.Find() if ret != nil { return } - r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings))) + r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0) if r0 != 0 { ret = syscall.Errno(r0) } diff --git a/tunnel/wintun_test.go b/tunnel/wintun_test.go deleted file mode 100644 index 4e56ff65..00000000 --- a/tunnel/wintun_test.go +++ /dev/null @@ -1,202 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. - */ - -package tunnel_test - -import ( - "bytes" - "crypto/rand" - "encoding/binary" - "fmt" - "net" - "sync" - "testing" - "time" - - "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/tun" - - "golang.zx2c4.com/wireguard/windows/elevate" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" -) - -func TestWintunOrdering(t *testing.T) { - var tunDevice tun.Device - err := elevate.DoAsSystem(func() error { - var err error - tunDevice, err = tun.CreateTUNWithRequestedGUID("tunordertest", &windows.GUID{12, 12, 12, [8]byte{12, 12, 12, 12, 12, 12, 12, 12}}, 1500) - return err - }) - if err != nil { - t.Fatal(err) - } - defer tunDevice.Close() - nativeTunDevice := tunDevice.(*tun.NativeTun) - luid := winipcfg.LUID(nativeTunDevice.LUID()) - ip, ipnet, _ := net.ParseCIDR("10.82.31.4/24") - err = luid.SetIPAddresses([]net.IPNet{{ip, ipnet.Mask}}) - if err != nil { - t.Fatal(err) - } - err = luid.SetRoutes([]*winipcfg.RouteData{{*ipnet, ipnet.IP, 0}}) - if err != nil { - t.Fatal(err) - } - var token [32]byte - _, err = rand.Read(token[:]) - if err != nil { - t.Fatal(err) - } - var sockWrite net.Conn - for i := 0; i < 1000; i++ { - sockWrite, err = net.Dial("udp", "10.82.31.5:9999") - if err == nil { - defer sockWrite.Close() - break - } - time.Sleep(time.Millisecond * 100) - } - if err != nil { - t.Fatal(err) - } - var sockRead *net.UDPConn - for i := 0; i < 1000; i++ { - var listenAddress *net.UDPAddr - listenAddress, err = net.ResolveUDPAddr("udp", "10.82.31.4:9999") - if err != nil { - continue - } - sockRead, err = net.ListenUDP("udp", listenAddress) - if err == nil { - defer sockRead.Close() - break - } - time.Sleep(time.Millisecond * 100) - } - if err != nil { - t.Fatal(err) - } - var wait sync.WaitGroup - wait.Add(4) - doneSockWrite := false - doneTunWrite := false - fatalErrors := make(chan error, 2) - errors := make(chan error, 2) - go func() { - defer wait.Done() - buffer := append(token[:], 0, 0, 0, 0, 0, 0, 0, 0) - for sendingIndex := uint64(0); !doneSockWrite; sendingIndex++ { - binary.LittleEndian.PutUint64(buffer[32:], sendingIndex) - _, err := sockWrite.Write(buffer[:]) - if err != nil { - fatalErrors <- err - } - } - }() - go func() { - defer wait.Done() - packet := [20 + 8 + 32 + 8]byte{ - 0x45, 0, 0, 20 + 8 + 32 + 8, - 0, 0, 0, 0, - 0x80, 0x11, 0, 0, - 10, 82, 31, 5, - 10, 82, 31, 4, - 8888 >> 8, 8888 & 0xff, 9999 >> 8, 9999 & 0xff, 0, 8 + 32 + 8, 0, 0, - } - copy(packet[28:], token[:]) - for sendingIndex := uint64(0); !doneTunWrite; sendingIndex++ { - binary.BigEndian.PutUint16(packet[4:], uint16(sendingIndex)) - var checksum uint32 - for i := 0; i < 20; i += 2 { - if i != 10 { - checksum += uint32(binary.BigEndian.Uint16(packet[i:])) - } - } - binary.BigEndian.PutUint16(packet[10:], ^(uint16(checksum>>16) + uint16(checksum&0xffff))) - binary.LittleEndian.PutUint64(packet[20+8+32:], sendingIndex) - n, err := tunDevice.Write(packet[:], 0) - if err != nil { - fatalErrors <- err - } - if n == 0 { - time.Sleep(time.Millisecond * 300) - } - } - }() - const packetsPerTest = 1 << 21 - go func() { - defer func() { - doneSockWrite = true - wait.Done() - }() - var expectedIndex uint64 - for i := uint64(0); i < packetsPerTest; { - var buffer [(1 << 16) - 1]byte - bytesRead, err := tunDevice.Read(buffer[:], 0) - if err != nil { - fatalErrors <- err - } - if bytesRead < 0 || bytesRead > len(buffer) { - continue - } - packet := buffer[:bytesRead] - tokenPos := bytes.Index(packet, token[:]) - if tokenPos == -1 || tokenPos+32+8 > len(packet) { - continue - } - foundIndex := binary.LittleEndian.Uint64(packet[tokenPos+32:]) - if foundIndex < expectedIndex { - errors <- fmt.Errorf("Sock write, tun read: expected packet %d, received packet %d", expectedIndex, foundIndex) - } - expectedIndex = foundIndex + 1 - i++ - } - }() - go func() { - defer func() { - doneTunWrite = true - wait.Done() - }() - var expectedIndex uint64 - for i := uint64(0); i < packetsPerTest; { - var buffer [(1 << 16) - 1]byte - bytesRead, err := sockRead.Read(buffer[:]) - if err != nil { - fatalErrors <- err - } - if bytesRead < 0 || bytesRead > len(buffer) { - continue - } - packet := buffer[:bytesRead] - if len(packet) != 32+8 || !bytes.HasPrefix(packet, token[:]) { - continue - } - foundIndex := binary.LittleEndian.Uint64(packet[32:]) - if foundIndex < expectedIndex { - errors <- fmt.Errorf("Tun write, sock read: expected packet %d, received packet %d", expectedIndex, foundIndex) - } - expectedIndex = foundIndex + 1 - i++ - } - }() - done := make(chan bool, 2) - doneFunc := func() { - wait.Wait() - done <- true - } - defer doneFunc() - go doneFunc() - for { - select { - case err := <-fatalErrors: - t.Fatal(err) - case err := <-errors: - t.Error(err) - case <-done: - return - } - } -} |