diff options
Diffstat (limited to '')
34 files changed, 1136 insertions, 1230 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go index 4be2c36a..a3ce6295 100644 --- a/tunnel/addressconfig.go +++ b/tunnel/addressconfig.go @@ -1,42 +1,30 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel import ( - "bytes" + "fmt" "log" - "net" - "sort" + "net/netip" + "time" "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) { +func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) { if len(addresses) == 0 { return } - includedInAddresses := func(a net.IPNet) bool { - // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer! - for _, addr := range addresses { - ip := addr.IP - if ip4 := ip.To4(); ip4 != nil { - ip = ip4 - } - mA, _ := addr.Mask.Size() - mB, _ := a.Mask.Size() - if bytes.Equal(ip, a.IP) && mA == mB { - return true - } - } - return false + addrHash := make(map[netip.Addr]bool, len(addresses)) + for i := range addresses { + addrHash[addresses[i].Addr()] = true } interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault) if err != nil { @@ -47,155 +35,124 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add continue } for address := iface.FirstUnicastAddress; address != nil; address = address.Next { - ip := address.Address.IP() - ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))} - if includedInAddresses(ipnet) { - log.Printf("Cleaning up stale address %s from interface ā%sā", ipnet.String(), iface.FriendlyName()) - iface.LUID.DeleteIPAddress(ipnet) + if ip, _ := netip.AddrFromSlice(address.Address.IP()); addrHash[ip] { + prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength)) + log.Printf("Cleaning up stale address %s from interface ā%sā", prefix.String(), iface.FriendlyName()) + iface.LUID.DeleteIPAddress(prefix) } } } } -func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error { - luid := winipcfg.LUID(tun.LUID()) +func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) error { + retryOnFailure := services.StartedAtBoot() + tryTimes := 0 +startOver: + var err error + if tryTimes > 0 { + log.Printf("Retrying interface configuration after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + time.Sleep(time.Second) + retryOnFailure = retryOnFailure && tryTimes < 15 + } + tryTimes++ estimatedRouteCount := 0 for _, peer := range conf.Peers { estimatedRouteCount += len(peer.AllowedIPs) } - routes := make([]winipcfg.RouteData, 0, estimatedRouteCount) - addresses := make([]net.IPNet, len(conf.Interface.Addresses)) - var haveV4Address, haveV6Address bool - for i, addr := range conf.Interface.Addresses { - addresses[i] = addr.IPNet() - if addr.Bits() == 32 { - haveV4Address = true - } else if addr.Bits() == 128 { - haveV6Address = true - } - } + routes := make(map[winipcfg.RouteData]bool, estimatedRouteCount) foundDefault4 := false foundDefault6 := false for _, peer := range conf.Peers { for _, allowedip := range peer.AllowedIPs { - if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) { - continue - } route := winipcfg.RouteData{ - Destination: allowedip.IPNet(), + Destination: allowedip.Masked(), Metric: 0, } - if allowedip.Bits() == 32 { - if allowedip.Cidr == 0 { + if allowedip.Addr().Is4() { + if allowedip.Bits() == 0 { foundDefault4 = true } - route.NextHop = net.IPv4zero - } else if allowedip.Bits() == 128 { - if allowedip.Cidr == 0 { + route.NextHop = netip.IPv4Unspecified() + } else if allowedip.Addr().Is6() { + if allowedip.Bits() == 0 { foundDefault6 = true } - route.NextHop = net.IPv6zero + route.NextHop = netip.IPv6Unspecified() } - routes = append(routes, route) + routes[route] = true } } - err := luid.SetIPAddressesForFamily(family, addresses) - if err == windows.ERROR_OBJECT_ALREADY_EXISTS { - cleanupAddressesOnDisconnectedInterfaces(family, addresses) - err = luid.SetIPAddressesForFamily(family, addresses) - } - if err != nil { - return err + deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes)) + for route := range routes { + r := route + deduplicatedRoutes = append(deduplicatedRoutes, &r) } - deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes)) - sort.Slice(routes, func(i, j int) bool { - return routes[i].Metric < routes[j].Metric || - bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 || - bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 || - bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == -1 - }) - for i := 0; i < len(routes); i++ { - if i > 0 && routes[i].Metric == routes[i-1].Metric && - bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) && - bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) && - bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) { - continue + if !conf.Interface.TableOff { + err = luid.SetRoutesForFamily(family, deduplicatedRoutes) + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set routes: %w", err) } - deduplicatedRoutes = append(deduplicatedRoutes, &routes[i]) } - err = luid.SetRoutesForFamily(family, deduplicatedRoutes) - if err != nil { - return nil + err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses) + if err == windows.ERROR_OBJECT_ALREADY_EXISTS { + cleanupAddressesOnDisconnectedInterfaces(family, conf.Interface.Addresses) + err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses) + } + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set ips: %w", err) } - ipif, err := luid.IPInterface(family) + var ipif *winipcfg.MibIPInterfaceRow + ipif, err = luid.IPInterface(family) if err != nil { return err } + ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + ipif.DadTransmits = 0 + ipif.ManagedAddressConfigurationSupported = false + ipif.OtherStatefulConfigurationSupported = false if conf.Interface.MTU > 0 { ipif.NLMTU = uint32(conf.Interface.MTU) - tun.ForceMTU(int(ipif.NLMTU)) } - if family == windows.AF_INET { - if foundDefault4 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - } else if family == windows.AF_INET6 { - if foundDefault6 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - ipif.DadTransmits = 0 - ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 } err = ipif.Set() - if err != nil { - return err + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set metric and MTU: %w", err) } - dnsSearch := "" - if len(conf.Interface.DNSSearch) > 0 { - dnsSearch = conf.Interface.DNSSearch[0] - } - err = luid.SetDNSDomain(dnsSearch) - if err != nil { - return nil + err = luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch) + if err == windows.ERROR_NOT_FOUND && retryOnFailure { + goto startOver + } else if err != nil { + return fmt.Errorf("unable to set DNS: %w", err) } - if len(conf.Interface.DNSSearch) > 1 { - log.Printf("Warning: %d DNS search domains were specified, but only one is supported, so the first one (%s) was used.", len(conf.Interface.DNSSearch), dnsSearch) - } - err = luid.SetDNSForFamily(family, conf.Interface.DNS) - if err != nil { - return err - } - return nil } -func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error { - restrictAll := false - if len(conf.Peers) == 1 { - nextallowedip: +func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error { + doNotRestrict := true + if len(conf.Peers) == 1 && !conf.Interface.TableOff { for _, allowedip := range conf.Peers[0].AllowedIPs { - if allowedip.Cidr == 0 { - for _, b := range allowedip.IP { - if b != 0 { - continue nextallowedip - } - } - restrictAll = true + if allowedip.Bits() == 0 && allowedip == allowedip.Masked() { + doNotRestrict = false break } } } - if restrictAll && len(conf.Interface.DNS) == 0 { - log.Println("Warning: no DNS server specified, despite having an allowed IPs of 0.0.0.0/0 or ::/0. There may be connectivity issues.") - } - return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS, restrictAll) + log.Println("Enabling firewall rules") + return firewall.EnableFirewall(uint64(luid), doNotRestrict, conf.Interface.DNS) } diff --git a/tunnel/deterministicguid.go b/tunnel/deterministicguid.go index 8c0f34c0..405d33a3 100644 --- a/tunnel/deterministicguid.go +++ b/tunnel/deterministicguid.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel @@ -18,8 +18,10 @@ import ( "golang.zx2c4.com/wireguard/windows/conf" ) -const deterministicGUIDLabel = "Deterministic WireGuard Windows GUID v1 jason@zx2c4.com" -const fixedGUIDLabel = "Fixed WireGuard Windows GUID v1 jason@zx2c4.com" +const ( + deterministicGUIDLabel = "Deterministic WireGuard Windows GUID v1 jason@zx2c4.com" + fixedGUIDLabel = "Fixed WireGuard Windows GUID v1 jason@zx2c4.com" +) // Escape hatch for external consumers, not us. var UseFixedGUIDInsteadOfDeterministic = false @@ -80,13 +82,13 @@ func deterministicGUID(c *conf.Config) *windows.GUID { b2Number(len(peer.AllowedIPs)) sortedAllowedIPs := peer.AllowedIPs sort.Slice(sortedAllowedIPs, func(i, j int) bool { - if bi, bj := sortedAllowedIPs[i].Bits(), sortedAllowedIPs[j].Bits(); bi != bj { + if bi, bj := sortedAllowedIPs[i].Addr().BitLen(), sortedAllowedIPs[j].Addr().BitLen(); bi != bj { return bi < bj } - if sortedAllowedIPs[i].Cidr != sortedAllowedIPs[j].Cidr { - return sortedAllowedIPs[i].Cidr < sortedAllowedIPs[j].Cidr + if sortedAllowedIPs[i].Bits() != sortedAllowedIPs[j].Bits() { + return sortedAllowedIPs[i].Bits() < sortedAllowedIPs[j].Bits() } - return bytes.Compare(sortedAllowedIPs[i].IP[:], sortedAllowedIPs[j].IP[:]) < 0 + return sortedAllowedIPs[i].Addr().Compare(sortedAllowedIPs[j].Addr()) < 0 }) for _, allowedip := range sortedAllowedIPs { b2String(allowedip.String()) diff --git a/tunnel/firewall/blocker.go b/tunnel/firewall/blocker.go index 7da391ca..8a4967ba 100644 --- a/tunnel/firewall/blocker.go +++ b/tunnel/firewall/blocker.go @@ -1,13 +1,13 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall import ( "errors" - "net" + "net/netip" "unsafe" "golang.org/x/sys/windows" @@ -101,7 +101,7 @@ func registerBaseObjects(session uintptr) (*baseObjects, error) { return bo, nil } -func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool) error { +func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []netip.Addr) error { if wfpSession != 0 { return errors.New("The firewall has already been enabled") } @@ -122,26 +122,24 @@ func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool return wrapErr(err) } - if len(restrictToDNSServers) > 0 { - err = blockDNS(restrictToDNSServers, session, baseObjects, 15, 14) - if err != nil { - return wrapErr(err) + if !doNotRestrict { + if len(restrictToDNSServers) > 0 { + err = blockDNS(restrictToDNSServers, session, baseObjects, 15, 14) + if err != nil { + return wrapErr(err) + } } - } - if restrictAll { err = permitLoopback(session, baseObjects, 13) if err != nil { return wrapErr(err) } - } - err = permitTunInterface(session, baseObjects, 12, luid) - if err != nil { - return wrapErr(err) - } + err = permitTunInterface(session, baseObjects, 12, luid) + if err != nil { + return wrapErr(err) + } - if restrictAll { err = permitDHCPIPv4(session, baseObjects, 12) if err != nil { return wrapErr(err) @@ -164,9 +162,7 @@ func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool return wrapErr(err) } */ - } - if restrictAll { err = blockAll(session, baseObjects, 0) if err != nil { return wrapErr(err) diff --git a/tunnel/firewall/helpers.go b/tunnel/firewall/helpers.go index 0c9e8e3f..46e43aa5 100644 --- a/tunnel/firewall/helpers.go +++ b/tunnel/firewall/helpers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall @@ -66,9 +66,9 @@ func wrapErr(err error) error { } _, file, line, ok := runtime.Caller(1) if !ok { - return fmt.Errorf("Firewall error at unknown location: %v", err) + return fmt.Errorf("Firewall error at unknown location: %w", err) } - return fmt.Errorf("Firewall error at %s:%d: %v", file, line, err) + return fmt.Errorf("Firewall error at %s:%d: %w", file, line, err) } func getCurrentProcessSecurityDescriptor() (*windows.SECURITY_DESCRIPTOR, error) { diff --git a/tunnel/firewall/mksyscall.go b/tunnel/firewall/mksyscall.go index 060c3b1c..fc108007 100644 --- a/tunnel/firewall/mksyscall.go +++ b/tunnel/firewall/mksyscall.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall diff --git a/tunnel/firewall/rules.go b/tunnel/firewall/rules.go index 7bca508b..41632f98 100644 --- a/tunnel/firewall/rules.go +++ b/tunnel/firewall/rules.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall @@ -8,7 +8,7 @@ package firewall import ( "encoding/binary" "errors" - "net" + "net/netip" "runtime" "unsafe" @@ -582,7 +582,6 @@ func permitDHCPIPv6(session uintptr, baseObjects *baseObjects, weight uint8) err } func permitNdp(session uintptr, baseObjects *baseObjects, weight uint8) error { - /* TODO: actually handle the hop limit somehow! The rules should vaguely be: * - icmpv6 133: must be outgoing, dst must be FF02::2/128, hop limit must be 255 * - icmpv6 134: must be incoming, src must be FE80::/10, hop limit must be 255 @@ -985,7 +984,7 @@ func blockAll(session uintptr, baseObjects *baseObjects, weight uint8) error { } // Block all DNS traffic except towards specified DNS servers. -func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weightAllow uint8, weightDeny uint8) error { +func blockDNS(except []netip.Addr, session uintptr, baseObjects *baseObjects, weightAllow, weightDeny uint8) error { if weightDeny >= weightAllow { return errors.New("The allow weight must be greater than the deny weight") } @@ -1106,8 +1105,7 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight allowConditionsV4 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except)) allowConditionsV4 = append(allowConditionsV4, denyConditions...) for _, ip := range except { - ip4 := ip.To4() - if ip4 == nil { + if !ip.Is4() { continue } allowConditionsV4 = append(allowConditionsV4, wtFwpmFilterCondition0{ @@ -1115,7 +1113,7 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight matchType: cFWP_MATCH_EQUAL, conditionValue: wtFwpConditionValue0{ _type: cFWP_UINT32, - value: uintptr(binary.BigEndian.Uint32(ip4)), + value: uintptr(binary.BigEndian.Uint32(ip.AsSlice())), }, }) } @@ -1124,11 +1122,10 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight allowConditionsV6 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except)) allowConditionsV6 = append(allowConditionsV6, denyConditions...) for _, ip := range except { - if ip.To4() != nil { + if !ip.Is6() { continue } - var address wtFwpByteArray16 - copy(address.byteArray16[:], ip) + address := wtFwpByteArray16{byteArray16: ip.As16()} allowConditionsV6 = append(allowConditionsV6, wtFwpmFilterCondition0{ fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS, matchType: cFWP_MATCH_EQUAL, diff --git a/tunnel/firewall/syscall_windows.go b/tunnel/firewall/syscall_windows.go index 1d2696a1..4d8eea42 100644 --- a/tunnel/firewall/syscall_windows.go +++ b/tunnel/firewall/syscall_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall diff --git a/tunnel/firewall/types_windows.go b/tunnel/firewall/types_windows.go index 9192c023..54e2aad7 100644 --- a/tunnel/firewall/types_windows.go +++ b/tunnel/firewall/types_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall @@ -148,8 +148,10 @@ var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{ Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57}, } -var cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT -var cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT +var ( + cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT + cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT +) // 7bc43cbf-37ba-45f1-b74a-82ff518eeb10 var cFWPM_CONDITION_L2_FLAGS = windows.GUID{ diff --git a/tunnel/firewall/types_windows_386.go b/tunnel/firewall/types_windows_32.go index e8e90663..29ae553a 100644 --- a/tunnel/firewall/types_windows_386.go +++ b/tunnel/firewall/types_windows_32.go @@ -1,6 +1,8 @@ +//go:build 386 || arm + /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall diff --git a/tunnel/firewall/types_windows_amd64.go b/tunnel/firewall/types_windows_64.go index 13fde97a..a476a745 100644 --- a/tunnel/firewall/types_windows_amd64.go +++ b/tunnel/firewall/types_windows_64.go @@ -1,6 +1,8 @@ +//go:build amd64 || arm64 + /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall diff --git a/tunnel/firewall/types_windows_test.go b/tunnel/firewall/types_windows_test.go index 97cb032c..afa1988f 100644 --- a/tunnel/firewall/types_windows_test.go +++ b/tunnel/firewall/types_windows_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package firewall @@ -11,7 +11,6 @@ import ( ) func TestWtFwpByteBlobSize(t *testing.T) { - const actualWtFwpByteBlobSize = unsafe.Sizeof(wtFwpByteBlob{}) if actualWtFwpByteBlobSize != wtFwpByteBlob_Size { @@ -21,7 +20,6 @@ func TestWtFwpByteBlobSize(t *testing.T) { } func TestWtFwpByteBlobOffsets(t *testing.T) { - s := wtFwpByteBlob{} sp := uintptr(unsafe.Pointer(&s)) @@ -34,7 +32,6 @@ func TestWtFwpByteBlobOffsets(t *testing.T) { } func TestWtFwpmAction0Size(t *testing.T) { - const actualWtFwpmAction0Size = unsafe.Sizeof(wtFwpmAction0{}) if actualWtFwpmAction0Size != wtFwpmAction0_Size { @@ -44,7 +41,6 @@ func TestWtFwpmAction0Size(t *testing.T) { } func TestWtFwpmAction0Offsets(t *testing.T) { - s := wtFwpmAction0{} sp := uintptr(unsafe.Pointer(&s)) @@ -58,7 +54,6 @@ func TestWtFwpmAction0Offsets(t *testing.T) { } func TestWtFwpBitmapArray64Size(t *testing.T) { - const actualWtFwpBitmapArray64Size = unsafe.Sizeof(wtFwpBitmapArray64{}) if actualWtFwpBitmapArray64Size != wtFwpBitmapArray64_Size { @@ -68,7 +63,6 @@ func TestWtFwpBitmapArray64Size(t *testing.T) { } func TestWtFwpByteArray6Size(t *testing.T) { - const actualWtFwpByteArray6Size = unsafe.Sizeof(wtFwpByteArray6{}) if actualWtFwpByteArray6Size != wtFwpByteArray6_Size { @@ -78,7 +72,6 @@ func TestWtFwpByteArray6Size(t *testing.T) { } func TestWtFwpByteArray16Size(t *testing.T) { - const actualWtFwpByteArray16Size = unsafe.Sizeof(wtFwpByteArray16{}) if actualWtFwpByteArray16Size != wtFwpByteArray16_Size { @@ -88,7 +81,6 @@ func TestWtFwpByteArray16Size(t *testing.T) { } func TestWtFwpConditionValue0Size(t *testing.T) { - const actualWtFwpConditionValue0Size = unsafe.Sizeof(wtFwpConditionValue0{}) if actualWtFwpConditionValue0Size != wtFwpConditionValue0_Size { @@ -98,7 +90,6 @@ func TestWtFwpConditionValue0Size(t *testing.T) { } func TestWtFwpConditionValue0Offsets(t *testing.T) { - s := wtFwpConditionValue0{} sp := uintptr(unsafe.Pointer(&s)) @@ -111,7 +102,6 @@ func TestWtFwpConditionValue0Offsets(t *testing.T) { } func TestWtFwpV4AddrAndMaskSize(t *testing.T) { - const actualWtFwpV4AddrAndMaskSize = unsafe.Sizeof(wtFwpV4AddrAndMask{}) if actualWtFwpV4AddrAndMaskSize != wtFwpV4AddrAndMask_Size { @@ -121,7 +111,6 @@ func TestWtFwpV4AddrAndMaskSize(t *testing.T) { } func TestWtFwpV4AddrAndMaskOffsets(t *testing.T) { - s := wtFwpV4AddrAndMask{} sp := uintptr(unsafe.Pointer(&s)) @@ -135,7 +124,6 @@ func TestWtFwpV4AddrAndMaskOffsets(t *testing.T) { } func TestWtFwpV6AddrAndMaskSize(t *testing.T) { - const actualWtFwpV6AddrAndMaskSize = unsafe.Sizeof(wtFwpV6AddrAndMask{}) if actualWtFwpV6AddrAndMaskSize != wtFwpV6AddrAndMask_Size { @@ -145,7 +133,6 @@ func TestWtFwpV6AddrAndMaskSize(t *testing.T) { } func TestWtFwpV6AddrAndMaskOffsets(t *testing.T) { - s := wtFwpV6AddrAndMask{} sp := uintptr(unsafe.Pointer(&s)) @@ -159,7 +146,6 @@ func TestWtFwpV6AddrAndMaskOffsets(t *testing.T) { } func TestWtFwpValue0Size(t *testing.T) { - const actualWtFwpValue0Size = unsafe.Sizeof(wtFwpValue0{}) if actualWtFwpValue0Size != wtFwpValue0_Size { @@ -168,7 +154,6 @@ func TestWtFwpValue0Size(t *testing.T) { } func TestWtFwpValue0Offsets(t *testing.T) { - s := wtFwpValue0{} sp := uintptr(unsafe.Pointer(&s)) @@ -181,7 +166,6 @@ func TestWtFwpValue0Offsets(t *testing.T) { } func TestWtFwpmDisplayData0Size(t *testing.T) { - const actualWtFwpmDisplayData0Size = unsafe.Sizeof(wtFwpmDisplayData0{}) if actualWtFwpmDisplayData0Size != wtFwpmDisplayData0_Size { @@ -191,7 +175,6 @@ func TestWtFwpmDisplayData0Size(t *testing.T) { } func TestWtFwpmDisplayData0Offsets(t *testing.T) { - s := wtFwpmDisplayData0{} sp := uintptr(unsafe.Pointer(&s)) @@ -205,7 +188,6 @@ func TestWtFwpmDisplayData0Offsets(t *testing.T) { } func TestWtFwpmFilterCondition0Size(t *testing.T) { - const actualWtFwpmFilterCondition0Size = unsafe.Sizeof(wtFwpmFilterCondition0{}) if actualWtFwpmFilterCondition0Size != wtFwpmFilterCondition0_Size { @@ -215,7 +197,6 @@ func TestWtFwpmFilterCondition0Size(t *testing.T) { } func TestWtFwpmFilterCondition0Offsets(t *testing.T) { - s := wtFwpmFilterCondition0{} sp := uintptr(unsafe.Pointer(&s)) @@ -237,7 +218,6 @@ func TestWtFwpmFilterCondition0Offsets(t *testing.T) { } func TestWtFwpmFilter0Size(t *testing.T) { - const actualWtFwpmFilter0Size = unsafe.Sizeof(wtFwpmFilter0{}) if actualWtFwpmFilter0Size != wtFwpmFilter0_Size { @@ -247,7 +227,6 @@ func TestWtFwpmFilter0Size(t *testing.T) { } func TestWtFwpmFilter0Offsets(t *testing.T) { - s := wtFwpmFilter0{} sp := uintptr(unsafe.Pointer(&s)) @@ -364,7 +343,6 @@ func TestWtFwpmFilter0Offsets(t *testing.T) { } func TestWtFwpProvider0Size(t *testing.T) { - const actualWtFwpProvider0Size = unsafe.Sizeof(wtFwpProvider0{}) if actualWtFwpProvider0Size != wtFwpProvider0_Size { @@ -374,7 +352,6 @@ func TestWtFwpProvider0Size(t *testing.T) { } func TestWtFwpProvider0Offsets(t *testing.T) { - s := wtFwpProvider0{} sp := uintptr(unsafe.Pointer(&s)) @@ -412,7 +389,6 @@ func TestWtFwpProvider0Offsets(t *testing.T) { } func TestWtFwpmSession0Size(t *testing.T) { - const actualWtFwpmSession0Size = unsafe.Sizeof(wtFwpmSession0{}) if actualWtFwpmSession0Size != wtFwpmSession0_Size { @@ -422,7 +398,6 @@ func TestWtFwpmSession0Size(t *testing.T) { } func TestWtFwpmSession0Offsets(t *testing.T) { - s := wtFwpmSession0{} sp := uintptr(unsafe.Pointer(&s)) @@ -482,7 +457,6 @@ func TestWtFwpmSession0Offsets(t *testing.T) { } func TestWtFwpmSublayer0Size(t *testing.T) { - const actualWtFwpmSublayer0Size = unsafe.Sizeof(wtFwpmSublayer0{}) if actualWtFwpmSublayer0Size != wtFwpmSublayer0_Size { @@ -492,7 +466,6 @@ func TestWtFwpmSublayer0Size(t *testing.T) { } func TestWtFwpmSublayer0Offsets(t *testing.T) { - s := wtFwpmSublayer0{} sp := uintptr(unsafe.Pointer(&s)) diff --git a/tunnel/firewall/zsyscall_windows.go b/tunnel/firewall/zsyscall_windows.go index 846d4ae8..9e60132d 100644 --- a/tunnel/firewall/zsyscall_windows.go +++ b/tunnel/firewall/zsyscall_windows.go @@ -19,6 +19,7 @@ const ( var ( errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL ) // errnoErr returns common boxed Errno values, to prevent @@ -26,7 +27,7 @@ var ( func errnoErr(e syscall.Errno) error { switch e { case 0: - return nil + return errERROR_EINVAL case errnoERROR_IO_PENDING: return errERROR_IO_PENDING } @@ -39,62 +40,38 @@ func errnoErr(e syscall.Errno) error { var ( modfwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll") - procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0") procFwpmEngineClose0 = modfwpuclnt.NewProc("FwpmEngineClose0") - procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0") - procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0") - procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0") + procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0") procFwpmFilterAdd0 = modfwpuclnt.NewProc("FwpmFilterAdd0") + procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0") + procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0") + procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0") + procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0") + procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0") procFwpmTransactionBegin0 = modfwpuclnt.NewProc("FwpmTransactionBegin0") procFwpmTransactionCommit0 = modfwpuclnt.NewProc("FwpmTransactionCommit0") - procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0") - procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0") ) -func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) { - r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0) - if r1 != 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - func fwpmEngineClose0(engineHandle uintptr) (err error) { r1, _, e1 := syscall.Syscall(procFwpmEngineClose0.Addr(), 1, uintptr(engineHandle), 0, 0) if r1 != 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } + err = errnoErr(e1) } return } -func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) { - r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd)) +func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) { + r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0) if r1 != 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } + err = errnoErr(e1) } return } -func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) { - r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0) +func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) { + r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0) if r1 != 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } + err = errnoErr(e1) } return } @@ -104,38 +81,26 @@ func fwpmFreeMemory0(p unsafe.Pointer) { return } -func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) { - r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0) +func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) { + r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0) if r1 != 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } + err = errnoErr(e1) } return } -func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0) +func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) { + r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd)) if r1 != 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } + err = errnoErr(e1) } return } -func fwpmTransactionCommit0(engineHandle uintptr) (err error) { - r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0) +func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) { + r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd)) if r1 != 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } + err = errnoErr(e1) } return } @@ -143,23 +108,23 @@ func fwpmTransactionCommit0(engineHandle uintptr) (err error) { func fwpmTransactionAbort0(engineHandle uintptr) (err error) { r1, _, e1 := syscall.Syscall(procFwpmTransactionAbort0.Addr(), 1, uintptr(engineHandle), 0, 0) if r1 != 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } + err = errnoErr(e1) } return } -func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) { - r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd)) +func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) { + r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0) + if r1 != 0 { + err = errnoErr(e1) + } + return +} + +func fwpmTransactionCommit0(engineHandle uintptr) (err error) { + r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0) if r1 != 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } + err = errnoErr(e1) } return } diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go index 1f632725..a831d06e 100644 --- a/tunnel/interfacewatcher.go +++ b/tunnel/interfacewatcher.go @@ -1,20 +1,20 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel import ( + "errors" + "fmt" "log" "sync" + "time" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/services" "golang.zx2c4.com/wireguard/windows/tunnel/firewall" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" @@ -24,63 +24,30 @@ type interfaceWatcherError struct { serviceError services.Error err error } + type interfaceWatcherEvent struct { luid winipcfg.LUID family winipcfg.AddressFamily } + type interfaceWatcher struct { - errors chan interfaceWatcherError + errors chan interfaceWatcherError + started chan winipcfg.AddressFamily - device *device.Device - conf *conf.Config - tun *tun.NativeTun + conf *conf.Config + adapter *driver.Adapter + luid winipcfg.LUID setupMutex sync.Mutex interfaceChangeCallback winipcfg.ChangeCallback changeCallbacks4 []winipcfg.ChangeCallback changeCallbacks6 []winipcfg.ChangeCallback storedEvents []interfaceWatcherEvent -} - -func hasDefaultRoute(family winipcfg.AddressFamily, peers []conf.Peer) bool { - var ( - foundV401 bool - foundV41281 bool - foundV600001 bool - foundV680001 bool - foundV400 bool - foundV600 bool - v40 = [4]byte{} - v60 = [16]byte{} - v48 = [4]byte{0x80} - v68 = [16]byte{0x80} - ) - for _, peer := range peers { - for _, allowedip := range peer.AllowedIPs { - if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { - foundV600001 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) { - foundV680001 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { - foundV401 = true - } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) { - foundV41281 = true - } else if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { - foundV600 = true - } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { - foundV400 = true - } - } - } - if family == windows.AF_INET { - return foundV400 || (foundV401 && foundV41281) - } else if family == windows.AF_INET6 { - return foundV600 || (foundV600001 && foundV680001) - } - return false + watchdog *time.Timer } func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { + iw.watchdog.Stop() var changeCallbacks *[]winipcfg.ChangeCallback var ipversion string if family == windows.AF_INET { @@ -100,25 +67,35 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) { } var err error - log.Printf("Monitoring default %s routes", ipversion) - *changeCallbacks, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun) - if err != nil { - iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err} - return + if iw.conf.Interface.MTU == 0 { + log.Printf("Monitoring MTU of default %s routes", ipversion) + *changeCallbacks, err = monitorMTU(family, iw.luid) + if err != nil { + iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err} + return + } } log.Printf("Setting device %s addresses", ipversion) - err = configureInterface(family, iw.conf, iw.tun) + err = configureInterface(family, iw.conf, iw.luid) if err != nil { iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err} return } + evaluateDynamicPitfalls(family, iw.conf, iw.luid) + + iw.started <- family } func watchInterface() (*interfaceWatcher, error) { iw := &interfaceWatcher{ - errors: make(chan interfaceWatcherError, 2), + errors: make(chan interfaceWatcherError, 2), + started: make(chan winipcfg.AddressFamily, 4), } + iw.watchdog = time.AfterFunc(time.Duration(1<<63-1), func() { + iw.errors <- interfaceWatcherError{services.ErrorCreateNetworkAdapter, errors.New("TCP/IP interface for adapter did not appear after one minute")} + }) + iw.watchdog.Stop() var err error iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { iw.setupMutex.Lock() @@ -127,28 +104,41 @@ func watchInterface() (*interfaceWatcher, error) { if notificationType != winipcfg.MibAddInstance { return } - if iw.tun == nil { + if iw.luid == 0 { iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family}) return } - if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) { + if iface.InterfaceLUID != iw.luid { return } iw.setup(iface.Family) + + if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown { + log.Println("Reinitializing adapter configuration") + err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration()) + if err != nil { + log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err)) + } + err = iw.adapter.SetAdapterState(driver.AdapterStateUp) + if err != nil { + log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err)) + } + } }) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to register interface change callback: %w", err) } return iw, nil } -func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, tun *tun.NativeTun) { +func (iw *interfaceWatcher) Configure(adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) { iw.setupMutex.Lock() defer iw.setupMutex.Unlock() + iw.watchdog.Reset(time.Minute) - iw.device, iw.conf, iw.tun = device, conf, tun + iw.adapter, iw.conf, iw.luid = adapter, conf, luid for _, event := range iw.storedEvents { - if event.luid == winipcfg.LUID(iw.tun.LUID()) { + if event.luid == luid { iw.setup(event.family) } } @@ -157,10 +147,11 @@ func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, func (iw *interfaceWatcher) Destroy() { iw.setupMutex.Lock() + iw.watchdog.Stop() changeCallbacks4 := iw.changeCallbacks4 changeCallbacks6 := iw.changeCallbacks6 interfaceChangeCallback := iw.interfaceChangeCallback - tun := iw.tun + luid := iw.luid iw.setupMutex.Unlock() if interfaceChangeCallback != nil { @@ -186,15 +177,15 @@ func (iw *interfaceWatcher) Destroy() { changeCallbacks6 = changeCallbacks6[1:] } firewall.DisableFirewall() - if tun != nil && iw.tun == tun { + if luid != 0 && iw.luid == luid { // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active // routes, so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(tun.LUID()) luid.FlushRoutes(windows.AF_INET) luid.FlushIPAddresses(windows.AF_INET) + luid.FlushDNS(windows.AF_INET) luid.FlushRoutes(windows.AF_INET6) luid.FlushIPAddresses(windows.AF_INET6) - luid.FlushDNS() + luid.FlushDNS(windows.AF_INET6) } iw.setupMutex.Unlock() } diff --git a/tunnel/ipcpermissions.go b/tunnel/ipcpermissions.go deleted file mode 100644 index 613d0283..00000000 --- a/tunnel/ipcpermissions.go +++ /dev/null @@ -1,63 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package tunnel - -import ( - "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/ipc" - - "golang.zx2c4.com/wireguard/windows/conf" -) - -func CopyConfigOwnerToIPCSecurityDescriptor(filename string) error { - if conf.PathIsEncrypted(filename) { - return nil - } - - fileSd, err := windows.GetNamedSecurityInfo(filename, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) - if err != nil { - return err - } - fileOwner, _, err := fileSd.Owner() - if err != nil { - return err - } - if fileOwner.IsWellKnown(windows.WinLocalSystemSid) { - return nil - } - additionalEntries := []windows.EXPLICIT_ACCESS{{ - AccessPermissions: windows.GENERIC_ALL, - AccessMode: windows.GRANT_ACCESS, - Trustee: windows.TRUSTEE{ - TrusteeForm: windows.TRUSTEE_IS_SID, - TrusteeType: windows.TRUSTEE_IS_USER, - TrusteeValue: windows.TrusteeValueFromSID(fileOwner), - }, - }} - - sd, err := ipc.UAPISecurityDescriptor.ToAbsolute() - if err != nil { - return err - } - dacl, defaulted, _ := sd.DACL() - - newDacl, err := windows.ACLFromEntries(additionalEntries, dacl) - if err != nil { - return err - } - err = sd.SetDACL(newDacl, true, defaulted) - if err != nil { - return err - } - sd, err = sd.ToSelfRelative() - if err != nil { - return err - } - ipc.UAPISecurityDescriptor = sd - - return nil -} diff --git a/tunnel/defaultroutemonitor.go b/tunnel/mtumonitor.go index 3af9042c..c07823a2 100644 --- a/tunnel/defaultroutemonitor.go +++ b/tunnel/mtumonitor.go @@ -1,30 +1,23 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel import ( - "log" - "sync" - "time" - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32, blackholeWhenLoop bool) error { +func findDefaultLUID(family winipcfg.AddressFamily, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32) error { r, err := winipcfg.GetIPForwardTable2(family) if err != nil { return err } lowestMetric := ^uint32(0) - index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want. - luid := winipcfg.LUID(0) // Hopefully luid zero is unspecified, but hard to find docs saying so. + index := uint32(0) + luid := winipcfg.LUID(0) for i := range r { if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID { continue @@ -50,40 +43,24 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU } *lastLUID = luid *lastIndex = index - blackhole := blackholeWhenLoop && index == 0 - bind, _ := device.Bind().(conn.BindSocketToInterface) - if bind == nil { - return nil - } - if family == windows.AF_INET { - log.Printf("Binding v4 socket to interface %d (blackhole=%v)", index, blackhole) - return bind.BindSocketToInterface4(index, blackhole) - } else if family == windows.AF_INET6 { - log.Printf("Binding v6 socket to interface %d (blackhole=%v)", index, blackhole) - return bind.BindSocketToInterface6(index, blackhole) - } return nil } -func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, blackholeWhenLoop bool, tun *tun.NativeTun) ([]winipcfg.ChangeCallback, error) { +func monitorMTU(family winipcfg.AddressFamily, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) { var minMTU uint32 if family == windows.AF_INET { minMTU = 576 } else if family == windows.AF_INET6 { minMTU = 1280 } - ourLUID := winipcfg.LUID(tun.LUID()) lastLUID := winipcfg.LUID(0) lastIndex := ^uint32(0) lastMTU := uint32(0) doIt := func() error { - err := bindSocketRoute(family, device, ourLUID, &lastLUID, &lastIndex, blackholeWhenLoop) + err := findDefaultLUID(family, ourLUID, &lastLUID, &lastIndex) if err != nil { return err } - if !autoMTU { - return nil - } mtu := uint32(0) if lastLUID != 0 { iface, err := lastLUID.Interface() @@ -107,7 +84,6 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, if err != nil { return err } - tun.ForceMTU(int(iface.NLMTU)) // TODO: having one MTU for both v4 and v6 kind of breaks the windows model, so right now this just gets the second one which is... bad. lastMTU = mtu } return nil @@ -116,32 +92,9 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, if err != nil { return nil, err } - - firstBurst := time.Time{} - burstMutex := sync.Mutex{} - burstTimer := time.AfterFunc(time.Hour*200, func() { - burstMutex.Lock() - firstBurst = time.Time{} - doIt() - burstMutex.Unlock() - }) - burstTimer.Stop() - bump := func() { - burstMutex.Lock() - burstTimer.Reset(time.Millisecond * 150) - if firstBurst.IsZero() { - firstBurst = time.Now() - } else if time.Since(firstBurst) > time.Second*2 { - firstBurst = time.Time{} - burstTimer.Stop() - doIt() - } - burstMutex.Unlock() - } - cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) { if route != nil && route.DestinationPrefix.PrefixLength == 0 { - bump() + doIt() } }) if err != nil { @@ -149,7 +102,7 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, } cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) { if notificationType == winipcfg.MibParameterNotification { - bump() + doIt() } }) if err != nil { diff --git a/tunnel/pitfalls.go b/tunnel/pitfalls.go new file mode 100644 index 00000000..fdef6eb2 --- /dev/null +++ b/tunnel/pitfalls.go @@ -0,0 +1,177 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package tunnel + +import ( + "log" + "net/netip" + "strings" + "unsafe" + + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/svc/mgr" + "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func evaluateStaticPitfalls() { + go func() { + pitfallDnsCacheDisabled() + pitfallVirtioNetworkDriver() + }() +} + +func evaluateDynamicPitfalls(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) { + go func() { + pitfallWeakHostSend(family, conf, luid) + }() +} + +func pitfallDnsCacheDisabled() { + scm, err := mgr.Connect() + if err != nil { + return + } + defer scm.Disconnect() + svc := mgr.Service{Name: "dnscache"} + svc.Handle, err = windows.OpenService(scm.Handle, windows.StringToUTF16Ptr(svc.Name), windows.SERVICE_QUERY_CONFIG) + if err != nil { + return + } + defer svc.Close() + cfg, err := svc.Config() + if err != nil { + return + } + if cfg.StartType != mgr.StartDisabled { + return + } + + log.Printf("Warning: the %q (dnscache) service is disabled; please re-enable it", cfg.DisplayName) +} + +func pitfallVirtioNetworkDriver() { + var modules []windows.RTL_PROCESS_MODULE_INFORMATION + for bufferSize := uint32(128 * 1024); ; { + moduleBuffer := make([]byte, bufferSize) + err := windows.NtQuerySystemInformation(windows.SystemModuleInformation, unsafe.Pointer(&moduleBuffer[0]), bufferSize, &bufferSize) + switch err { + case windows.STATUS_INFO_LENGTH_MISMATCH: + continue + case nil: + break + default: + return + } + mods := (*windows.RTL_PROCESS_MODULES)(unsafe.Pointer(&moduleBuffer[0])) + modules = unsafe.Slice(&mods.Modules[0], mods.NumberOfModules) + break + } + for i := range modules { + if !strings.EqualFold(windows.ByteSliceToString(modules[i].FullPathName[modules[i].OffsetToFileName:]), "netkvm.sys") { + continue + } + driverPath := `\\?\GLOBALROOT` + windows.ByteSliceToString(modules[i].FullPathName[:]) + var zero windows.Handle + infoSize, err := windows.GetFileVersionInfoSize(driverPath, &zero) + if err != nil { + return + } + versionInfo := make([]byte, infoSize) + err = windows.GetFileVersionInfo(driverPath, 0, infoSize, unsafe.Pointer(&versionInfo[0])) + if err != nil { + return + } + var fixedInfo *windows.VS_FIXEDFILEINFO + fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo)) + err = windows.VerQueryValue(unsafe.Pointer(&versionInfo[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen) + if err != nil { + return + } + const minimumPlausibleVersion = 40 << 48 + const minimumGoodVersion = (100 << 48) | (85 << 32) | (104 << 16) | (20800 << 0) + version := (uint64(fixedInfo.FileVersionMS) << 32) | uint64(fixedInfo.FileVersionLS) + if version >= minimumGoodVersion || version < minimumPlausibleVersion { + return + } + log.Println("Warning: the VirtIO network driver (NetKVM) is out of date and may cause known problems; please update to v100.85.104.20800 or later") + return + } +} + +func pitfallWeakHostSend(family winipcfg.AddressFamily, conf *conf.Config, ourLUID winipcfg.LUID) { + routingTable, err := winipcfg.GetIPForwardTable2(family) + if err != nil { + return + } + type endpointRoute struct { + addr netip.Addr + name string + lowestMetric uint32 + highestCIDR uint8 + weakHostSend bool + finalIsOurs bool + } + endpoints := make([]endpointRoute, 0, len(conf.Peers)) + for _, peer := range conf.Peers { + addr, err := netip.ParseAddr(peer.Endpoint.Host) + if err != nil || (addr.Is4() && family != windows.AF_INET) || (addr.Is6() && family != windows.AF_INET6) { + continue + } + endpoints = append(endpoints, endpointRoute{addr: addr, lowestMetric: ^uint32(0)}) + } + for i := range routingTable { + var ( + ifrow *winipcfg.MibIfRow2 + ifacerow *winipcfg.MibIPInterfaceRow + metric uint32 + ) + for j := range endpoints { + r, e := &routingTable[i], &endpoints[j] + if r.DestinationPrefix.PrefixLength < e.highestCIDR { + continue + } + if !r.DestinationPrefix.Prefix().Contains(e.addr) { + continue + } + if ifrow == nil { + ifrow, err = r.InterfaceLUID.Interface() + if err != nil { + continue + } + } + if ifrow.OperStatus != winipcfg.IfOperStatusUp { + continue + } + if ifacerow == nil { + ifacerow, err = r.InterfaceLUID.IPInterface(family) + if err != nil { + continue + } + metric = r.Metric + ifacerow.Metric + } + if r.DestinationPrefix.PrefixLength == e.highestCIDR && metric > e.lowestMetric { + continue + } + e.lowestMetric = metric + e.highestCIDR = r.DestinationPrefix.PrefixLength + e.finalIsOurs = r.InterfaceLUID == ourLUID + if !e.finalIsOurs { + e.name = ifrow.Alias() + e.weakHostSend = ifacerow.ForwardingEnabled || ifacerow.WeakHostSend + } + } + } + problematicInterfaces := make(map[string]bool, len(endpoints)) + for _, e := range endpoints { + if e.weakHostSend && e.finalIsOurs { + problematicInterfaces[e.name] = true + } + } + for iface := range problematicInterfaces { + log.Printf("Warning: the %q interface has Forwarding/WeakHostSend enabled, which will cause routing loops", iface) + } +} diff --git a/tunnel/scriptrunner.go b/tunnel/scriptrunner.go new file mode 100644 index 00000000..eb97d98d --- /dev/null +++ b/tunnel/scriptrunner.go @@ -0,0 +1,77 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. + */ + +package tunnel + +import ( + "bufio" + "fmt" + "log" + "os" + "path/filepath" + "syscall" + + "golang.org/x/sys/windows" + + "golang.zx2c4.com/wireguard/windows/conf" +) + +func runScriptCommand(command, interfaceName string) error { + if len(command) == 0 { + return nil + } + if !conf.AdminBool("DangerousScriptExecution") { + log.Printf("Skipping execution of script, because dangerous script execution is safely disabled: %#q", command) + return nil + } + log.Printf("Executing: %#q", command) + comspec, _ := os.LookupEnv("COMSPEC") + if len(comspec) == 0 { + system32, err := windows.GetSystemDirectory() + if err != nil { + return err + } + comspec = filepath.Join(system32, "cmd.exe") + } + + devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0) + if err != nil { + return err + } + defer devNull.Close() + reader, writer, err := os.Pipe() + if err != nil { + return err + } + process, err := os.StartProcess(comspec, nil /* CmdLine below */, &os.ProcAttr{ + Files: []*os.File{devNull, writer, writer}, + Env: append(os.Environ(), "WIREGUARD_TUNNEL_NAME="+interfaceName), + Sys: &syscall.SysProcAttr{ + HideWindow: true, + CmdLine: fmt.Sprintf("cmd /c %s", command), + }, + }) + writer.Close() + if err != nil { + reader.Close() + return err + } + go func() { + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + log.Printf("cmd> %s", scanner.Text()) + } + }() + state, err := process.Wait() + reader.Close() + if err != nil { + return err + } + if state.ExitCode() == 0 { + return nil + } + log.Printf("Command error exit status: %d", state.ExitCode()) + return windows.ERROR_GENERIC_COMMAND_FAILED +} diff --git a/tunnel/service.go b/tunnel/service.go index e535894b..a56ed1f3 100644 --- a/tunnel/service.go +++ b/tunnel/service.go @@ -1,33 +1,27 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package tunnel import ( - "bufio" "bytes" "fmt" "log" - "net" "os" "runtime" - "runtime/debug" - "strings" "time" + "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/windows/conf" + "golang.zx2c4.com/wireguard/windows/driver" "golang.zx2c4.com/wireguard/windows/elevate" "golang.zx2c4.com/wireguard/windows/ringlogger" "golang.zx2c4.com/wireguard/windows/services" - "golang.zx2c4.com/wireguard/windows/version" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) type tunnelService struct { @@ -35,12 +29,13 @@ type tunnelService struct { } func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) { - changes <- svc.Status{State: svc.StartPending} + serviceState := svc.StartPending + changes <- svc.Status{State: serviceState} - var dev *device.Device - var uapi net.Listener var watcher *interfaceWatcher - var nativeTun *tun.NativeTun + var adapter *driver.Adapter + var luid winipcfg.LUID + var config *conf.Config var err error serviceError := services.ErrorSuccess @@ -50,7 +45,8 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, if logErr != nil { log.Println(logErr) } - changes <- svc.Status{State: svc.StopPending} + serviceState = svc.StopPending + changes <- svc.Status{State: serviceState} stopIt := make(chan bool, 1) go func() { @@ -84,65 +80,63 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } }() + if logErr == nil && adapter != nil && config != nil { + logErr = runScriptCommand(config.Interface.PreDown, config.Name) + } if watcher != nil { watcher.Destroy() } - if uapi != nil { - uapi.Close() + if adapter != nil { + adapter.Close() } - if dev != nil { - dev.Close() + if logErr == nil && adapter != nil && config != nil { + _ = runScriptCommand(config.Interface.PostDown, config.Name) } stopIt <- true log.Println("Shutting down") }() - err = ringlogger.InitGlobalLogger("TUN") + var logFile string + logFile, err = conf.LogFile(true) if err != nil { serviceError = services.ErrorRingloggerOpen return } - defer func() { - if x := recover(); x != nil { - for _, line := range append([]string{fmt.Sprint(x)}, strings.Split(string(debug.Stack()), "\n")...) { - if len(strings.TrimSpace(line)) > 0 { - log.Println(line) - } - } - panic(x) - } - }() - - conf, err := conf.LoadFromPath(service.Path) + err = ringlogger.InitGlobalLogger(logFile, "TUN") if err != nil { - serviceError = services.ErrorLoadConfiguration + serviceError = services.ErrorRingloggerOpen return } - conf.DeduplicateNetworkEntries() - err = CopyConfigOwnerToIPCSecurityDescriptor(service.Path) + + config, err = conf.LoadFromPath(service.Path) if err != nil { serviceError = services.ErrorLoadConfiguration return } + config.DeduplicateNetworkEntries() - logPrefix := fmt.Sprintf("[%s] ", conf.Name) - log.SetPrefix(logPrefix) + log.SetPrefix(fmt.Sprintf("[%s] ", config.Name)) - log.Println("Starting", version.UserAgent()) + services.PrintStarting() - if m, err := mgr.Connect(); err == nil { - if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked { - /* If we don't do this, then the Wintun installation will block forever, because - * installing a Wintun device starts a service too. Apparently at boot time, Windows - * 8.1 locks the SCM for each service start, creating a deadlock if we don't announce - * that we're running before starting additional services. - */ - log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner) - changes <- svc.Status{State: svc.Running} + if services.StartedAtBoot() { + if m, err := mgr.Connect(); err == nil { + if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked { + /* If we don't do this, then the driver installation will block forever, because + * installing a network adapter starts the driver service too. Apparently at boot time, + * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't + * announce that we're running before starting additional services. + */ + log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner) + serviceState = svc.Running + changes <- svc.Status{State: serviceState} + } + m.Disconnect() } - m.Disconnect() } + evaluateStaticPitfalls() + log.Println("Watching network interfaces") watcher, err = watchInterface() if err != nil { @@ -151,28 +145,49 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, } log.Println("Resolving DNS names") - uapiConf, err := conf.ToUAPI() + err = config.ResolveEndpoints() if err != nil { serviceError = services.ErrorDNSLookup return } - log.Println("Creating Wintun interface") - wintun, err := tun.CreateTUNWithRequestedGUID(conf.Name, deterministicGUID(conf), 0) + log.Println("Creating network adapter") + for i := 0; i < 15; i++ { + if i > 0 { + time.Sleep(time.Second) + log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err) + } + adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config)) + if err == nil || !services.StartedAtBoot() { + break + } + } if err != nil { - serviceError = services.ErrorCreateWintun + err = fmt.Errorf("Error creating adapter: %w", err) + serviceError = services.ErrorCreateNetworkAdapter return } - nativeTun = wintun.(*tun.NativeTun) - wintunVersion, ndisVersion, err := nativeTun.Version() + luid = adapter.LUID() + driverVersion, err := driver.RunningVersion() if err != nil { - log.Printf("Warning: unable to determine Wintun version: %v", err) + log.Printf("Warning: unable to determine driver version: %v", err) } else { - log.Printf("Using Wintun/%s (NDIS %s)", wintunVersion, ndisVersion) + log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff) + } + err = adapter.SetLogging(driver.AdapterLogOn) + if err != nil { + err = fmt.Errorf("Error enabling adapter logging: %w", err) + serviceError = services.ErrorCreateNetworkAdapter + return } - log.Println("Enabling firewall rules") - err = enableFirewall(conf, nativeTun) + err = runScriptCommand(config.Interface.PreUp, config.Name) + if err != nil { + serviceError = services.ErrorRunScript + return + } + + err = enableFirewall(config, luid) if err != nil { serviceError = services.ErrorFirewall return @@ -185,43 +200,28 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, return } - log.Println("Creating interface instance") - logOutput := log.New(ringlogger.Global, logPrefix, 0) - logger := &device.Logger{logOutput, logOutput, logOutput} - dev = device.NewDevice(wintun, logger) - log.Println("Setting interface configuration") - uapi, err = ipc.UAPIListen(conf.Name) + err = adapter.SetConfiguration(config.ToDriverConfiguration()) if err != nil { - serviceError = services.ErrorUAPIListen + serviceError = services.ErrorDeviceSetConfig return } - ipcErr := dev.IpcSetOperation(bufio.NewReader(strings.NewReader(uapiConf))) - if ipcErr != nil { - err = ipcErr - serviceError = services.ErrorDeviceSetConfig + err = adapter.SetAdapterState(driver.AdapterStateUp) + if err != nil { + serviceError = services.ErrorDeviceBringUp return } + watcher.Configure(adapter, config, luid) - log.Println("Bringing peers up") - dev.Up() - - watcher.Configure(dev, conf, nativeTun) - - log.Println("Listening for UAPI requests") - go func() { - for { - conn, err := uapi.Accept() - if err != nil { - continue - } - go dev.IpcHandle(conn) - } - }() + err = runScriptCommand(config.Interface.PostUp, config.Name) + if err != nil { + serviceError = services.ErrorRunScript + return + } - changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} - log.Println("Startup complete") + changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown} + var started bool for { select { case c := <-r: @@ -233,8 +233,13 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, default: log.Printf("Unexpected service control request #%d\n", c) } - case <-dev.Wait(): - return + case <-watcher.started: + if !started { + serviceState = svc.Running + changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown} + log.Println("Startup complete") + started = true + } case e := <-watcher.errors: serviceError, err = e.serviceError, e.err return @@ -247,7 +252,7 @@ func Run(confPath string) error { if err != nil { return err } - serviceName, err := services.ServiceNameOfTunnel(name) + serviceName, err := conf.ServiceNameOfTunnel(name) if err != nil { return err } diff --git a/tunnel/winipcfg/interface_change_handler.go b/tunnel/winipcfg/interface_change_handler.go index 9406c18a..af29801a 100644 --- a/tunnel/winipcfg/interface_change_handler.go +++ b/tunnel/winipcfg/interface_change_handler.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/luid.go b/tunnel/winipcfg/luid.go index 8f5ba61b..0c898b89 100644 --- a/tunnel/winipcfg/luid.go +++ b/tunnel/winipcfg/luid.go @@ -1,17 +1,16 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg import ( "errors" - "fmt" - "net" + "net/netip" + "strings" "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" ) // LUID represents a network interface. @@ -64,12 +63,23 @@ func LUIDFromGUID(guid *windows.GUID) (LUID, error) { return luid, nil } +// LUIDFromIndex function converts a local index for a network interface to the locally unique identifier (LUID) for the interface. +// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceindextoluid +func LUIDFromIndex(index uint32) (LUID, error) { + var luid LUID + err := convertInterfaceIndexToLUID(index, &luid) + if err != nil { + return 0, err + } + return luid, nil +} + // IPAddress method returns MibUnicastIPAddressRow struct that matches to provided 'ip' argument. Corresponds to GetUnicastIpAddressEntry // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry) -func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) { +func (luid LUID) IPAddress(addr netip.Addr) (*MibUnicastIPAddressRow, error) { row := &MibUnicastIPAddressRow{InterfaceLUID: luid} - err := row.Address.SetIP(ip, 0) + err := row.Address.SetAddr(addr) if err != nil { return nil, err } @@ -84,22 +94,24 @@ func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) { // AddIPAddress method adds new unicast IP address to the interface. Corresponds to CreateUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry). -func (luid LUID) AddIPAddress(address net.IPNet) error { +func (luid LUID) AddIPAddress(address netip.Prefix) error { row := &MibUnicastIPAddressRow{} row.Init() row.InterfaceLUID = luid - err := row.Address.SetIP(address.IP, 0) + row.DadState = DadStatePreferred + row.ValidLifetime = 0xffffffff + row.PreferredLifetime = 0xffffffff + err := row.Address.SetAddr(address.Addr()) if err != nil { return err } - ones, _ := address.Mask.Size() - row.OnLinkPrefixLength = uint8(ones) + row.OnLinkPrefixLength = uint8(address.Bits()) return row.Create() } // AddIPAddresses method adds multiple new unicast IP addresses to the interface. Corresponds to CreateUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry). -func (luid LUID) AddIPAddresses(addresses []net.IPNet) error { +func (luid LUID) AddIPAddresses(addresses []netip.Prefix) error { for i := range addresses { err := luid.AddIPAddress(addresses[i]) if err != nil { @@ -110,7 +122,7 @@ func (luid LUID) AddIPAddresses(addresses []net.IPNet) error { } // SetIPAddresses method sets new unicast IP addresses to the interface. -func (luid LUID) SetIPAddresses(addresses []net.IPNet) error { +func (luid LUID) SetIPAddresses(addresses []netip.Prefix) error { err := luid.FlushIPAddresses(windows.AF_UNSPEC) if err != nil { return err @@ -119,16 +131,15 @@ func (luid LUID) SetIPAddresses(addresses []net.IPNet) error { } // SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface. -func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error { +func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []netip.Prefix) error { err := luid.FlushIPAddresses(family) if err != nil { return err } for i := range addresses { - asV4 := addresses[i].IP.To4() - if asV4 == nil && family == windows.AF_INET { + if !addresses[i].Addr().Is4() && family == windows.AF_INET { continue - } else if asV4 != nil && family == windows.AF_INET6 { + } else if !addresses[i].Addr().Is6() && family == windows.AF_INET6 { continue } err := luid.AddIPAddress(addresses[i]) @@ -141,17 +152,16 @@ func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.I // DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry). -func (luid LUID) DeleteIPAddress(address net.IPNet) error { +func (luid LUID) DeleteIPAddress(address netip.Prefix) error { row := &MibUnicastIPAddressRow{} row.Init() row.InterfaceLUID = luid - err := row.Address.SetIP(address.IP, 0) + err := row.Address.SetAddr(address.Addr()) if err != nil { return err } // Note: OnLinkPrefixLength member is ignored by DeleteUnicastIpAddressEntry(). - ones, _ := address.Mask.Size() - row.OnLinkPrefixLength = uint8(ones) + row.OnLinkPrefixLength = uint8(address.Bits()) return row.Delete() } @@ -175,15 +185,17 @@ func (luid LUID) FlushIPAddresses(family AddressFamily) error { // Route method returns route determined with the input arguments. Corresponds to GetIpForwardEntry2 function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2). // NOTE: If the corresponding route isn't found, the method will return error. -func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2, error) { +func (luid LUID) Route(destination netip.Prefix, nextHop netip.Addr) (*MibIPforwardRow2, error) { row := &MibIPforwardRow2{} row.Init() row.InterfaceLUID = luid - err := row.DestinationPrefix.SetIPNet(destination) + row.ValidLifetime = 0xffffffff + row.PreferredLifetime = 0xffffffff + err := row.DestinationPrefix.SetPrefix(destination) if err != nil { return nil, err } - err = row.NextHop.SetIP(nextHop, 0) + err = row.NextHop.SetAddr(nextHop) if err != nil { return nil, err } @@ -197,15 +209,15 @@ func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2 // AddRoute method adds a route to the interface. Corresponds to CreateIpForwardEntry2 function, with added splitDefault feature. // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2) -func (luid LUID) AddRoute(destination net.IPNet, nextHop net.IP, metric uint32) error { +func (luid LUID) AddRoute(destination netip.Prefix, nextHop netip.Addr, metric uint32) error { row := &MibIPforwardRow2{} row.Init() row.InterfaceLUID = luid - err := row.DestinationPrefix.SetIPNet(destination) + err := row.DestinationPrefix.SetPrefix(destination) if err != nil { return err } - err = row.NextHop.SetIP(nextHop, 0) + err = row.NextHop.SetAddr(nextHop) if err != nil { return err } @@ -240,10 +252,9 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat return err } for _, rd := range routesData { - asV4 := rd.Destination.IP.To4() - if asV4 == nil && family == windows.AF_INET { + if !rd.Destination.Addr().Is4() && family == windows.AF_INET { continue - } else if asV4 != nil && family == windows.AF_INET6 { + } else if !rd.Destination.Addr().Is6() && family == windows.AF_INET6 { continue } err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric) @@ -256,15 +267,15 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat // DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function // (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2). -func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error { +func (luid LUID) DeleteRoute(destination netip.Prefix, nextHop netip.Addr) error { row := &MibIPforwardRow2{} row.Init() row.InterfaceLUID = luid - err := row.DestinationPrefix.SetIPNet(destination) + err := row.DestinationPrefix.SetPrefix(destination) if err != nil { return err } - err = row.NextHop.SetIP(nextHop, 0) + err = row.NextHop.SetAddr(nextHop) if err != nil { return err } @@ -297,17 +308,19 @@ func (luid LUID) FlushRoutes(family AddressFamily) error { } // DNS method returns all DNS server addresses associated with the adapter. -func (luid LUID) DNS() ([]net.IP, error) { +func (luid LUID) DNS() ([]netip.Addr, error) { addresses, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault) if err != nil { return nil, err } - r := make([]net.IP, 0, len(addresses)) + r := make([]netip.Addr, 0, len(addresses)) for _, addr := range addresses { if addr.LUID == luid { for dns := addr.FirstDNSServerAddress; dns != nil; dns = dns.Next { if ip := dns.Address.IP(); ip != nil { - r = append(r, ip) + if a, ok := netip.AddrFromSlice(ip); ok { + r = append(r, a) + } } else { return nil, windows.ERROR_INVALID_PARAMETER } @@ -317,141 +330,58 @@ func (luid LUID) DNS() ([]net.IP, error) { return r, nil } -const ( - netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both" - netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both" - netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no" - netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no" -) - -// FlushDNS method clears all DNS servers associated with the adapter. -func (luid LUID) FlushDNS() error { - cmds := make([]string, 0, 2) - ipif4, err := luid.IPInterface(windows.AF_INET) - if err == nil { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush4, ipif4.InterfaceIndex)) - } - ipif6, err := luid.IPInterface(windows.AF_INET6) - if err == nil { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush6, ipif6.InterfaceIndex)) - } - - if len(cmds) == 0 { - return nil - } - return runNetsh(cmds) -} - -// AddDNS method associates additional DNS servers with the adapter. -func (luid LUID) AddDNS(dnses []net.IP) error { - var ipif4, ipif6 *MibIPInterfaceRow - var err error - cmds := make([]string, 0, len(dnses)) - for i := 0; i < len(dnses); i++ { - if v4 := dnses[i].To4(); v4 != nil { - if ipif4 == nil { - ipif4, err = luid.IPInterface(windows.AF_INET) - if err != nil { - return err - } - } - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif4.InterfaceIndex, v4.String())) - } else if v6 := dnses[i].To16(); v6 != nil { - if ipif6 == nil { - ipif6, err = luid.IPInterface(windows.AF_INET6) - if err != nil { - return err - } - } - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif6.InterfaceIndex, v6.String())) - } - } - - if len(cmds) == 0 { - return nil +// SetDNS method clears previous and associates new DNS servers and search domains with the adapter for a specific family. +func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []string) error { + if family != windows.AF_INET && family != windows.AF_INET6 { + return windows.ERROR_PROTOCOL_UNREACHABLE } - return runNetsh(cmds) -} -// SetDNS method clears previous and associates new DNS servers with the adapter. -func (luid LUID) SetDNS(dnses []net.IP) error { - cmds := make([]string, 0, 2+len(dnses)) - ipif4, err := luid.IPInterface(windows.AF_INET) - if err == nil { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush4, ipif4.InterfaceIndex)) - } - ipif6, err := luid.IPInterface(windows.AF_INET6) - if err == nil { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush6, ipif6.InterfaceIndex)) - } - for i := 0; i < len(dnses); i++ { - if v4 := dnses[i].To4(); v4 != nil { - if ipif4 == nil { - return windows.ERROR_NOT_SUPPORTED - } - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif4.InterfaceIndex, v4.String())) - } else if v6 := dnses[i].To16(); v6 != nil { - if ipif6 == nil { - return windows.ERROR_NOT_SUPPORTED - } - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif6.InterfaceIndex, v6.String())) + var filteredServers []string + for _, server := range servers { + if (server.Is4() && family == windows.AF_INET) || (server.Is6() && family == windows.AF_INET6) { + filteredServers = append(filteredServers, server.String()) } } - - if len(cmds) == 0 { - return nil - } - return runNetsh(cmds) -} - -// SetDNSForFamily method clears previous and associates new DNS servers with the adapter for a specific family. -func (luid LUID) SetDNSForFamily(family AddressFamily, dnses []net.IP) error { - var templateFlush string - if family == windows.AF_INET { - templateFlush = netshCmdTemplateFlush4 - } else if family == windows.AF_INET6 { - templateFlush = netshCmdTemplateFlush6 - } - - cmds := make([]string, 0, 1+len(dnses)) - ipif, err := luid.IPInterface(family) + servers16, err := windows.UTF16PtrFromString(strings.Join(filteredServers, ",")) if err != nil { return err } - cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex)) - for i := 0; i < len(dnses); i++ { - if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String())) - } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 { - cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String())) - } + domains16, err := windows.UTF16PtrFromString(strings.Join(domains, ",")) + if err != nil { + return err } - return runNetsh(cmds) -} - -// SetDNSDomain method sets the interface-specific DNS domain. -func (luid LUID) SetDNSDomain(domain string) error { guid, err := luid.GUID() if err != nil { - return fmt.Errorf("Error converting luid to guid: %v", err) + return err } - key, err := registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", guid), registry.QUERY_VALUE) - if err != nil { - return fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err) + dnsInterfaceSettings := &DnsInterfaceSettings{ + Version: DnsInterfaceSettingsVersion1, + Flags: DnsInterfaceSettingsFlagNameserver | DnsInterfaceSettingsFlagSearchList, + NameServer: servers16, + SearchList: domains16, } - paths, _, err := key.GetStringsValue("IpConfig") - key.Close() - if err != nil { - return fmt.Errorf("Error reading IpConfig registry key: %v", err) + if family == windows.AF_INET6 { + dnsInterfaceSettings.Flags |= DnsInterfaceSettingsFlagIPv6 } - if len(paths) == 0 { - return errors.New("No TCP/IP interfaces found on adapter") + // For >= Windows 10 1809 + err = SetInterfaceDnsSettings(*guid, dnsInterfaceSettings) + if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) { + return err } - key, err = registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), registry.SET_VALUE) + + // For < Windows 10 1809 + err = luid.fallbackSetDNSForFamily(family, servers) if err != nil { - return fmt.Errorf("Unable to open TCP/IP network registry key: %v", err) + return err } - err = key.SetStringValue("Domain", domain) - key.Close() - return err + if len(domains) > 0 { + return luid.fallbackSetDNSDomain(domains[0]) + } else { + return luid.fallbackSetDNSDomain("") + } +} + +// FlushDNS method clears all DNS servers associated with the adapter. +func (luid LUID) FlushDNS(family AddressFamily) error { + return luid.SetDNS(family, nil, nil) } diff --git a/tunnel/winipcfg/mksyscall.go b/tunnel/winipcfg/mksyscall.go index 8edb1cf2..d62d38df 100644 --- a/tunnel/winipcfg/mksyscall.go +++ b/tunnel/winipcfg/mksyscall.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/netsh.go b/tunnel/winipcfg/netsh.go index 4714c520..4f8e5b13 100644 --- a/tunnel/winipcfg/netsh.go +++ b/tunnel/winipcfg/netsh.go @@ -1,49 +1,108 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg import ( "bytes" + "errors" "fmt" "io" + "net/netip" "os/exec" "path/filepath" "strings" + "syscall" "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" ) -// I wish we didn't have to do this. netiohlp.dll (what's used by netsh.exe) has some nice tricks with writing directly -// to the registry and the nsi kernel object, but it's not clear copying those makes for a stable interface. WMI doesn't -// work with v6. CMI isn't in Windows 7. func runNetsh(cmds []string) error { system32, err := windows.GetSystemDirectory() if err != nil { return err } - cmd := exec.Command(filepath.Join(system32, "netsh.exe")) // I wish we could append (, "-f", "CONIN$") but Go sets up the process context wrong. + cmd := exec.Command(filepath.Join(system32, "netsh.exe")) + cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} + stdin, err := cmd.StdinPipe() if err != nil { - return fmt.Errorf("runNetsh stdin pipe - %v", err) + return fmt.Errorf("runNetsh stdin pipe - %w", err) } go func() { defer stdin.Close() io.WriteString(stdin, strings.Join(append(cmds, "exit\r\n"), "\r\n")) }() output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("runNetsh run - %v", err) - } // Horrible kludges, sorry. - cleaned := bytes.ReplaceAll(output, []byte("netsh>"), []byte{}) + cleaned := bytes.ReplaceAll(output, []byte{'\r', '\n'}, []byte{'\n'}) + cleaned = bytes.ReplaceAll(cleaned, []byte("netsh>"), []byte{}) cleaned = bytes.ReplaceAll(cleaned, []byte("There are no Domain Name Servers (DNS) configured on this computer."), []byte{}) cleaned = bytes.TrimSpace(cleaned) - if len(cleaned) != 0 { - return fmt.Errorf("runNetsh returned error strings.\ninput:\n%s\noutput\n:%s", - strings.Join(cmds, "\n"), bytes.ReplaceAll(output, []byte{'\r', '\n'}, []byte{'\n'})) + if len(cleaned) != 0 && err == nil { + return fmt.Errorf("netsh: %#q", string(cleaned)) + } else if err != nil { + return fmt.Errorf("netsh: %v: %#q", err, string(cleaned)) } return nil } + +const ( + netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both" + netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both" + netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no" + netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no" +) + +func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error { + var templateFlush string + if family == windows.AF_INET { + templateFlush = netshCmdTemplateFlush4 + } else if family == windows.AF_INET6 { + templateFlush = netshCmdTemplateFlush6 + } + + cmds := make([]string, 0, 1+len(dnses)) + ipif, err := luid.IPInterface(family) + if err != nil { + return err + } + cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex)) + for i := 0; i < len(dnses); i++ { + if dnses[i].Is4() && family == windows.AF_INET { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, dnses[i].String())) + } else if dnses[i].Is6() && family == windows.AF_INET6 { + cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, dnses[i].String())) + } + } + return runNetsh(cmds) +} + +func (luid LUID) fallbackSetDNSDomain(domain string) error { + guid, err := luid.GUID() + if err != nil { + return fmt.Errorf("Error converting luid to guid: %w", err) + } + key, err := registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", guid), registry.QUERY_VALUE) + if err != nil { + return fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %w", err) + } + paths, _, err := key.GetStringsValue("IpConfig") + key.Close() + if err != nil { + return fmt.Errorf("Error reading IpConfig registry key: %w", err) + } + if len(paths) == 0 { + return errors.New("No TCP/IP interfaces found on adapter") + } + key, err = registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), registry.SET_VALUE) + if err != nil { + return fmt.Errorf("Unable to open TCP/IP network registry key: %w", err) + } + err = key.SetStringValue("Domain", domain) + key.Close() + return err +} diff --git a/tunnel/winipcfg/route_change_handler.go b/tunnel/winipcfg/route_change_handler.go index 1b4bad95..4b78331e 100644 --- a/tunnel/winipcfg/route_change_handler.go +++ b/tunnel/winipcfg/route_change_handler.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go index 81f9335d..8e8f4a59 100644 --- a/tunnel/winipcfg/types.go +++ b/tunnel/winipcfg/types.go @@ -1,13 +1,15 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg import ( - "bytes" - "net" + "encoding/binary" + "fmt" + "net/netip" + "strconv" "unsafe" "golang.org/x/sys/windows" @@ -581,19 +583,17 @@ const ( ScopeLevelCount = 16 ) -// Theoretical array index limitations -const ( - maxIndexCount8 = (1 << 31) - 1 - maxIndexCount16 = (1 << 30) - 1 -) - // RouteData structure describes a route to add type RouteData struct { - Destination net.IPNet - NextHop net.IP + Destination netip.Prefix + NextHop netip.Addr Metric uint32 } +func (routeData *RouteData) String() string { + return fmt.Sprintf("%+v", *routeData) +} + // IPAdapterDNSSuffix structure stores a DNS suffix in a linked list of DNS suffixes for a particular adapter. // https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_dns_suffix type IPAdapterDNSSuffix struct { @@ -609,15 +609,7 @@ func (obj *IPAdapterDNSSuffix) String() string { // AdapterName method returns the name of the adapter with which these addresses are associated. // Unlike an adapter's friendly name, the adapter name returned by AdapterName is permanent and cannot be modified by the user. func (addr *IPAdapterAddresses) AdapterName() string { - if addr.adapterName == nil { - return "" - } - slice := (*(*[maxIndexCount8]uint8)(unsafe.Pointer(addr.adapterName)))[:] - null := bytes.IndexByte(slice, 0) - if null != -1 { - slice = slice[:null] - } - return string(slice) + return windows.BytePtrToString(addr.adapterName) } // DNSSuffix method returns adapter DNS suffix associated with this adapter. @@ -625,7 +617,7 @@ func (addr *IPAdapterAddresses) DNSSuffix() string { if addr.dnsSuffix == nil { return "" } - return windows.UTF16ToString((*(*[maxIndexCount16]uint16)(unsafe.Pointer(addr.dnsSuffix)))[:]) + return windows.UTF16PtrToString(addr.dnsSuffix) } // Description method returns description for the adapter. @@ -633,7 +625,7 @@ func (addr *IPAdapterAddresses) Description() string { if addr.description == nil { return "" } - return windows.UTF16ToString((*(*[maxIndexCount16]uint16)(unsafe.Pointer(addr.description)))[:]) + return windows.UTF16PtrToString(addr.description) } // FriendlyName method returns a user-friendly name for the adapter. For example: "Local Area Connection 1." @@ -642,7 +634,7 @@ func (addr *IPAdapterAddresses) FriendlyName() string { if addr.friendlyName == nil { return "" } - return windows.UTF16ToString((*(*[maxIndexCount16]uint16)(unsafe.Pointer(addr.friendlyName)))[:]) + return windows.UTF16PtrToString(addr.friendlyName) } // PhysicalAddress method returns the Media Access Control (MAC) address for the adapter. @@ -693,9 +685,8 @@ func (row *MibIPInterfaceRow) Set() error { } // get method returns all table rows as a Go slice. -func (tab *mibIPInterfaceTable) get() []MibIPInterfaceRow { - const maxCount = maxIndexCount8 / unsafe.Sizeof(MibIPInterfaceRow{}) - return (*[maxCount]MibIPInterfaceRow)(unsafe.Pointer(&tab.table[0]))[:tab.numEntries] +func (tab *mibIPInterfaceTable) get() (s []MibIPInterfaceRow) { + return unsafe.Slice(&tab.table[0], tab.numEntries) } // free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. @@ -731,9 +722,8 @@ func (row *MibIfRow2) get() (ret error) { } // get method returns all table rows as a Go slice. -func (tab *mibIfTable2) get() []MibIfRow2 { - const maxCount = maxIndexCount8 / unsafe.Sizeof(MibIfRow2{}) - return (*[maxCount]MibIfRow2)(unsafe.Pointer(&tab.table[0]))[:tab.numEntries] +func (tab *mibIfTable2) get() (s []MibIfRow2) { + return unsafe.Slice(&tab.table[0], tab.numEntries) } // free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. @@ -749,45 +739,82 @@ type RawSockaddrInet struct { data [26]byte } -// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port. +func ntohs(i uint16) uint16 { + return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&i))[:]) +} + +func htons(i uint16) uint16 { + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, i) + return *(*uint16)(unsafe.Pointer(&b[0])) +} + +// SetAddrPort method sets family, address, and port to the given IPv4 or IPv6 address and port. // All other members of the structure are set to zero. -func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error { - if v4 := ip.To4(); v4 != nil { +func (addr *RawSockaddrInet) SetAddrPort(addrPort netip.AddrPort) error { + if addrPort.Addr().Is4() { addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)) addr4.Family = windows.AF_INET - copy(addr4.Addr[:], v4) - addr4.Port = port + addr4.Addr = addrPort.Addr().As4() + addr4.Port = htons(addrPort.Port()) for i := 0; i < 8; i++ { addr4.Zero[i] = 0 } return nil - } - - if v6 := ip.To16(); v6 != nil { + } else if addrPort.Addr().Is6() { addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)) addr6.Family = windows.AF_INET6 - addr6.Port = port + addr6.Addr = addrPort.Addr().As16() + addr6.Port = htons(addrPort.Port()) addr6.Flowinfo = 0 - copy(addr6.Addr[:], v6) - addr6.Scope_id = 0 + scopeId := uint32(0) + if z := addrPort.Addr().Zone(); z != "" { + if s, err := strconv.ParseUint(z, 10, 32); err == nil { + scopeId = uint32(s) + } + } + addr6.Scope_id = scopeId return nil } - return windows.ERROR_INVALID_PARAMETER } -// IP method returns IPv4 or IPv6 address. -// If the address is neither IPv4 not IPv6 nil is returned. -func (addr *RawSockaddrInet) IP() net.IP { +// SetAddr method sets family and address to the given IPv4 or IPv6 address. +// All other members of the structure are set to zero. +func (addr *RawSockaddrInet) SetAddr(netAddr netip.Addr) error { + return addr.SetAddrPort(netip.AddrPortFrom(netAddr, 0)) +} + +// AddrPort returns the IP address and port. +func (addr *RawSockaddrInet) AddrPort() netip.AddrPort { + return netip.AddrPortFrom(addr.Addr(), addr.Port()) +} + +// Addr returns IPv4 or IPv6 address, or an invalid address if the address is neither. +func (addr *RawSockaddrInet) Addr() netip.Addr { switch addr.Family { case windows.AF_INET: - return (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr[:] - + return netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr) case windows.AF_INET6: - return (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Addr[:] + raw := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)) + a := netip.AddrFrom16(raw.Addr) + if raw.Scope_id != 0 { + a = a.WithZone(strconv.FormatUint(uint64(raw.Scope_id), 10)) + } + return a } + return netip.Addr{} +} - return nil +// Port returns the port if the address if IPv4 or IPv6, or 0 if neither. +func (addr *RawSockaddrInet) Port() uint16 { + switch addr.Family { + case windows.AF_INET: + return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port) + case windows.AF_INET6: + return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port) + } + return 0 } // Init method initializes a MibUnicastIPAddressRow structure with default values for a unicast IP address entry on the local computer. @@ -821,9 +848,8 @@ func (row *MibUnicastIPAddressRow) Delete() error { } // get method returns all table rows as a Go slice. -func (tab *mibUnicastIPAddressTable) get() []MibUnicastIPAddressRow { - const maxCount = maxIndexCount8 / unsafe.Sizeof(MibUnicastIPAddressRow{}) - return (*[maxCount]MibUnicastIPAddressRow)(unsafe.Pointer(&tab.table[0]))[:tab.numEntries] +func (tab *mibUnicastIPAddressTable) get() (s []MibUnicastIPAddressRow) { + return unsafe.Slice(&tab.table[0], tab.numEntries) } // free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. @@ -851,9 +877,8 @@ func (row *MibAnycastIPAddressRow) Delete() error { } // get method returns all table rows as a Go slice. -func (tab *mibAnycastIPAddressTable) get() []MibAnycastIPAddressRow { - const maxCount = maxIndexCount8 / unsafe.Sizeof(MibAnycastIPAddressRow{}) - return (*[maxCount]MibAnycastIPAddressRow)(unsafe.Pointer(&tab.table[0]))[:tab.numEntries] +func (tab *mibAnycastIPAddressTable) get() (s []MibAnycastIPAddressRow) { + return unsafe.Slice(&tab.table[0], tab.numEntries) } // free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. @@ -865,32 +890,30 @@ func (tab *mibAnycastIPAddressTable) free() { // IPAddressPrefix structure stores an IP address prefix. // https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_ip_address_prefix type IPAddressPrefix struct { - Prefix RawSockaddrInet + RawPrefix RawSockaddrInet PrefixLength uint8 _ [2]byte } -// SetIPNet method sets IP address prefix using net.IPNet. -func (prefix *IPAddressPrefix) SetIPNet(net net.IPNet) error { - err := prefix.Prefix.SetIP(net.IP, 0) +// SetPrefix method sets IP address prefix using netip.Prefix. +func (prefix *IPAddressPrefix) SetPrefix(netPrefix netip.Prefix) error { + err := prefix.RawPrefix.SetAddr(netPrefix.Addr()) if err != nil { return err } - ones, _ := net.Mask.Size() - prefix.PrefixLength = uint8(ones) + prefix.PrefixLength = uint8(netPrefix.Bits()) return nil } -// IPNet method returns IP address prefix as net.IPNet. -// If the address is neither IPv4 not IPv6 an empty net.IPNet is returned. The resulting net.IPNet should be checked appropriately. -func (prefix *IPAddressPrefix) IPNet() net.IPNet { - switch prefix.Prefix.Family { +// Prefix returns IP address prefix as netip.Prefix. +func (prefix *IPAddressPrefix) Prefix() netip.Prefix { + switch prefix.RawPrefix.Family { case windows.AF_INET: - return net.IPNet{IP: (*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv4len)} + return netip.PrefixFrom(netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength)) case windows.AF_INET6: - return net.IPNet{IP: (*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv6len)} + return netip.PrefixFrom(netip.AddrFrom16((*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength)) } - return net.IPNet{} + return netip.Prefix{} } // MibIPforwardRow2 structure stores information about an IP route entry. @@ -944,9 +967,8 @@ func (row *MibIPforwardRow2) Delete() error { } // get method returns all table rows as a Go slice. -func (tab *mibIPforwardTable2) get() []MibIPforwardRow2 { - const maxCount = maxIndexCount8 / unsafe.Sizeof(MibIPforwardRow2{}) - return (*[maxCount]MibIPforwardRow2)(unsafe.Pointer(&tab.table[0]))[:tab.numEntries] +func (tab *mibIPforwardTable2) get() (s []MibIPforwardRow2) { + return unsafe.Slice(&tab.table[0], tab.numEntries) } // free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes. @@ -954,3 +976,43 @@ func (tab *mibIPforwardTable2) get() []MibIPforwardRow2 { func (tab *mibIPforwardTable2) free() { freeMibTable(unsafe.Pointer(tab)) } + +// +// DNS API +// + +// DnsInterfaceSettings is meant to be used with SetInterfaceDnsSettings +type DnsInterfaceSettings struct { + Version uint32 + _ [4]byte + Flags uint64 + Domain *uint16 + NameServer *uint16 + SearchList *uint16 + RegistrationEnabled uint32 + RegisterAdapterName uint32 + EnableLLMNR uint32 + QueryAdapterName uint32 + ProfileNameServer *uint16 +} + +const ( + DnsInterfaceSettingsVersion1 = 1 // for DnsInterfaceSettings + DnsInterfaceSettingsVersion2 = 2 // for DnsInterfaceSettingsEx + DnsInterfaceSettingsVersion3 = 3 // for DnsInterfaceSettings3 + + DnsInterfaceSettingsFlagIPv6 = 0x0001 + DnsInterfaceSettingsFlagNameserver = 0x0002 + DnsInterfaceSettingsFlagSearchList = 0x0004 + DnsInterfaceSettingsFlagRegistrationEnabled = 0x0008 + DnsInterfaceSettingsFlagRegisterAdapterName = 0x0010 + DnsInterfaceSettingsFlagDomain = 0x0020 + DnsInterfaceSettingsFlagHostname = 0x0040 + DnsInterfaceSettingsFlagEnableLLMNR = 0x0080 + DnsInterfaceSettingsFlagQueryAdapterName = 0x0100 + DnsInterfaceSettingsFlagProfileNameserver = 0x0200 + DnsInterfaceSettingsFlagDisableUnconstrainedQueries = 0x0400 // v2 only + DnsInterfaceSettingsFlagSupplementalSearchList = 0x0800 // v2 only + DnsInterfaceSettingsFlagDOH = 0x1000 // v3 only + DnsInterfaceSettingsFlagDOHProfile = 0x2000 // v3 only +) diff --git a/tunnel/winipcfg/types_386.go b/tunnel/winipcfg/types_32.go index 3a4b5733..1a8d4443 100644 --- a/tunnel/winipcfg/types_386.go +++ b/tunnel/winipcfg/types_32.go @@ -1,6 +1,8 @@ +//go:build 386 || arm + /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/types_amd64.go b/tunnel/winipcfg/types_64.go index 11242891..3a1fe07f 100644 --- a/tunnel/winipcfg/types_amd64.go +++ b/tunnel/winipcfg/types_64.go @@ -1,6 +1,8 @@ +//go:build amd64 || arm64 + /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/types_test.go b/tunnel/winipcfg/types_test.go index c7494e8c..b72d73f5 100644 --- a/tunnel/winipcfg/types_test.go +++ b/tunnel/winipcfg/types_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg @@ -957,7 +957,6 @@ func TestIPAddressPrefix(t *testing.T) { offset := uintptr(unsafe.Pointer(&s.PrefixLength)) - sp if offset != ipAddressPrefixPrefixLengthOffset { t.Errorf("IPAddressPrefix.PrefixLength offset is %d although %d is expected", offset, ipAddressPrefixPrefixLengthOffset) - } } diff --git a/tunnel/winipcfg/types_test_386.go b/tunnel/winipcfg/types_test_32.go index db5f5f86..9e62bfef 100644 --- a/tunnel/winipcfg/types_test_386.go +++ b/tunnel/winipcfg/types_test_32.go @@ -1,6 +1,8 @@ +//go:build 386 || arm + /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/types_test_amd64.go b/tunnel/winipcfg/types_test_64.go index acc74118..8a181575 100644 --- a/tunnel/winipcfg/types_test_amd64.go +++ b/tunnel/winipcfg/types_test_64.go @@ -1,6 +1,8 @@ +//go:build amd64 || arm64 + /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/unicast_address_change_handler.go b/tunnel/winipcfg/unicast_address_change_handler.go index 5f8f2c96..cf4fcb3a 100644 --- a/tunnel/winipcfg/unicast_address_change_handler.go +++ b/tunnel/winipcfg/unicast_address_change_handler.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg diff --git a/tunnel/winipcfg/winipcfg.go b/tunnel/winipcfg/winipcfg.go index 2fc0c875..e24157b9 100644 --- a/tunnel/winipcfg/winipcfg.go +++ b/tunnel/winipcfg/winipcfg.go @@ -1,11 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ package winipcfg import ( + "runtime" "unsafe" "golang.org/x/sys/windows" @@ -29,6 +30,7 @@ import ( //sys getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) = iphlpapi.GetIfTable2Ex //sys convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid //sys convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) = iphlpapi.ConvertInterfaceGuidToLuid +//sys convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) = iphlpapi.ConvertInterfaceIndexToLuid // GetAdaptersAddresses function retrieves the addresses associated with the adapters on the local computer. // https://docs.microsoft.com/en-us/windows/desktop/api/iphlpapi/nf-iphlpapi-getadaptersaddresses @@ -166,3 +168,29 @@ func GetIPForwardTable2(family AddressFamily) ([]MibIPforwardRow2, error) { // https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-cancelmibchangenotify2 //sys cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) = iphlpapi.CancelMibChangeNotify2 + +// +// DNS-related functions +// + +//sys setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? +//sys setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? +//sys setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings? + +// The GUID is passed by value, not by reference, which means different +// things on different calling conventions. On amd64, this means it's +// passed by reference anyway, while on arm, arm64, and 386, it's split +// into words. +func SetInterfaceDnsSettings(guid windows.GUID, settings *DnsInterfaceSettings) error { + words := (*[4]uintptr)(unsafe.Pointer(&guid)) + switch runtime.GOARCH { + case "amd64": + return setInterfaceDnsSettingsByPtr(&guid, settings) + case "arm64": + return setInterfaceDnsSettingsByQwords(words[0], words[1], settings) + case "arm", "386": + return setInterfaceDnsSettingsByDwords(words[0], words[1], words[2], words[3], settings) + default: + panic("unknown calling convention") + } +} diff --git a/tunnel/winipcfg/winipcfg_test.go b/tunnel/winipcfg/winipcfg_test.go index 0251aecf..b49daf33 100644 --- a/tunnel/winipcfg/winipcfg_test.go +++ b/tunnel/winipcfg/winipcfg_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. */ /* @@ -8,8 +8,8 @@ Some tests in this file require: - A dedicated network adapter - Any network adapter will do. It may be virtual (Wintun etc.). The adapter name - must contain string "winipcfg_test". + Any network adapter will do. It may be virtual (WireGuardNT, Wintun, + etc.). The adapter name must contain string "winipcfg_test". Tests will add, remove, flush DNS servers, change adapter IP address, manipulate routes etc. The adapter will not be returned to previous state, so use an expendable one. @@ -22,9 +22,9 @@ Some tests in this file require: package winipcfg import ( - "bytes" - "net" + "net/netip" "strings" + "syscall" "testing" "time" @@ -37,22 +37,13 @@ const ( // TODO: Add IPv6 tests. var ( - unexistentIPAddresToAdd = net.IPNet{ - IP: net.IP{172, 16, 1, 114}, - Mask: net.IPMask{255, 255, 255, 0}, - } - unexistentRouteIPv4ToAdd = RouteData{ - Destination: net.IPNet{ - IP: net.IP{172, 16, 200, 0}, - Mask: net.IPMask{255, 255, 255, 0}, - }, - NextHop: net.IP{172, 16, 1, 2}, - Metric: 0, - } - dnsesToSet = []net.IP{ - net.IPv4(8, 8, 8, 8), - net.IPv4(8, 8, 4, 4), + nonexistantIPv4ToAdd = netip.MustParsePrefix("172.16.1.114/24") + nonexistentRouteIPv4ToAdd = RouteData{ + Destination: netip.MustParsePrefix("172.16.200.0/24"), + NextHop: netip.MustParseAddr("172.16.1.2"), + Metric: 0, } + dnsesToSet = []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")} ) func runningElevated() bool { @@ -73,7 +64,7 @@ func getTestInterface() (*IPAdapterAddresses, error) { marker := strings.ToLower(testInterfaceMarker) for _, ifc := range ifcs { - if strings.Index(strings.ToLower(ifc.FriendlyName()), marker) != -1 { + if strings.Contains(strings.ToLower(ifc.FriendlyName()), marker) { return ifc, nil } } @@ -93,7 +84,7 @@ func getTestIPInterface(family AddressFamily) (*MibIPInterfaceRow, error) { func TestAdaptersAddresses(t *testing.T) { ifcs, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagIncludeAll) if err != nil { - t.Errorf("GetAdaptersAddresses() returned error: %v", err) + t.Errorf("GetAdaptersAddresses() returned error: %w", err) } else if ifcs == nil { t.Errorf("GetAdaptersAddresses() returned nil.") } else if len(ifcs) == 0 { @@ -107,7 +98,7 @@ func TestAdaptersAddresses(t *testing.T) { i.PhysicalAddress() i.DHCPv6ClientDUID() for dnsSuffix := i.FirstDNSSuffix; dnsSuffix != nil; dnsSuffix = dnsSuffix.Next { - dnsSuffix.String() + _ = dnsSuffix.String() } } } @@ -117,7 +108,7 @@ func TestAdaptersAddresses(t *testing.T) { for _, i := range ifcs { ifc, err := i.LUID.Interface() if err != nil { - t.Errorf("LUID.Interface() returned an error: %v", err) + t.Errorf("LUID.Interface() returned an error: %w", err) continue } else if ifc == nil { t.Errorf("LUID.Interface() returned nil.") @@ -128,7 +119,7 @@ func TestAdaptersAddresses(t *testing.T) { for _, i := range ifcs { guid, err := i.LUID.GUID() if err != nil { - t.Errorf("LUID.GUID() returned an error: %v", err) + t.Errorf("LUID.GUID() returned an error: %w", err) continue } if guid == nil { @@ -138,7 +129,7 @@ func TestAdaptersAddresses(t *testing.T) { luid, err := LUIDFromGUID(guid) if err != nil { - t.Errorf("LUIDFromGUID() returned an error: %v", err) + t.Errorf("LUIDFromGUID() returned an error: %w", err) continue } if luid != i.LUID { @@ -151,7 +142,7 @@ func TestAdaptersAddresses(t *testing.T) { func TestIPInterface(t *testing.T) { ifcs, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault) if err != nil { - t.Errorf("GetAdaptersAddresses() returned error: %v", err) + t.Errorf("GetAdaptersAddresses() returned error: %w", err) } for _, i := range ifcs { @@ -161,12 +152,12 @@ func TestIPInterface(t *testing.T) { continue } if err != nil { - t.Errorf("LUID.IPInterface(%s) returned an error: %v", i.FriendlyName(), err) + t.Errorf("LUID.IPInterface(%s) returned an error: %w", i.FriendlyName(), err) } _, err = i.LUID.IPInterface(windows.AF_INET6) if err != nil { - t.Errorf("LUID.IPInterface(%s) returned an error: %v", i.FriendlyName(), err) + t.Errorf("LUID.IPInterface(%s) returned an error: %w", i.FriendlyName(), err) } } } @@ -174,7 +165,7 @@ func TestIPInterface(t *testing.T) { func TestIPInterfaces(t *testing.T) { tab, err := GetIPInterfaceTable(windows.AF_UNSPEC) if err != nil { - t.Errorf("GetIPInterfaceTable() returned an error: %v", err) + t.Errorf("GetIPInterfaceTable() returned an error: %w", err) return } else if tab == nil { t.Error("GetIPInterfaceTable() returned nil.") @@ -189,7 +180,7 @@ func TestIPInterfaces(t *testing.T) { func TestIPChangeMetric(t *testing.T) { ipifc, err := getTestIPInterface(windows.AF_INET) if err != nil { - t.Errorf("getTestIPInterface() returned an error: %v", err) + t.Errorf("getTestIPInterface() returned an error: %w", err) return } if !runningElevated() { @@ -208,13 +199,13 @@ func TestIPChangeMetric(t *testing.T) { } }) if err != nil { - t.Errorf("RegisterInterfaceChangeCallback() returned error: %v", err) + t.Errorf("RegisterInterfaceChangeCallback() returned error: %w", err) return } defer func() { err = cb.Unregister() if err != nil { - t.Errorf("UnregisterInterfaceChangeCallback() returned error: %v", err) + t.Errorf("UnregisterInterfaceChangeCallback() returned error: %w", err) } }() @@ -230,14 +221,14 @@ func TestIPChangeMetric(t *testing.T) { ipifc.Metric = newMetric err = ipifc.Set() if err != nil { - t.Errorf("MibIPInterfaceRow.Set() returned an error: %v", err) + t.Errorf("MibIPInterfaceRow.Set() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) ipifc, err = getTestIPInterface(windows.AF_INET) if err != nil { - t.Errorf("getTestIPInterface() returned an error: %v", err) + t.Errorf("getTestIPInterface() returned an error: %w", err) return } if ipifc.Metric != newMetric { @@ -255,14 +246,14 @@ func TestIPChangeMetric(t *testing.T) { ipifc.Metric = metric err = ipifc.Set() if err != nil { - t.Errorf("MibIPInterfaceRow.Set() returned an error: %v", err) + t.Errorf("MibIPInterfaceRow.Set() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) ipifc, err = getTestIPInterface(windows.AF_INET) if err != nil { - t.Errorf("getTestIPInterface() returned an error: %v", err) + t.Errorf("getTestIPInterface() returned an error: %w", err) return } if ipifc.Metric != metric { @@ -279,7 +270,7 @@ func TestIPChangeMetric(t *testing.T) { func TestIPChangeMTU(t *testing.T) { ipifc, err := getTestIPInterface(windows.AF_INET) if err != nil { - t.Errorf("getTestIPInterface() returned an error: %v", err) + t.Errorf("getTestIPInterface() returned an error: %w", err) return } if !runningElevated() { @@ -292,14 +283,14 @@ func TestIPChangeMTU(t *testing.T) { ipifc.NLMTU = mtuToSet err = ipifc.Set() if err != nil { - t.Errorf("Interface.Set() returned error: %v", err) + t.Errorf("Interface.Set() returned error: %w", err) } time.Sleep(500 * time.Millisecond) ipifc, err = getTestIPInterface(windows.AF_INET) if err != nil { - t.Errorf("getTestIPInterface() returned an error: %v", err) + t.Errorf("getTestIPInterface() returned an error: %w", err) return } if ipifc.NLMTU != mtuToSet { @@ -309,14 +300,14 @@ func TestIPChangeMTU(t *testing.T) { ipifc.NLMTU = prevMTU err = ipifc.Set() if err != nil { - t.Errorf("Interface.Set() returned error: %v", err) + t.Errorf("Interface.Set() returned error: %w", err) } time.Sleep(500 * time.Millisecond) ipifc, err = getTestIPInterface(windows.AF_INET) if err != nil { - t.Errorf("getTestIPInterface() returned an error: %v", err) + t.Errorf("getTestIPInterface() returned an error: %w", err) } if ipifc.NLMTU != prevMTU { t.Errorf("Interface.NLMTU is %d although %d is expected.", ipifc.NLMTU, prevMTU) @@ -326,13 +317,13 @@ func TestIPChangeMTU(t *testing.T) { func TestGetIfRow(t *testing.T) { ifc, err := getTestInterface() if err != nil { - t.Errorf("getTestInterface() returned an error: %v", err) + t.Errorf("getTestInterface() returned an error: %w", err) return } row, err := ifc.LUID.Interface() if err != nil { - t.Errorf("LUID.Interface() returned an error: %v", err) + t.Errorf("LUID.Interface() returned an error: %w", err) return } @@ -345,7 +336,7 @@ func TestGetIfRow(t *testing.T) { func TestGetIfRows(t *testing.T) { tab, err := GetIfTable2Ex(MibIfEntryNormal) if err != nil { - t.Errorf("GetIfTable2Ex() returned an error: %v", err) + t.Errorf("GetIfTable2Ex() returned an error: %w", err) return } else if tab == nil { t.Errorf("GetIfTable2Ex() returned nil") @@ -363,7 +354,7 @@ func TestGetIfRows(t *testing.T) { func TestUnicastIPAddress(t *testing.T) { _, err := GetUnicastIPAddressTable(windows.AF_UNSPEC) if err != nil { - t.Errorf("GetUnicastAddresses() returned an error: %v", err) + t.Errorf("GetUnicastAddresses() returned an error: %w", err) return } } @@ -371,7 +362,7 @@ func TestUnicastIPAddress(t *testing.T) { func TestAddDeleteIPAddress(t *testing.T) { ifc, err := getTestInterface() if err != nil { - t.Errorf("getTestInterface() returned an error: %v", err) + t.Errorf("getTestInterface() returned an error: %w", err) return } if !runningElevated() { @@ -379,12 +370,12 @@ func TestAddDeleteIPAddress(t *testing.T) { return } - addr, err := ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP) + addr, err := ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) if err == nil { - t.Errorf("Unicast address %s already exists. Please set unexistentIPAddresToAdd appropriately.", unexistentIPAddresToAdd.IP.String()) + t.Errorf("Unicast address %s already exists. Please set nonexistantIPv4ToAdd appropriately.", nonexistantIPv4ToAdd.Addr().String()) return } else if err != windows.ERROR_NOT_FOUND { - t.Errorf("LUID.IPAddress() returned an error: %v", err) + t.Errorf("LUID.IPAddress() returned an error: %w", err) return } @@ -401,7 +392,7 @@ func TestAddDeleteIPAddress(t *testing.T) { } }) if err != nil { - t.Errorf("RegisterUnicastAddressChangeCallback() returned an error: %v", err) + t.Errorf("RegisterUnicastAddressChangeCallback() returned an error: %w", err) } else { defer cb.Unregister() } @@ -409,9 +400,9 @@ func TestAddDeleteIPAddress(t *testing.T) { for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next { count-- } - err = ifc.LUID.AddIPAddresses([]net.IPNet{unexistentIPAddresToAdd}) + err = ifc.LUID.AddIPAddresses([]netip.Prefix{nonexistantIPv4ToAdd}) if err != nil { - t.Errorf("LUID.AddIPAddresses() returned an error: %v", err) + t.Errorf("LUID.AddIPAddresses() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) @@ -423,28 +414,28 @@ func TestAddDeleteIPAddress(t *testing.T) { if count != 1 { t.Errorf("After adding there are %d new interface(s).", count) } - addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP) + addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) if err != nil { - t.Errorf("LUID.IPAddress() returned an error: %v", err) + t.Errorf("LUID.IPAddress() returned an error: %w", err) } else if addr == nil { - t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", unexistentIPAddresToAdd.IP.String()) + t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", nonexistantIPv4ToAdd.Addr().String()) } if !created { t.Errorf("Notification handler has not been called on add.") } - err = ifc.LUID.DeleteIPAddress(unexistentIPAddresToAdd) + err = ifc.LUID.DeleteIPAddress(nonexistantIPv4ToAdd) if err != nil { - t.Errorf("LUID.DeleteIPAddress() returned an error: %v", err) + t.Errorf("LUID.DeleteIPAddress() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) - addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP) + addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr()) if err == nil { - t.Errorf("Unicast address %s still exists, although it's deleted successfully.", unexistentIPAddresToAdd.IP.String()) + t.Errorf("Unicast address %s still exists, although it's deleted successfully.", nonexistantIPv4ToAdd.Addr().String()) } else if err != windows.ERROR_NOT_FOUND { - t.Errorf("LUID.IPAddress() returned an error: %v", err) + t.Errorf("LUID.IPAddress() returned an error: %w", err) } if !deleted { t.Errorf("Notification handler has not been called on delete.") @@ -454,19 +445,18 @@ func TestAddDeleteIPAddress(t *testing.T) { func TestGetRoutes(t *testing.T) { _, err := GetIPForwardTable2(windows.AF_UNSPEC) if err != nil { - t.Errorf("GetIPForwardTable2() returned error: %v", err) + t.Errorf("GetIPForwardTable2() returned error: %w", err) } } func TestAddDeleteRoute(t *testing.T) { - findRoute := func(luid LUID, dest net.IPNet) ([]MibIPforwardRow2, error) { + findRoute := func(luid LUID, dest netip.Prefix) ([]MibIPforwardRow2, error) { var family AddressFamily - switch { - case dest.IP.To4() != nil: + if dest.Addr().Is4() { family = windows.AF_INET - case dest.IP.To16() != nil: + } else if dest.Addr().Is6() { family = windows.AF_INET6 - default: + } else { return nil, windows.ERROR_INVALID_PARAMETER } r, err := GetIPForwardTable2(family) @@ -474,9 +464,8 @@ func TestAddDeleteRoute(t *testing.T) { return nil, err } matches := make([]MibIPforwardRow2, 0, len(r)) - ones, _ := dest.Mask.Size() for _, route := range r { - if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(ones) && route.DestinationPrefix.Prefix.Family == family && route.DestinationPrefix.Prefix.IP().Equal(dest.IP) { + if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(dest.Bits()) && route.DestinationPrefix.RawPrefix.Family == family && route.DestinationPrefix.RawPrefix.Addr() == dest.Addr() { matches = append(matches, route) } } @@ -485,7 +474,7 @@ func TestAddDeleteRoute(t *testing.T) { ifc, err := getTestInterface() if err != nil { - t.Errorf("getTestInterface() returned an error: %v", err) + t.Errorf("getTestInterface() returned an error: %w", err) return } if !runningElevated() { @@ -493,20 +482,20 @@ func TestAddDeleteRoute(t *testing.T) { return } - _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err == nil { - t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?") + t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?") return } else if err != windows.ERROR_NOT_FOUND { - t.Errorf("LUID.Route() returned an error: %v", err) + t.Errorf("LUID.Route() returned an error: %w", err) return } - routes, err := findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination) + routes, err := findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) if err != nil { - t.Errorf("findRoute() returned an error: %v", err) + t.Errorf("findRoute() returned an error: %w", err) } else if len(routes) != 0 { - t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?", len(routes)) + t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?", len(routes)) } var created, deleted bool @@ -519,58 +508,58 @@ func TestAddDeleteRoute(t *testing.T) { } }) if err != nil { - t.Errorf("RegisterRouteChangeCallback() returned an error: %v", err) + t.Errorf("RegisterRouteChangeCallback() returned an error: %w", err) } else { defer cb.Unregister() } - err = ifc.LUID.AddRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop, unexistentRouteIPv4ToAdd.Metric) + err = ifc.LUID.AddRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop, nonexistentRouteIPv4ToAdd.Metric) if err != nil { - t.Errorf("LUID.AddRoute() returned an error: %v", err) + t.Errorf("LUID.AddRoute() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) - route, err := ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + route, err := ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err == windows.ERROR_NOT_FOUND { t.Error("LUID.Route() returned nil although the route is added successfully.") } else if err != nil { - t.Errorf("LUID.Route() returned an error: %v", err) - } else if !route.DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) || !route.NextHop.IP().Equal(unexistentRouteIPv4ToAdd.NextHop) { + t.Errorf("LUID.Route() returned an error: %w", err) + } else if route.DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() || route.NextHop.Addr() != nonexistentRouteIPv4ToAdd.NextHop { t.Error("LUID.Route() returned a wrong route!") } if !created { t.Errorf("Route handler has not been called on add.") } - routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination) + routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) if err != nil { - t.Errorf("findRoute() returned an error: %v", err) + t.Errorf("findRoute() returned an error: %w", err) } else if len(routes) != 1 { t.Errorf("findRoute() returned %d items although %d is expected.", len(routes), 1) - } else if !routes[0].DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) { - t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.Prefix.IP().String(), unexistentRouteIPv4ToAdd.Destination.IP.String()) + } else if routes[0].DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() { + t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.RawPrefix.Addr().String(), nonexistentRouteIPv4ToAdd.Destination.Addr().String()) } - err = ifc.LUID.DeleteRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + err = ifc.LUID.DeleteRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err != nil { - t.Errorf("LUID.DeleteRoute() returned an error: %v", err) + t.Errorf("LUID.DeleteRoute() returned an error: %w", err) } time.Sleep(500 * time.Millisecond) - _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop) + _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop) if err == nil { t.Error("LUID.Route() returned a route although it is removed successfully.") } else if err != windows.ERROR_NOT_FOUND { - t.Errorf("LUID.Route() returned an error: %v", err) + t.Errorf("LUID.Route() returned an error: %w", err) } if !deleted { t.Errorf("Route handler has not been called on delete.") } - routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination) + routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination) if err != nil { - t.Errorf("findRoute() returned an error: %v", err) + t.Errorf("findRoute() returned an error: %w", err) } else if len(routes) != 0 { t.Errorf("findRoute() returned %d items although the route is deleted successfully.", len(routes)) } @@ -579,7 +568,7 @@ func TestAddDeleteRoute(t *testing.T) { func TestFlushDNS(t *testing.T) { ifc, err := getTestInterface() if err != nil { - t.Errorf("getTestInterface() returned an error: %v", err) + t.Errorf("getTestInterface() returned an error: %w", err) return } if !runningElevated() { @@ -589,12 +578,12 @@ func TestFlushDNS(t *testing.T) { prevDNSes, err := ifc.LUID.DNS() if err != nil { - t.Errorf("LUID.DNS() returned an error: %v", err) + t.Errorf("LUID.DNS() returned an error: %w", err) } - err = ifc.LUID.FlushDNS() + err = ifc.LUID.FlushDNS(syscall.AF_INET) if err != nil { - t.Errorf("LUID.FlushDNS() returned an error: %v", err) + t.Errorf("LUID.FlushDNS() returned an error: %w", err) } ifc, _ = getTestInterface() @@ -602,10 +591,10 @@ func TestFlushDNS(t *testing.T) { n := 0 dns, err := ifc.LUID.DNS() if err != nil { - t.Errorf("LUID.DNS() returned an error: %v", err) + t.Errorf("LUID.DNS() returned an error: %w", err) } for _, a := range dns { - if len(a) != 16 || a.To4() != nil || !((a[15] == 1 || a[15] == 2 || a[15] == 3) && bytes.HasPrefix(a, []byte{0xfe, 0xc0, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})) { + if a.Is4() { n++ } } @@ -613,51 +602,7 @@ func TestFlushDNS(t *testing.T) { t.Errorf("DNSServerAddresses contains %d items, although FlushDNS is executed successfully.", n) } - err = ifc.LUID.SetDNS(prevDNSes) - if err != nil { - t.Errorf("LUID.SetDNS() returned an error: %v.", err) - } -} - -func TestAddDNS(t *testing.T) { - ifc, err := getTestInterface() - if err != nil { - t.Errorf("getTestInterface() returned an error: %v", err) - return - } - if !runningElevated() { - t.Errorf("%s requires elevation", t.Name()) - return - } - - prevDNSes, err := ifc.LUID.DNS() - if err != nil { - t.Errorf("LUID.DNS() returned an error: %v", err) - } - expectedDNSes := append(prevDNSes, dnsesToSet...) - - err = ifc.LUID.AddDNS(dnsesToSet) - if err != nil { - t.Errorf("LUID.AddDNS() returned an error: %v", err) - return - } - - ifc, _ = getTestInterface() - - newDNSes, err := ifc.LUID.DNS() - if err != nil { - t.Errorf("LUID.DNS() returned an error: %v", err) - } else if len(newDNSes) != len(expectedDNSes) { - t.Errorf("expectedDNSes contains %d items, while DNSServerAddresses contains %d.", len(expectedDNSes), len(newDNSes)) - } else { - for i := range expectedDNSes { - if !expectedDNSes[i].Equal(newDNSes[i]) { - t.Errorf("expectedDNSes[%d] = %s while DNSServerAddresses[%d] = %s.", i, expectedDNSes[i].String(), i, newDNSes[i].String()) - } - } - } - - err = ifc.LUID.SetDNS(prevDNSes) + err = ifc.LUID.SetDNS(windows.AF_INET, prevDNSes, nil) if err != nil { t.Errorf("LUID.SetDNS() returned an error: %v.", err) } @@ -666,7 +611,7 @@ func TestAddDNS(t *testing.T) { func TestSetDNS(t *testing.T) { ifc, err := getTestInterface() if err != nil { - t.Errorf("getTestInterface() returned an error: %v", err) + t.Errorf("getTestInterface() returned an error: %w", err) return } if !runningElevated() { @@ -676,12 +621,12 @@ func TestSetDNS(t *testing.T) { prevDNSes, err := ifc.LUID.DNS() if err != nil { - t.Errorf("LUID.DNS() returned an error: %v", err) + t.Errorf("LUID.DNS() returned an error: %w", err) } - err = ifc.LUID.SetDNS(dnsesToSet) + err = ifc.LUID.SetDNS(windows.AF_INET, dnsesToSet, nil) if err != nil { - t.Errorf("LUID.SetDNS() returned an error: %v", err) + t.Errorf("LUID.SetDNS() returned an error: %w", err) return } @@ -689,18 +634,18 @@ func TestSetDNS(t *testing.T) { newDNSes, err := ifc.LUID.DNS() if err != nil { - t.Errorf("LUID.DNS() returned an error: %v", err) + t.Errorf("LUID.DNS() returned an error: %w", err) } else if len(newDNSes) != len(dnsesToSet) { t.Errorf("dnsesToSet contains %d items, while DNSServerAddresses contains %d.", len(dnsesToSet), len(newDNSes)) } else { for i := range dnsesToSet { - if !dnsesToSet[i].Equal(newDNSes[i]) { + if dnsesToSet[i] != newDNSes[i] { t.Errorf("dnsesToSet[%d] = %s while DNSServerAddresses[%d] = %s.", i, dnsesToSet[i].String(), i, newDNSes[i].String()) } } } - err = ifc.LUID.SetDNS(prevDNSes) + err = ifc.LUID.SetDNS(windows.AF_INET, prevDNSes, nil) if err != nil { t.Errorf("LUID.SetDNS() returned an error: %v.", err) } @@ -709,7 +654,7 @@ func TestSetDNS(t *testing.T) { func TestAnycastIPAddress(t *testing.T) { _, err := GetAnycastIPAddressTable(windows.AF_UNSPEC) if err != nil { - t.Errorf("GetAnycastIPAddressTable() returned an error: %v", err) + t.Errorf("GetAnycastIPAddressTable() returned an error: %w", err) return } } diff --git a/tunnel/winipcfg/zwinipcfg_windows.go b/tunnel/winipcfg/zwinipcfg_windows.go index 8f37ac26..3a0d8680 100644 --- a/tunnel/winipcfg/zwinipcfg_windows.go +++ b/tunnel/winipcfg/zwinipcfg_windows.go @@ -19,6 +19,7 @@ const ( var ( errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL ) // errnoErr returns common boxed Errno values, to prevent @@ -26,7 +27,7 @@ var ( func errnoErr(e syscall.Errno) error { switch e { case 0: - return nil + return errERROR_EINVAL case errnoERROR_IO_PENDING: return errERROR_IO_PENDING } @@ -39,182 +40,198 @@ func errnoErr(e syscall.Errno) error { var ( modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") - procFreeMibTable = modiphlpapi.NewProc("FreeMibTable") - procInitializeIpInterfaceEntry = modiphlpapi.NewProc("InitializeIpInterfaceEntry") - procGetIpInterfaceTable = modiphlpapi.NewProc("GetIpInterfaceTable") - procGetIpInterfaceEntry = modiphlpapi.NewProc("GetIpInterfaceEntry") - procSetIpInterfaceEntry = modiphlpapi.NewProc("SetIpInterfaceEntry") - procGetIfEntry2 = modiphlpapi.NewProc("GetIfEntry2") - procGetIfTable2Ex = modiphlpapi.NewProc("GetIfTable2Ex") - procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid") + procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2") procConvertInterfaceGuidToLuid = modiphlpapi.NewProc("ConvertInterfaceGuidToLuid") - procGetUnicastIpAddressTable = modiphlpapi.NewProc("GetUnicastIpAddressTable") - procInitializeUnicastIpAddressEntry = modiphlpapi.NewProc("InitializeUnicastIpAddressEntry") - procGetUnicastIpAddressEntry = modiphlpapi.NewProc("GetUnicastIpAddressEntry") - procSetUnicastIpAddressEntry = modiphlpapi.NewProc("SetUnicastIpAddressEntry") + procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid") + procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid") + procCreateAnycastIpAddressEntry = modiphlpapi.NewProc("CreateAnycastIpAddressEntry") + procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2") procCreateUnicastIpAddressEntry = modiphlpapi.NewProc("CreateUnicastIpAddressEntry") + procDeleteAnycastIpAddressEntry = modiphlpapi.NewProc("DeleteAnycastIpAddressEntry") + procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2") procDeleteUnicastIpAddressEntry = modiphlpapi.NewProc("DeleteUnicastIpAddressEntry") - procGetAnycastIpAddressTable = modiphlpapi.NewProc("GetAnycastIpAddressTable") + procFreeMibTable = modiphlpapi.NewProc("FreeMibTable") procGetAnycastIpAddressEntry = modiphlpapi.NewProc("GetAnycastIpAddressEntry") - procCreateAnycastIpAddressEntry = modiphlpapi.NewProc("CreateAnycastIpAddressEntry") - procDeleteAnycastIpAddressEntry = modiphlpapi.NewProc("DeleteAnycastIpAddressEntry") + procGetAnycastIpAddressTable = modiphlpapi.NewProc("GetAnycastIpAddressTable") + procGetIfEntry2 = modiphlpapi.NewProc("GetIfEntry2") + procGetIfTable2Ex = modiphlpapi.NewProc("GetIfTable2Ex") + procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2") procGetIpForwardTable2 = modiphlpapi.NewProc("GetIpForwardTable2") + procGetIpInterfaceEntry = modiphlpapi.NewProc("GetIpInterfaceEntry") + procGetIpInterfaceTable = modiphlpapi.NewProc("GetIpInterfaceTable") + procGetUnicastIpAddressEntry = modiphlpapi.NewProc("GetUnicastIpAddressEntry") + procGetUnicastIpAddressTable = modiphlpapi.NewProc("GetUnicastIpAddressTable") procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry") - procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2") - procSetIpForwardEntry2 = modiphlpapi.NewProc("SetIpForwardEntry2") - procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2") - procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2") + procInitializeIpInterfaceEntry = modiphlpapi.NewProc("InitializeIpInterfaceEntry") + procInitializeUnicastIpAddressEntry = modiphlpapi.NewProc("InitializeUnicastIpAddressEntry") procNotifyIpInterfaceChange = modiphlpapi.NewProc("NotifyIpInterfaceChange") - procNotifyUnicastIpAddressChange = modiphlpapi.NewProc("NotifyUnicastIpAddressChange") procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2") - procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2") + procNotifyUnicastIpAddressChange = modiphlpapi.NewProc("NotifyUnicastIpAddressChange") + procSetInterfaceDnsSettings = modiphlpapi.NewProc("SetInterfaceDnsSettings") + procSetIpForwardEntry2 = modiphlpapi.NewProc("SetIpForwardEntry2") + procSetIpInterfaceEntry = modiphlpapi.NewProc("SetIpInterfaceEntry") + procSetUnicastIpAddressEntry = modiphlpapi.NewProc("SetUnicastIpAddressEntry") ) -func freeMibTable(memory unsafe.Pointer) { - syscall.Syscall(procFreeMibTable.Addr(), 1, uintptr(memory), 0, 0) +func cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) { + r0, _, _ := syscall.Syscall(procCancelMibChangeNotify2.Addr(), 1, uintptr(notificationHandle), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } return } -func initializeIPInterfaceEntry(row *MibIPInterfaceRow) { - syscall.Syscall(procInitializeIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) { + r0, _, _ := syscall.Syscall(procConvertInterfaceGuidToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceGUID)), uintptr(unsafe.Pointer(interfaceLUID)), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } return } -func getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) { - r0, _, _ := syscall.Syscall(procGetIpInterfaceTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) +func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) { + r0, _, _ := syscall.Syscall(procConvertInterfaceIndexToLuid.Addr(), 2, uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) { - r0, _, _ := syscall.Syscall(procGetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) { + r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) { - r0, _, _ := syscall.Syscall(procSetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procCreateAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func getIfEntry2(row *MibIfRow2) (ret error) { - r0, _, _ := syscall.Syscall(procGetIfEntry2.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func createIPForwardEntry2(route *MibIPforwardRow2) (ret error) { + r0, _, _ := syscall.Syscall(procCreateIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) { - r0, _, _ := syscall.Syscall(procGetIfTable2Ex.Addr(), 2, uintptr(level), uintptr(unsafe.Pointer(table)), 0) +func createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procCreateUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) { - r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0) +func deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procDeleteAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) { - r0, _, _ := syscall.Syscall(procConvertInterfaceGuidToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceGUID)), uintptr(unsafe.Pointer(interfaceLUID)), 0) +func deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) { + r0, _, _ := syscall.Syscall(procDeleteIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) { - r0, _, _ := syscall.Syscall(procGetUnicastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) +func deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procDeleteUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) { - syscall.Syscall(procInitializeUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func freeMibTable(memory unsafe.Pointer) { + syscall.Syscall(procFreeMibTable.Addr(), 1, uintptr(memory), 0, 0) return } -func getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procGetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procGetAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procSetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) { + r0, _, _ := syscall.Syscall(procGetAnycastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procCreateUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func getIfEntry2(row *MibIfRow2) (ret error) { + r0, _, _ := syscall.Syscall(procGetIfEntry2.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procDeleteUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) { + r0, _, _ := syscall.Syscall(procGetIfTable2Ex.Addr(), 2, uintptr(level), uintptr(unsafe.Pointer(table)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) { - r0, _, _ := syscall.Syscall(procGetAnycastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) +func getIPForwardEntry2(route *MibIPforwardRow2) (ret error) { + r0, _, _ := syscall.Syscall(procGetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procGetAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) { + r0, _, _ := syscall.Syscall(procGetIpForwardTable2.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procCreateAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) { + r0, _, _ := syscall.Syscall(procGetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) { - r0, _, _ := syscall.Syscall(procDeleteAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) +func getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) { + r0, _, _ := syscall.Syscall(procGetIpInterfaceTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) { - r0, _, _ := syscall.Syscall(procGetIpForwardTable2.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) +func getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procGetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) { + r0, _, _ := syscall.Syscall(procGetUnicastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0) if r0 != 0 { ret = syscall.Errno(r0) } @@ -226,82 +243,106 @@ func initializeIPForwardEntry(route *MibIPforwardRow2) { return } -func getIPForwardEntry2(route *MibIPforwardRow2) (ret error) { - r0, _, _ := syscall.Syscall(procGetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) +func initializeIPInterfaceEntry(row *MibIPInterfaceRow) { + syscall.Syscall(procInitializeIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + return +} + +func initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) { + syscall.Syscall(procInitializeUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) + return +} + +func notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) { + var _p0 uint32 + if initialNotification { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procNotifyIpInterfaceChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func setIPForwardEntry2(route *MibIPforwardRow2) (ret error) { - r0, _, _ := syscall.Syscall(procSetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) +func notifyRouteChange2(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) { + var _p0 uint32 + if initialNotification { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procNotifyRouteChange2.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func createIPForwardEntry2(route *MibIPforwardRow2) (ret error) { - r0, _, _ := syscall.Syscall(procCreateIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) +func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) { + var _p0 uint32 + if initialNotification { + _p0 = 1 + } + r0, _, _ := syscall.Syscall6(procNotifyUnicastIpAddressChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) { - r0, _, _ := syscall.Syscall(procDeleteIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) +func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) { + ret = procSetInterfaceDnsSettings.Find() + if ret != nil { + return + } + r0, _, _ := syscall.Syscall6(procSetInterfaceDnsSettings.Addr(), 5, uintptr(guid1), uintptr(guid2), uintptr(guid3), uintptr(guid4), uintptr(unsafe.Pointer(settings)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) { - var _p0 uint32 - if initialNotification { - _p0 = 1 - } else { - _p0 = 0 +func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) { + ret = procSetInterfaceDnsSettings.Find() + if ret != nil { + return } - r0, _, _ := syscall.Syscall6(procNotifyIpInterfaceChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) + r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings))) if r0 != 0 { ret = syscall.Errno(r0) } return } -func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) { - var _p0 uint32 - if initialNotification { - _p0 = 1 - } else { - _p0 = 0 +func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) { + ret = procSetInterfaceDnsSettings.Find() + if ret != nil { + return } - r0, _, _ := syscall.Syscall6(procNotifyUnicastIpAddressChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) + r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func notifyRouteChange2(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) { - var _p0 uint32 - if initialNotification { - _p0 = 1 - } else { - _p0 = 0 +func setIPForwardEntry2(route *MibIPforwardRow2) (ret error) { + r0, _, _ := syscall.Syscall(procSetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0) + if r0 != 0 { + ret = syscall.Errno(r0) } - r0, _, _ := syscall.Syscall6(procNotifyRouteChange2.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0) + return +} + +func setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) { + r0, _, _ := syscall.Syscall(procSetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } return } -func cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) { - r0, _, _ := syscall.Syscall(procCancelMibChangeNotify2.Addr(), 1, uintptr(notificationHandle), 0, 0) +func setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) { + r0, _, _ := syscall.Syscall(procSetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0) if r0 != 0 { ret = syscall.Errno(r0) } diff --git a/tunnel/wintun_test.go b/tunnel/wintun_test.go deleted file mode 100644 index 11d7ab2c..00000000 --- a/tunnel/wintun_test.go +++ /dev/null @@ -1,202 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package tunnel_test - -import ( - "bytes" - "crypto/rand" - "encoding/binary" - "fmt" - "net" - "sync" - "testing" - "time" - - "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/tun" - - "golang.zx2c4.com/wireguard/windows/elevate" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" -) - -func TestWintunOrdering(t *testing.T) { - var tunDevice tun.Device - err := elevate.DoAsSystem(func() error { - var err error - tunDevice, err = tun.CreateTUNWithRequestedGUID("tunordertest", &windows.GUID{12, 12, 12, [8]byte{12, 12, 12, 12, 12, 12, 12, 12}}, 1500) - return err - }) - if err != nil { - t.Fatal(err) - } - defer tunDevice.Close() - nativeTunDevice := tunDevice.(*tun.NativeTun) - luid := winipcfg.LUID(nativeTunDevice.LUID()) - ip, ipnet, _ := net.ParseCIDR("10.82.31.4/24") - err = luid.SetIPAddresses([]net.IPNet{{ip, ipnet.Mask}}) - if err != nil { - t.Fatal(err) - } - err = luid.SetRoutes([]*winipcfg.RouteData{{*ipnet, ipnet.IP, 0}}) - if err != nil { - t.Fatal(err) - } - var token [32]byte - _, err = rand.Read(token[:]) - if err != nil { - t.Fatal(err) - } - var sockWrite net.Conn - for i := 0; i < 1000; i++ { - sockWrite, err = net.Dial("udp", "10.82.31.5:9999") - if err == nil { - defer sockWrite.Close() - break - } - time.Sleep(time.Millisecond * 100) - } - if err != nil { - t.Fatal(err) - } - var sockRead *net.UDPConn - for i := 0; i < 1000; i++ { - var listenAddress *net.UDPAddr - listenAddress, err = net.ResolveUDPAddr("udp", "10.82.31.4:9999") - if err != nil { - continue - } - sockRead, err = net.ListenUDP("udp", listenAddress) - if err == nil { - defer sockRead.Close() - break - } - time.Sleep(time.Millisecond * 100) - } - if err != nil { - t.Fatal(err) - } - var wait sync.WaitGroup - wait.Add(4) - doneSockWrite := false - doneTunWrite := false - fatalErrors := make(chan error, 2) - errors := make(chan error, 2) - go func() { - defer wait.Done() - buffer := append(token[:], 0, 0, 0, 0, 0, 0, 0, 0) - for sendingIndex := uint64(0); !doneSockWrite; sendingIndex++ { - binary.LittleEndian.PutUint64(buffer[32:], sendingIndex) - _, err := sockWrite.Write(buffer[:]) - if err != nil { - fatalErrors <- err - } - } - }() - go func() { - defer wait.Done() - packet := [20 + 8 + 32 + 8]byte{ - 0x45, 0, 0, 20 + 8 + 32 + 8, - 0, 0, 0, 0, - 0x80, 0x11, 0, 0, - 10, 82, 31, 5, - 10, 82, 31, 4, - 8888 >> 8, 8888 & 0xff, 9999 >> 8, 9999 & 0xff, 0, 8 + 32 + 8, 0, 0, - } - copy(packet[28:], token[:]) - for sendingIndex := uint64(0); !doneTunWrite; sendingIndex++ { - binary.BigEndian.PutUint16(packet[4:], uint16(sendingIndex)) - var checksum uint32 - for i := 0; i < 20; i += 2 { - if i != 10 { - checksum += uint32(binary.BigEndian.Uint16(packet[i:])) - } - } - binary.BigEndian.PutUint16(packet[10:], ^(uint16(checksum>>16) + uint16(checksum&0xffff))) - binary.LittleEndian.PutUint64(packet[20+8+32:], sendingIndex) - n, err := tunDevice.Write(packet[:], 0) - if err != nil { - fatalErrors <- err - } - if n == 0 { - time.Sleep(time.Millisecond * 300) - } - } - }() - const packetsPerTest = 1 << 21 - go func() { - defer func() { - doneSockWrite = true - wait.Done() - }() - var expectedIndex uint64 - for i := uint64(0); i < packetsPerTest; { - var buffer [(1 << 16) - 1]byte - bytesRead, err := tunDevice.Read(buffer[:], 0) - if err != nil { - fatalErrors <- err - } - if bytesRead < 0 || bytesRead > len(buffer) { - continue - } - packet := buffer[:bytesRead] - tokenPos := bytes.Index(packet, token[:]) - if tokenPos == -1 || tokenPos+32+8 > len(packet) { - continue - } - foundIndex := binary.LittleEndian.Uint64(packet[tokenPos+32:]) - if foundIndex < expectedIndex { - errors <- fmt.Errorf("Sock write, tun read: expected packet %d, received packet %d", expectedIndex, foundIndex) - } - expectedIndex = foundIndex + 1 - i++ - } - }() - go func() { - defer func() { - doneTunWrite = true - wait.Done() - }() - var expectedIndex uint64 - for i := uint64(0); i < packetsPerTest; { - var buffer [(1 << 16) - 1]byte - bytesRead, err := sockRead.Read(buffer[:]) - if err != nil { - fatalErrors <- err - } - if bytesRead < 0 || bytesRead > len(buffer) { - continue - } - packet := buffer[:bytesRead] - if len(packet) != 32+8 || !bytes.HasPrefix(packet, token[:]) { - continue - } - foundIndex := binary.LittleEndian.Uint64(packet[32:]) - if foundIndex < expectedIndex { - errors <- fmt.Errorf("Tun write, sock read: expected packet %d, received packet %d", expectedIndex, foundIndex) - } - expectedIndex = foundIndex + 1 - i++ - } - }() - done := make(chan bool, 2) - doneFunc := func() { - wait.Wait() - done <- true - } - defer doneFunc() - go doneFunc() - for { - select { - case err := <-fatalErrors: - t.Fatal(err) - case err := <-errors: - t.Error(err) - case <-done: - return - } - } -} |