aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tunnel
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tunnel/addressconfig.go197
-rw-r--r--tunnel/deterministicguid.go16
-rw-r--r--tunnel/firewall/blocker.go30
-rw-r--r--tunnel/firewall/helpers.go6
-rw-r--r--tunnel/firewall/mksyscall.go2
-rw-r--r--tunnel/firewall/rules.go17
-rw-r--r--tunnel/firewall/syscall_windows.go2
-rw-r--r--tunnel/firewall/types_windows.go8
-rw-r--r--tunnel/firewall/types_windows_32.go (renamed from tunnel/firewall/types_windows_386.go)4
-rw-r--r--tunnel/firewall/types_windows_64.go (renamed from tunnel/firewall/types_windows_amd64.go)4
-rw-r--r--tunnel/firewall/types_windows_test.go29
-rw-r--r--tunnel/firewall/zsyscall_windows.go107
-rw-r--r--tunnel/interfacewatcher.go119
-rw-r--r--tunnel/ipcpermissions.go63
-rw-r--r--tunnel/mtumonitor.go (renamed from tunnel/defaultroutemonitor.go)63
-rw-r--r--tunnel/pitfalls.go177
-rw-r--r--tunnel/scriptrunner.go77
-rw-r--r--tunnel/service.go183
-rw-r--r--tunnel/winipcfg/interface_change_handler.go2
-rw-r--r--tunnel/winipcfg/luid.go242
-rw-r--r--tunnel/winipcfg/mksyscall.go2
-rw-r--r--tunnel/winipcfg/netsh.go85
-rw-r--r--tunnel/winipcfg/route_change_handler.go2
-rw-r--r--tunnel/winipcfg/types.go202
-rw-r--r--tunnel/winipcfg/types_32.go (renamed from tunnel/winipcfg/types_386.go)4
-rw-r--r--tunnel/winipcfg/types_64.go (renamed from tunnel/winipcfg/types_amd64.go)4
-rw-r--r--tunnel/winipcfg/types_test.go3
-rw-r--r--tunnel/winipcfg/types_test_32.go (renamed from tunnel/winipcfg/types_test_386.go)4
-rw-r--r--tunnel/winipcfg/types_test_64.go (renamed from tunnel/winipcfg/types_test_amd64.go)4
-rw-r--r--tunnel/winipcfg/unicast_address_change_handler.go2
-rw-r--r--tunnel/winipcfg/winipcfg.go30
-rw-r--r--tunnel/winipcfg/winipcfg_test.go247
-rw-r--r--tunnel/winipcfg/zwinipcfg_windows.go227
-rw-r--r--tunnel/wintun_test.go202
34 files changed, 1136 insertions, 1230 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go
index 4be2c36a..a3ce6295 100644
--- a/tunnel/addressconfig.go
+++ b/tunnel/addressconfig.go
@@ -1,42 +1,30 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
import (
- "bytes"
+ "fmt"
"log"
- "net"
- "sort"
+ "net/netip"
+ "time"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
-func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) {
+func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) {
if len(addresses) == 0 {
return
}
- includedInAddresses := func(a net.IPNet) bool {
- // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer!
- for _, addr := range addresses {
- ip := addr.IP
- if ip4 := ip.To4(); ip4 != nil {
- ip = ip4
- }
- mA, _ := addr.Mask.Size()
- mB, _ := a.Mask.Size()
- if bytes.Equal(ip, a.IP) && mA == mB {
- return true
- }
- }
- return false
+ addrHash := make(map[netip.Addr]bool, len(addresses))
+ for i := range addresses {
+ addrHash[addresses[i].Addr()] = true
}
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault)
if err != nil {
@@ -47,155 +35,124 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add
continue
}
for address := iface.FirstUnicastAddress; address != nil; address = address.Next {
- ip := address.Address.IP()
- ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))}
- if includedInAddresses(ipnet) {
- log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", ipnet.String(), iface.FriendlyName())
- iface.LUID.DeleteIPAddress(ipnet)
+ if ip, _ := netip.AddrFromSlice(address.Address.IP()); addrHash[ip] {
+ prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength))
+ log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", prefix.String(), iface.FriendlyName())
+ iface.LUID.DeleteIPAddress(prefix)
}
}
}
}
-func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error {
- luid := winipcfg.LUID(tun.LUID())
+func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) error {
+ retryOnFailure := services.StartedAtBoot()
+ tryTimes := 0
+startOver:
+ var err error
+ if tryTimes > 0 {
+ log.Printf("Retrying interface configuration after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ time.Sleep(time.Second)
+ retryOnFailure = retryOnFailure && tryTimes < 15
+ }
+ tryTimes++
estimatedRouteCount := 0
for _, peer := range conf.Peers {
estimatedRouteCount += len(peer.AllowedIPs)
}
- routes := make([]winipcfg.RouteData, 0, estimatedRouteCount)
- addresses := make([]net.IPNet, len(conf.Interface.Addresses))
- var haveV4Address, haveV6Address bool
- for i, addr := range conf.Interface.Addresses {
- addresses[i] = addr.IPNet()
- if addr.Bits() == 32 {
- haveV4Address = true
- } else if addr.Bits() == 128 {
- haveV6Address = true
- }
- }
+ routes := make(map[winipcfg.RouteData]bool, estimatedRouteCount)
foundDefault4 := false
foundDefault6 := false
for _, peer := range conf.Peers {
for _, allowedip := range peer.AllowedIPs {
- if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) {
- continue
- }
route := winipcfg.RouteData{
- Destination: allowedip.IPNet(),
+ Destination: allowedip.Masked(),
Metric: 0,
}
- if allowedip.Bits() == 32 {
- if allowedip.Cidr == 0 {
+ if allowedip.Addr().Is4() {
+ if allowedip.Bits() == 0 {
foundDefault4 = true
}
- route.NextHop = net.IPv4zero
- } else if allowedip.Bits() == 128 {
- if allowedip.Cidr == 0 {
+ route.NextHop = netip.IPv4Unspecified()
+ } else if allowedip.Addr().Is6() {
+ if allowedip.Bits() == 0 {
foundDefault6 = true
}
- route.NextHop = net.IPv6zero
+ route.NextHop = netip.IPv6Unspecified()
}
- routes = append(routes, route)
+ routes[route] = true
}
}
- err := luid.SetIPAddressesForFamily(family, addresses)
- if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
- cleanupAddressesOnDisconnectedInterfaces(family, addresses)
- err = luid.SetIPAddressesForFamily(family, addresses)
- }
- if err != nil {
- return err
+ deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
+ for route := range routes {
+ r := route
+ deduplicatedRoutes = append(deduplicatedRoutes, &r)
}
- deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
- sort.Slice(routes, func(i, j int) bool {
- return routes[i].Metric < routes[j].Metric ||
- bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 ||
- bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 ||
- bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == -1
- })
- for i := 0; i < len(routes); i++ {
- if i > 0 && routes[i].Metric == routes[i-1].Metric &&
- bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) &&
- bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
- bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
- continue
+ if !conf.Interface.TableOff {
+ err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set routes: %w", err)
}
- deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
}
- err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
- if err != nil {
- return nil
+ err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses)
+ if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
+ cleanupAddressesOnDisconnectedInterfaces(family, conf.Interface.Addresses)
+ err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses)
+ }
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set ips: %w", err)
}
- ipif, err := luid.IPInterface(family)
+ var ipif *winipcfg.MibIPInterfaceRow
+ ipif, err = luid.IPInterface(family)
if err != nil {
return err
}
+ ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
+ ipif.DadTransmits = 0
+ ipif.ManagedAddressConfigurationSupported = false
+ ipif.OtherStatefulConfigurationSupported = false
if conf.Interface.MTU > 0 {
ipif.NLMTU = uint32(conf.Interface.MTU)
- tun.ForceMTU(int(ipif.NLMTU))
}
- if family == windows.AF_INET {
- if foundDefault4 {
- ipif.UseAutomaticMetric = false
- ipif.Metric = 0
- }
- } else if family == windows.AF_INET6 {
- if foundDefault6 {
- ipif.UseAutomaticMetric = false
- ipif.Metric = 0
- }
- ipif.DadTransmits = 0
- ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
+ if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) {
+ ipif.UseAutomaticMetric = false
+ ipif.Metric = 0
}
err = ipif.Set()
- if err != nil {
- return err
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set metric and MTU: %w", err)
}
- dnsSearch := ""
- if len(conf.Interface.DNSSearch) > 0 {
- dnsSearch = conf.Interface.DNSSearch[0]
- }
- err = luid.SetDNSDomain(dnsSearch)
- if err != nil {
- return nil
+ err = luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch)
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set DNS: %w", err)
}
- if len(conf.Interface.DNSSearch) > 1 {
- log.Printf("Warning: %d DNS search domains were specified, but only one is supported, so the first one (%s) was used.", len(conf.Interface.DNSSearch), dnsSearch)
- }
- err = luid.SetDNSForFamily(family, conf.Interface.DNS)
- if err != nil {
- return err
- }
-
return nil
}
-func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error {
- restrictAll := false
- if len(conf.Peers) == 1 {
- nextallowedip:
+func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error {
+ doNotRestrict := true
+ if len(conf.Peers) == 1 && !conf.Interface.TableOff {
for _, allowedip := range conf.Peers[0].AllowedIPs {
- if allowedip.Cidr == 0 {
- for _, b := range allowedip.IP {
- if b != 0 {
- continue nextallowedip
- }
- }
- restrictAll = true
+ if allowedip.Bits() == 0 && allowedip == allowedip.Masked() {
+ doNotRestrict = false
break
}
}
}
- if restrictAll && len(conf.Interface.DNS) == 0 {
- log.Println("Warning: no DNS server specified, despite having an allowed IPs of 0.0.0.0/0 or ::/0. There may be connectivity issues.")
- }
- return firewall.EnableFirewall(tun.LUID(), conf.Interface.DNS, restrictAll)
+ log.Println("Enabling firewall rules")
+ return firewall.EnableFirewall(uint64(luid), doNotRestrict, conf.Interface.DNS)
}
diff --git a/tunnel/deterministicguid.go b/tunnel/deterministicguid.go
index 8c0f34c0..405d33a3 100644
--- a/tunnel/deterministicguid.go
+++ b/tunnel/deterministicguid.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
@@ -18,8 +18,10 @@ import (
"golang.zx2c4.com/wireguard/windows/conf"
)
-const deterministicGUIDLabel = "Deterministic WireGuard Windows GUID v1 jason@zx2c4.com"
-const fixedGUIDLabel = "Fixed WireGuard Windows GUID v1 jason@zx2c4.com"
+const (
+ deterministicGUIDLabel = "Deterministic WireGuard Windows GUID v1 jason@zx2c4.com"
+ fixedGUIDLabel = "Fixed WireGuard Windows GUID v1 jason@zx2c4.com"
+)
// Escape hatch for external consumers, not us.
var UseFixedGUIDInsteadOfDeterministic = false
@@ -80,13 +82,13 @@ func deterministicGUID(c *conf.Config) *windows.GUID {
b2Number(len(peer.AllowedIPs))
sortedAllowedIPs := peer.AllowedIPs
sort.Slice(sortedAllowedIPs, func(i, j int) bool {
- if bi, bj := sortedAllowedIPs[i].Bits(), sortedAllowedIPs[j].Bits(); bi != bj {
+ if bi, bj := sortedAllowedIPs[i].Addr().BitLen(), sortedAllowedIPs[j].Addr().BitLen(); bi != bj {
return bi < bj
}
- if sortedAllowedIPs[i].Cidr != sortedAllowedIPs[j].Cidr {
- return sortedAllowedIPs[i].Cidr < sortedAllowedIPs[j].Cidr
+ if sortedAllowedIPs[i].Bits() != sortedAllowedIPs[j].Bits() {
+ return sortedAllowedIPs[i].Bits() < sortedAllowedIPs[j].Bits()
}
- return bytes.Compare(sortedAllowedIPs[i].IP[:], sortedAllowedIPs[j].IP[:]) < 0
+ return sortedAllowedIPs[i].Addr().Compare(sortedAllowedIPs[j].Addr()) < 0
})
for _, allowedip := range sortedAllowedIPs {
b2String(allowedip.String())
diff --git a/tunnel/firewall/blocker.go b/tunnel/firewall/blocker.go
index 7da391ca..8a4967ba 100644
--- a/tunnel/firewall/blocker.go
+++ b/tunnel/firewall/blocker.go
@@ -1,13 +1,13 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
import (
"errors"
- "net"
+ "net/netip"
"unsafe"
"golang.org/x/sys/windows"
@@ -101,7 +101,7 @@ func registerBaseObjects(session uintptr) (*baseObjects, error) {
return bo, nil
}
-func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool) error {
+func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []netip.Addr) error {
if wfpSession != 0 {
return errors.New("The firewall has already been enabled")
}
@@ -122,26 +122,24 @@ func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool
return wrapErr(err)
}
- if len(restrictToDNSServers) > 0 {
- err = blockDNS(restrictToDNSServers, session, baseObjects, 15, 14)
- if err != nil {
- return wrapErr(err)
+ if !doNotRestrict {
+ if len(restrictToDNSServers) > 0 {
+ err = blockDNS(restrictToDNSServers, session, baseObjects, 15, 14)
+ if err != nil {
+ return wrapErr(err)
+ }
}
- }
- if restrictAll {
err = permitLoopback(session, baseObjects, 13)
if err != nil {
return wrapErr(err)
}
- }
- err = permitTunInterface(session, baseObjects, 12, luid)
- if err != nil {
- return wrapErr(err)
- }
+ err = permitTunInterface(session, baseObjects, 12, luid)
+ if err != nil {
+ return wrapErr(err)
+ }
- if restrictAll {
err = permitDHCPIPv4(session, baseObjects, 12)
if err != nil {
return wrapErr(err)
@@ -164,9 +162,7 @@ func EnableFirewall(luid uint64, restrictToDNSServers []net.IP, restrictAll bool
return wrapErr(err)
}
*/
- }
- if restrictAll {
err = blockAll(session, baseObjects, 0)
if err != nil {
return wrapErr(err)
diff --git a/tunnel/firewall/helpers.go b/tunnel/firewall/helpers.go
index 0c9e8e3f..46e43aa5 100644
--- a/tunnel/firewall/helpers.go
+++ b/tunnel/firewall/helpers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
@@ -66,9 +66,9 @@ func wrapErr(err error) error {
}
_, file, line, ok := runtime.Caller(1)
if !ok {
- return fmt.Errorf("Firewall error at unknown location: %v", err)
+ return fmt.Errorf("Firewall error at unknown location: %w", err)
}
- return fmt.Errorf("Firewall error at %s:%d: %v", file, line, err)
+ return fmt.Errorf("Firewall error at %s:%d: %w", file, line, err)
}
func getCurrentProcessSecurityDescriptor() (*windows.SECURITY_DESCRIPTOR, error) {
diff --git a/tunnel/firewall/mksyscall.go b/tunnel/firewall/mksyscall.go
index 060c3b1c..fc108007 100644
--- a/tunnel/firewall/mksyscall.go
+++ b/tunnel/firewall/mksyscall.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
diff --git a/tunnel/firewall/rules.go b/tunnel/firewall/rules.go
index 7bca508b..41632f98 100644
--- a/tunnel/firewall/rules.go
+++ b/tunnel/firewall/rules.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
@@ -8,7 +8,7 @@ package firewall
import (
"encoding/binary"
"errors"
- "net"
+ "net/netip"
"runtime"
"unsafe"
@@ -582,7 +582,6 @@ func permitDHCPIPv6(session uintptr, baseObjects *baseObjects, weight uint8) err
}
func permitNdp(session uintptr, baseObjects *baseObjects, weight uint8) error {
-
/* TODO: actually handle the hop limit somehow! The rules should vaguely be:
* - icmpv6 133: must be outgoing, dst must be FF02::2/128, hop limit must be 255
* - icmpv6 134: must be incoming, src must be FE80::/10, hop limit must be 255
@@ -985,7 +984,7 @@ func blockAll(session uintptr, baseObjects *baseObjects, weight uint8) error {
}
// Block all DNS traffic except towards specified DNS servers.
-func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weightAllow uint8, weightDeny uint8) error {
+func blockDNS(except []netip.Addr, session uintptr, baseObjects *baseObjects, weightAllow, weightDeny uint8) error {
if weightDeny >= weightAllow {
return errors.New("The allow weight must be greater than the deny weight")
}
@@ -1106,8 +1105,7 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight
allowConditionsV4 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except))
allowConditionsV4 = append(allowConditionsV4, denyConditions...)
for _, ip := range except {
- ip4 := ip.To4()
- if ip4 == nil {
+ if !ip.Is4() {
continue
}
allowConditionsV4 = append(allowConditionsV4, wtFwpmFilterCondition0{
@@ -1115,7 +1113,7 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT32,
- value: uintptr(binary.BigEndian.Uint32(ip4)),
+ value: uintptr(binary.BigEndian.Uint32(ip.AsSlice())),
},
})
}
@@ -1124,11 +1122,10 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight
allowConditionsV6 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except))
allowConditionsV6 = append(allowConditionsV6, denyConditions...)
for _, ip := range except {
- if ip.To4() != nil {
+ if !ip.Is6() {
continue
}
- var address wtFwpByteArray16
- copy(address.byteArray16[:], ip)
+ address := wtFwpByteArray16{byteArray16: ip.As16()}
allowConditionsV6 = append(allowConditionsV6, wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
matchType: cFWP_MATCH_EQUAL,
diff --git a/tunnel/firewall/syscall_windows.go b/tunnel/firewall/syscall_windows.go
index 1d2696a1..4d8eea42 100644
--- a/tunnel/firewall/syscall_windows.go
+++ b/tunnel/firewall/syscall_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
diff --git a/tunnel/firewall/types_windows.go b/tunnel/firewall/types_windows.go
index 9192c023..54e2aad7 100644
--- a/tunnel/firewall/types_windows.go
+++ b/tunnel/firewall/types_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
@@ -148,8 +148,10 @@ var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
}
-var cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
-var cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
+var (
+ cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
+ cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
+)
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
diff --git a/tunnel/firewall/types_windows_386.go b/tunnel/firewall/types_windows_32.go
index e8e90663..29ae553a 100644
--- a/tunnel/firewall/types_windows_386.go
+++ b/tunnel/firewall/types_windows_32.go
@@ -1,6 +1,8 @@
+//go:build 386 || arm
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
diff --git a/tunnel/firewall/types_windows_amd64.go b/tunnel/firewall/types_windows_64.go
index 13fde97a..a476a745 100644
--- a/tunnel/firewall/types_windows_amd64.go
+++ b/tunnel/firewall/types_windows_64.go
@@ -1,6 +1,8 @@
+//go:build amd64 || arm64
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
diff --git a/tunnel/firewall/types_windows_test.go b/tunnel/firewall/types_windows_test.go
index 97cb032c..afa1988f 100644
--- a/tunnel/firewall/types_windows_test.go
+++ b/tunnel/firewall/types_windows_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
@@ -11,7 +11,6 @@ import (
)
func TestWtFwpByteBlobSize(t *testing.T) {
-
const actualWtFwpByteBlobSize = unsafe.Sizeof(wtFwpByteBlob{})
if actualWtFwpByteBlobSize != wtFwpByteBlob_Size {
@@ -21,7 +20,6 @@ func TestWtFwpByteBlobSize(t *testing.T) {
}
func TestWtFwpByteBlobOffsets(t *testing.T) {
-
s := wtFwpByteBlob{}
sp := uintptr(unsafe.Pointer(&s))
@@ -34,7 +32,6 @@ func TestWtFwpByteBlobOffsets(t *testing.T) {
}
func TestWtFwpmAction0Size(t *testing.T) {
-
const actualWtFwpmAction0Size = unsafe.Sizeof(wtFwpmAction0{})
if actualWtFwpmAction0Size != wtFwpmAction0_Size {
@@ -44,7 +41,6 @@ func TestWtFwpmAction0Size(t *testing.T) {
}
func TestWtFwpmAction0Offsets(t *testing.T) {
-
s := wtFwpmAction0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -58,7 +54,6 @@ func TestWtFwpmAction0Offsets(t *testing.T) {
}
func TestWtFwpBitmapArray64Size(t *testing.T) {
-
const actualWtFwpBitmapArray64Size = unsafe.Sizeof(wtFwpBitmapArray64{})
if actualWtFwpBitmapArray64Size != wtFwpBitmapArray64_Size {
@@ -68,7 +63,6 @@ func TestWtFwpBitmapArray64Size(t *testing.T) {
}
func TestWtFwpByteArray6Size(t *testing.T) {
-
const actualWtFwpByteArray6Size = unsafe.Sizeof(wtFwpByteArray6{})
if actualWtFwpByteArray6Size != wtFwpByteArray6_Size {
@@ -78,7 +72,6 @@ func TestWtFwpByteArray6Size(t *testing.T) {
}
func TestWtFwpByteArray16Size(t *testing.T) {
-
const actualWtFwpByteArray16Size = unsafe.Sizeof(wtFwpByteArray16{})
if actualWtFwpByteArray16Size != wtFwpByteArray16_Size {
@@ -88,7 +81,6 @@ func TestWtFwpByteArray16Size(t *testing.T) {
}
func TestWtFwpConditionValue0Size(t *testing.T) {
-
const actualWtFwpConditionValue0Size = unsafe.Sizeof(wtFwpConditionValue0{})
if actualWtFwpConditionValue0Size != wtFwpConditionValue0_Size {
@@ -98,7 +90,6 @@ func TestWtFwpConditionValue0Size(t *testing.T) {
}
func TestWtFwpConditionValue0Offsets(t *testing.T) {
-
s := wtFwpConditionValue0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -111,7 +102,6 @@ func TestWtFwpConditionValue0Offsets(t *testing.T) {
}
func TestWtFwpV4AddrAndMaskSize(t *testing.T) {
-
const actualWtFwpV4AddrAndMaskSize = unsafe.Sizeof(wtFwpV4AddrAndMask{})
if actualWtFwpV4AddrAndMaskSize != wtFwpV4AddrAndMask_Size {
@@ -121,7 +111,6 @@ func TestWtFwpV4AddrAndMaskSize(t *testing.T) {
}
func TestWtFwpV4AddrAndMaskOffsets(t *testing.T) {
-
s := wtFwpV4AddrAndMask{}
sp := uintptr(unsafe.Pointer(&s))
@@ -135,7 +124,6 @@ func TestWtFwpV4AddrAndMaskOffsets(t *testing.T) {
}
func TestWtFwpV6AddrAndMaskSize(t *testing.T) {
-
const actualWtFwpV6AddrAndMaskSize = unsafe.Sizeof(wtFwpV6AddrAndMask{})
if actualWtFwpV6AddrAndMaskSize != wtFwpV6AddrAndMask_Size {
@@ -145,7 +133,6 @@ func TestWtFwpV6AddrAndMaskSize(t *testing.T) {
}
func TestWtFwpV6AddrAndMaskOffsets(t *testing.T) {
-
s := wtFwpV6AddrAndMask{}
sp := uintptr(unsafe.Pointer(&s))
@@ -159,7 +146,6 @@ func TestWtFwpV6AddrAndMaskOffsets(t *testing.T) {
}
func TestWtFwpValue0Size(t *testing.T) {
-
const actualWtFwpValue0Size = unsafe.Sizeof(wtFwpValue0{})
if actualWtFwpValue0Size != wtFwpValue0_Size {
@@ -168,7 +154,6 @@ func TestWtFwpValue0Size(t *testing.T) {
}
func TestWtFwpValue0Offsets(t *testing.T) {
-
s := wtFwpValue0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -181,7 +166,6 @@ func TestWtFwpValue0Offsets(t *testing.T) {
}
func TestWtFwpmDisplayData0Size(t *testing.T) {
-
const actualWtFwpmDisplayData0Size = unsafe.Sizeof(wtFwpmDisplayData0{})
if actualWtFwpmDisplayData0Size != wtFwpmDisplayData0_Size {
@@ -191,7 +175,6 @@ func TestWtFwpmDisplayData0Size(t *testing.T) {
}
func TestWtFwpmDisplayData0Offsets(t *testing.T) {
-
s := wtFwpmDisplayData0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -205,7 +188,6 @@ func TestWtFwpmDisplayData0Offsets(t *testing.T) {
}
func TestWtFwpmFilterCondition0Size(t *testing.T) {
-
const actualWtFwpmFilterCondition0Size = unsafe.Sizeof(wtFwpmFilterCondition0{})
if actualWtFwpmFilterCondition0Size != wtFwpmFilterCondition0_Size {
@@ -215,7 +197,6 @@ func TestWtFwpmFilterCondition0Size(t *testing.T) {
}
func TestWtFwpmFilterCondition0Offsets(t *testing.T) {
-
s := wtFwpmFilterCondition0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -237,7 +218,6 @@ func TestWtFwpmFilterCondition0Offsets(t *testing.T) {
}
func TestWtFwpmFilter0Size(t *testing.T) {
-
const actualWtFwpmFilter0Size = unsafe.Sizeof(wtFwpmFilter0{})
if actualWtFwpmFilter0Size != wtFwpmFilter0_Size {
@@ -247,7 +227,6 @@ func TestWtFwpmFilter0Size(t *testing.T) {
}
func TestWtFwpmFilter0Offsets(t *testing.T) {
-
s := wtFwpmFilter0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -364,7 +343,6 @@ func TestWtFwpmFilter0Offsets(t *testing.T) {
}
func TestWtFwpProvider0Size(t *testing.T) {
-
const actualWtFwpProvider0Size = unsafe.Sizeof(wtFwpProvider0{})
if actualWtFwpProvider0Size != wtFwpProvider0_Size {
@@ -374,7 +352,6 @@ func TestWtFwpProvider0Size(t *testing.T) {
}
func TestWtFwpProvider0Offsets(t *testing.T) {
-
s := wtFwpProvider0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -412,7 +389,6 @@ func TestWtFwpProvider0Offsets(t *testing.T) {
}
func TestWtFwpmSession0Size(t *testing.T) {
-
const actualWtFwpmSession0Size = unsafe.Sizeof(wtFwpmSession0{})
if actualWtFwpmSession0Size != wtFwpmSession0_Size {
@@ -422,7 +398,6 @@ func TestWtFwpmSession0Size(t *testing.T) {
}
func TestWtFwpmSession0Offsets(t *testing.T) {
-
s := wtFwpmSession0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -482,7 +457,6 @@ func TestWtFwpmSession0Offsets(t *testing.T) {
}
func TestWtFwpmSublayer0Size(t *testing.T) {
-
const actualWtFwpmSublayer0Size = unsafe.Sizeof(wtFwpmSublayer0{})
if actualWtFwpmSublayer0Size != wtFwpmSublayer0_Size {
@@ -492,7 +466,6 @@ func TestWtFwpmSublayer0Size(t *testing.T) {
}
func TestWtFwpmSublayer0Offsets(t *testing.T) {
-
s := wtFwpmSublayer0{}
sp := uintptr(unsafe.Pointer(&s))
diff --git a/tunnel/firewall/zsyscall_windows.go b/tunnel/firewall/zsyscall_windows.go
index 846d4ae8..9e60132d 100644
--- a/tunnel/firewall/zsyscall_windows.go
+++ b/tunnel/firewall/zsyscall_windows.go
@@ -19,6 +19,7 @@ const (
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
+ errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
@@ -26,7 +27,7 @@ var (
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
- return nil
+ return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
@@ -39,62 +40,38 @@ func errnoErr(e syscall.Errno) error {
var (
modfwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
- procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0")
procFwpmEngineClose0 = modfwpuclnt.NewProc("FwpmEngineClose0")
- procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0")
- procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0")
- procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0")
+ procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0")
procFwpmFilterAdd0 = modfwpuclnt.NewProc("FwpmFilterAdd0")
+ procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0")
+ procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0")
+ procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0")
+ procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0")
+ procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0")
procFwpmTransactionBegin0 = modfwpuclnt.NewProc("FwpmTransactionBegin0")
procFwpmTransactionCommit0 = modfwpuclnt.NewProc("FwpmTransactionCommit0")
- procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0")
- procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0")
)
-func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) {
- r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0)
- if r1 != 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
func fwpmEngineClose0(engineHandle uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmEngineClose0.Addr(), 1, uintptr(engineHandle), 0, 0)
if r1 != 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
+ err = errnoErr(e1)
}
return
}
-func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) {
- r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd))
+func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) {
+ r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0)
if r1 != 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
+ err = errnoErr(e1)
}
return
}
-func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) {
- r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0)
+func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) {
+ r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0)
if r1 != 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
+ err = errnoErr(e1)
}
return
}
@@ -104,38 +81,26 @@ func fwpmFreeMemory0(p unsafe.Pointer) {
return
}
-func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) {
- r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0)
+func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) {
+ r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0)
if r1 != 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
+ err = errnoErr(e1)
}
return
}
-func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) {
- r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0)
+func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) {
+ r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd))
if r1 != 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
+ err = errnoErr(e1)
}
return
}
-func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
- r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0)
+func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) {
+ r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd))
if r1 != 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
+ err = errnoErr(e1)
}
return
}
@@ -143,23 +108,23 @@ func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
func fwpmTransactionAbort0(engineHandle uintptr) (err error) {
r1, _, e1 := syscall.Syscall(procFwpmTransactionAbort0.Addr(), 1, uintptr(engineHandle), 0, 0)
if r1 != 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
+ err = errnoErr(e1)
}
return
}
-func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) {
- r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd))
+func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) {
+ r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0)
+ if r1 != 0 {
+ err = errnoErr(e1)
+ }
+ return
+}
+
+func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
+ r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0)
if r1 != 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
+ err = errnoErr(e1)
}
return
}
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go
index 1f632725..a831d06e 100644
--- a/tunnel/interfacewatcher.go
+++ b/tunnel/interfacewatcher.go
@@ -1,20 +1,20 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
import (
+ "errors"
+ "fmt"
"log"
"sync"
+ "time"
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -24,63 +24,30 @@ type interfaceWatcherError struct {
serviceError services.Error
err error
}
+
type interfaceWatcherEvent struct {
luid winipcfg.LUID
family winipcfg.AddressFamily
}
+
type interfaceWatcher struct {
- errors chan interfaceWatcherError
+ errors chan interfaceWatcherError
+ started chan winipcfg.AddressFamily
- device *device.Device
- conf *conf.Config
- tun *tun.NativeTun
+ conf *conf.Config
+ adapter *driver.Adapter
+ luid winipcfg.LUID
setupMutex sync.Mutex
interfaceChangeCallback winipcfg.ChangeCallback
changeCallbacks4 []winipcfg.ChangeCallback
changeCallbacks6 []winipcfg.ChangeCallback
storedEvents []interfaceWatcherEvent
-}
-
-func hasDefaultRoute(family winipcfg.AddressFamily, peers []conf.Peer) bool {
- var (
- foundV401 bool
- foundV41281 bool
- foundV600001 bool
- foundV680001 bool
- foundV400 bool
- foundV600 bool
- v40 = [4]byte{}
- v60 = [16]byte{}
- v48 = [4]byte{0x80}
- v68 = [16]byte{0x80}
- )
- for _, peer := range peers {
- for _, allowedip := range peer.AllowedIPs {
- if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) {
- foundV600001 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) {
- foundV680001 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) {
- foundV401 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) {
- foundV41281 = true
- } else if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) {
- foundV600 = true
- } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) {
- foundV400 = true
- }
- }
- }
- if family == windows.AF_INET {
- return foundV400 || (foundV401 && foundV41281)
- } else if family == windows.AF_INET6 {
- return foundV600 || (foundV600001 && foundV680001)
- }
- return false
+ watchdog *time.Timer
}
func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
+ iw.watchdog.Stop()
var changeCallbacks *[]winipcfg.ChangeCallback
var ipversion string
if family == windows.AF_INET {
@@ -100,25 +67,35 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
}
var err error
- log.Printf("Monitoring default %s routes", ipversion)
- *changeCallbacks, err = monitorDefaultRoutes(family, iw.device, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun)
- if err != nil {
- iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
- return
+ if iw.conf.Interface.MTU == 0 {
+ log.Printf("Monitoring MTU of default %s routes", ipversion)
+ *changeCallbacks, err = monitorMTU(family, iw.luid)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err}
+ return
+ }
}
log.Printf("Setting device %s addresses", ipversion)
- err = configureInterface(family, iw.conf, iw.tun)
+ err = configureInterface(family, iw.conf, iw.luid)
if err != nil {
iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err}
return
}
+ evaluateDynamicPitfalls(family, iw.conf, iw.luid)
+
+ iw.started <- family
}
func watchInterface() (*interfaceWatcher, error) {
iw := &interfaceWatcher{
- errors: make(chan interfaceWatcherError, 2),
+ errors: make(chan interfaceWatcherError, 2),
+ started: make(chan winipcfg.AddressFamily, 4),
}
+ iw.watchdog = time.AfterFunc(time.Duration(1<<63-1), func() {
+ iw.errors <- interfaceWatcherError{services.ErrorCreateNetworkAdapter, errors.New("TCP/IP interface for adapter did not appear after one minute")}
+ })
+ iw.watchdog.Stop()
var err error
iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
iw.setupMutex.Lock()
@@ -127,28 +104,41 @@ func watchInterface() (*interfaceWatcher, error) {
if notificationType != winipcfg.MibAddInstance {
return
}
- if iw.tun == nil {
+ if iw.luid == 0 {
iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family})
return
}
- if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) {
+ if iface.InterfaceLUID != iw.luid {
return
}
iw.setup(iface.Family)
+
+ if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown {
+ log.Println("Reinitializing adapter configuration")
+ err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration())
+ if err != nil {
+ log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err))
+ }
+ err = iw.adapter.SetAdapterState(driver.AdapterStateUp)
+ if err != nil {
+ log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err))
+ }
+ }
})
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("unable to register interface change callback: %w", err)
}
return iw, nil
}
-func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config, tun *tun.NativeTun) {
+func (iw *interfaceWatcher) Configure(adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) {
iw.setupMutex.Lock()
defer iw.setupMutex.Unlock()
+ iw.watchdog.Reset(time.Minute)
- iw.device, iw.conf, iw.tun = device, conf, tun
+ iw.adapter, iw.conf, iw.luid = adapter, conf, luid
for _, event := range iw.storedEvents {
- if event.luid == winipcfg.LUID(iw.tun.LUID()) {
+ if event.luid == luid {
iw.setup(event.family)
}
}
@@ -157,10 +147,11 @@ func (iw *interfaceWatcher) Configure(device *device.Device, conf *conf.Config,
func (iw *interfaceWatcher) Destroy() {
iw.setupMutex.Lock()
+ iw.watchdog.Stop()
changeCallbacks4 := iw.changeCallbacks4
changeCallbacks6 := iw.changeCallbacks6
interfaceChangeCallback := iw.interfaceChangeCallback
- tun := iw.tun
+ luid := iw.luid
iw.setupMutex.Unlock()
if interfaceChangeCallback != nil {
@@ -186,15 +177,15 @@ func (iw *interfaceWatcher) Destroy() {
changeCallbacks6 = changeCallbacks6[1:]
}
firewall.DisableFirewall()
- if tun != nil && iw.tun == tun {
+ if luid != 0 && iw.luid == luid {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active
// routes, so to be certain, just remove everything before destroying.
- luid := winipcfg.LUID(tun.LUID())
luid.FlushRoutes(windows.AF_INET)
luid.FlushIPAddresses(windows.AF_INET)
+ luid.FlushDNS(windows.AF_INET)
luid.FlushRoutes(windows.AF_INET6)
luid.FlushIPAddresses(windows.AF_INET6)
- luid.FlushDNS()
+ luid.FlushDNS(windows.AF_INET6)
}
iw.setupMutex.Unlock()
}
diff --git a/tunnel/ipcpermissions.go b/tunnel/ipcpermissions.go
deleted file mode 100644
index 613d0283..00000000
--- a/tunnel/ipcpermissions.go
+++ /dev/null
@@ -1,63 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package tunnel
-
-import (
- "golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/ipc"
-
- "golang.zx2c4.com/wireguard/windows/conf"
-)
-
-func CopyConfigOwnerToIPCSecurityDescriptor(filename string) error {
- if conf.PathIsEncrypted(filename) {
- return nil
- }
-
- fileSd, err := windows.GetNamedSecurityInfo(filename, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
- if err != nil {
- return err
- }
- fileOwner, _, err := fileSd.Owner()
- if err != nil {
- return err
- }
- if fileOwner.IsWellKnown(windows.WinLocalSystemSid) {
- return nil
- }
- additionalEntries := []windows.EXPLICIT_ACCESS{{
- AccessPermissions: windows.GENERIC_ALL,
- AccessMode: windows.GRANT_ACCESS,
- Trustee: windows.TRUSTEE{
- TrusteeForm: windows.TRUSTEE_IS_SID,
- TrusteeType: windows.TRUSTEE_IS_USER,
- TrusteeValue: windows.TrusteeValueFromSID(fileOwner),
- },
- }}
-
- sd, err := ipc.UAPISecurityDescriptor.ToAbsolute()
- if err != nil {
- return err
- }
- dacl, defaulted, _ := sd.DACL()
-
- newDacl, err := windows.ACLFromEntries(additionalEntries, dacl)
- if err != nil {
- return err
- }
- err = sd.SetDACL(newDacl, true, defaulted)
- if err != nil {
- return err
- }
- sd, err = sd.ToSelfRelative()
- if err != nil {
- return err
- }
- ipc.UAPISecurityDescriptor = sd
-
- return nil
-}
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/mtumonitor.go
index 3af9042c..c07823a2 100644
--- a/tunnel/defaultroutemonitor.go
+++ b/tunnel/mtumonitor.go
@@ -1,30 +1,23 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
import (
- "log"
- "sync"
- "time"
-
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
-func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32, blackholeWhenLoop bool) error {
+func findDefaultLUID(family winipcfg.AddressFamily, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32) error {
r, err := winipcfg.GetIPForwardTable2(family)
if err != nil {
return err
}
lowestMetric := ^uint32(0)
- index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want.
- luid := winipcfg.LUID(0) // Hopefully luid zero is unspecified, but hard to find docs saying so.
+ index := uint32(0)
+ luid := winipcfg.LUID(0)
for i := range r {
if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID {
continue
@@ -50,40 +43,24 @@ func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLU
}
*lastLUID = luid
*lastIndex = index
- blackhole := blackholeWhenLoop && index == 0
- bind, _ := device.Bind().(conn.BindSocketToInterface)
- if bind == nil {
- return nil
- }
- if family == windows.AF_INET {
- log.Printf("Binding v4 socket to interface %d (blackhole=%v)", index, blackhole)
- return bind.BindSocketToInterface4(index, blackhole)
- } else if family == windows.AF_INET6 {
- log.Printf("Binding v6 socket to interface %d (blackhole=%v)", index, blackhole)
- return bind.BindSocketToInterface6(index, blackhole)
- }
return nil
}
-func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device, autoMTU bool, blackholeWhenLoop bool, tun *tun.NativeTun) ([]winipcfg.ChangeCallback, error) {
+func monitorMTU(family winipcfg.AddressFamily, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) {
var minMTU uint32
if family == windows.AF_INET {
minMTU = 576
} else if family == windows.AF_INET6 {
minMTU = 1280
}
- ourLUID := winipcfg.LUID(tun.LUID())
lastLUID := winipcfg.LUID(0)
lastIndex := ^uint32(0)
lastMTU := uint32(0)
doIt := func() error {
- err := bindSocketRoute(family, device, ourLUID, &lastLUID, &lastIndex, blackholeWhenLoop)
+ err := findDefaultLUID(family, ourLUID, &lastLUID, &lastIndex)
if err != nil {
return err
}
- if !autoMTU {
- return nil
- }
mtu := uint32(0)
if lastLUID != 0 {
iface, err := lastLUID.Interface()
@@ -107,7 +84,6 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device,
if err != nil {
return err
}
- tun.ForceMTU(int(iface.NLMTU)) // TODO: having one MTU for both v4 and v6 kind of breaks the windows model, so right now this just gets the second one which is... bad.
lastMTU = mtu
}
return nil
@@ -116,32 +92,9 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device,
if err != nil {
return nil, err
}
-
- firstBurst := time.Time{}
- burstMutex := sync.Mutex{}
- burstTimer := time.AfterFunc(time.Hour*200, func() {
- burstMutex.Lock()
- firstBurst = time.Time{}
- doIt()
- burstMutex.Unlock()
- })
- burstTimer.Stop()
- bump := func() {
- burstMutex.Lock()
- burstTimer.Reset(time.Millisecond * 150)
- if firstBurst.IsZero() {
- firstBurst = time.Now()
- } else if time.Since(firstBurst) > time.Second*2 {
- firstBurst = time.Time{}
- burstTimer.Stop()
- doIt()
- }
- burstMutex.Unlock()
- }
-
cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
if route != nil && route.DestinationPrefix.PrefixLength == 0 {
- bump()
+ doIt()
}
})
if err != nil {
@@ -149,7 +102,7 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, device *device.Device,
}
cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
if notificationType == winipcfg.MibParameterNotification {
- bump()
+ doIt()
}
})
if err != nil {
diff --git a/tunnel/pitfalls.go b/tunnel/pitfalls.go
new file mode 100644
index 00000000..fdef6eb2
--- /dev/null
+++ b/tunnel/pitfalls.go
@@ -0,0 +1,177 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package tunnel
+
+import (
+ "log"
+ "net/netip"
+ "strings"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc/mgr"
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+)
+
+func evaluateStaticPitfalls() {
+ go func() {
+ pitfallDnsCacheDisabled()
+ pitfallVirtioNetworkDriver()
+ }()
+}
+
+func evaluateDynamicPitfalls(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) {
+ go func() {
+ pitfallWeakHostSend(family, conf, luid)
+ }()
+}
+
+func pitfallDnsCacheDisabled() {
+ scm, err := mgr.Connect()
+ if err != nil {
+ return
+ }
+ defer scm.Disconnect()
+ svc := mgr.Service{Name: "dnscache"}
+ svc.Handle, err = windows.OpenService(scm.Handle, windows.StringToUTF16Ptr(svc.Name), windows.SERVICE_QUERY_CONFIG)
+ if err != nil {
+ return
+ }
+ defer svc.Close()
+ cfg, err := svc.Config()
+ if err != nil {
+ return
+ }
+ if cfg.StartType != mgr.StartDisabled {
+ return
+ }
+
+ log.Printf("Warning: the %q (dnscache) service is disabled; please re-enable it", cfg.DisplayName)
+}
+
+func pitfallVirtioNetworkDriver() {
+ var modules []windows.RTL_PROCESS_MODULE_INFORMATION
+ for bufferSize := uint32(128 * 1024); ; {
+ moduleBuffer := make([]byte, bufferSize)
+ err := windows.NtQuerySystemInformation(windows.SystemModuleInformation, unsafe.Pointer(&moduleBuffer[0]), bufferSize, &bufferSize)
+ switch err {
+ case windows.STATUS_INFO_LENGTH_MISMATCH:
+ continue
+ case nil:
+ break
+ default:
+ return
+ }
+ mods := (*windows.RTL_PROCESS_MODULES)(unsafe.Pointer(&moduleBuffer[0]))
+ modules = unsafe.Slice(&mods.Modules[0], mods.NumberOfModules)
+ break
+ }
+ for i := range modules {
+ if !strings.EqualFold(windows.ByteSliceToString(modules[i].FullPathName[modules[i].OffsetToFileName:]), "netkvm.sys") {
+ continue
+ }
+ driverPath := `\\?\GLOBALROOT` + windows.ByteSliceToString(modules[i].FullPathName[:])
+ var zero windows.Handle
+ infoSize, err := windows.GetFileVersionInfoSize(driverPath, &zero)
+ if err != nil {
+ return
+ }
+ versionInfo := make([]byte, infoSize)
+ err = windows.GetFileVersionInfo(driverPath, 0, infoSize, unsafe.Pointer(&versionInfo[0]))
+ if err != nil {
+ return
+ }
+ var fixedInfo *windows.VS_FIXEDFILEINFO
+ fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo))
+ err = windows.VerQueryValue(unsafe.Pointer(&versionInfo[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen)
+ if err != nil {
+ return
+ }
+ const minimumPlausibleVersion = 40 << 48
+ const minimumGoodVersion = (100 << 48) | (85 << 32) | (104 << 16) | (20800 << 0)
+ version := (uint64(fixedInfo.FileVersionMS) << 32) | uint64(fixedInfo.FileVersionLS)
+ if version >= minimumGoodVersion || version < minimumPlausibleVersion {
+ return
+ }
+ log.Println("Warning: the VirtIO network driver (NetKVM) is out of date and may cause known problems; please update to v100.85.104.20800 or later")
+ return
+ }
+}
+
+func pitfallWeakHostSend(family winipcfg.AddressFamily, conf *conf.Config, ourLUID winipcfg.LUID) {
+ routingTable, err := winipcfg.GetIPForwardTable2(family)
+ if err != nil {
+ return
+ }
+ type endpointRoute struct {
+ addr netip.Addr
+ name string
+ lowestMetric uint32
+ highestCIDR uint8
+ weakHostSend bool
+ finalIsOurs bool
+ }
+ endpoints := make([]endpointRoute, 0, len(conf.Peers))
+ for _, peer := range conf.Peers {
+ addr, err := netip.ParseAddr(peer.Endpoint.Host)
+ if err != nil || (addr.Is4() && family != windows.AF_INET) || (addr.Is6() && family != windows.AF_INET6) {
+ continue
+ }
+ endpoints = append(endpoints, endpointRoute{addr: addr, lowestMetric: ^uint32(0)})
+ }
+ for i := range routingTable {
+ var (
+ ifrow *winipcfg.MibIfRow2
+ ifacerow *winipcfg.MibIPInterfaceRow
+ metric uint32
+ )
+ for j := range endpoints {
+ r, e := &routingTable[i], &endpoints[j]
+ if r.DestinationPrefix.PrefixLength < e.highestCIDR {
+ continue
+ }
+ if !r.DestinationPrefix.Prefix().Contains(e.addr) {
+ continue
+ }
+ if ifrow == nil {
+ ifrow, err = r.InterfaceLUID.Interface()
+ if err != nil {
+ continue
+ }
+ }
+ if ifrow.OperStatus != winipcfg.IfOperStatusUp {
+ continue
+ }
+ if ifacerow == nil {
+ ifacerow, err = r.InterfaceLUID.IPInterface(family)
+ if err != nil {
+ continue
+ }
+ metric = r.Metric + ifacerow.Metric
+ }
+ if r.DestinationPrefix.PrefixLength == e.highestCIDR && metric > e.lowestMetric {
+ continue
+ }
+ e.lowestMetric = metric
+ e.highestCIDR = r.DestinationPrefix.PrefixLength
+ e.finalIsOurs = r.InterfaceLUID == ourLUID
+ if !e.finalIsOurs {
+ e.name = ifrow.Alias()
+ e.weakHostSend = ifacerow.ForwardingEnabled || ifacerow.WeakHostSend
+ }
+ }
+ }
+ problematicInterfaces := make(map[string]bool, len(endpoints))
+ for _, e := range endpoints {
+ if e.weakHostSend && e.finalIsOurs {
+ problematicInterfaces[e.name] = true
+ }
+ }
+ for iface := range problematicInterfaces {
+ log.Printf("Warning: the %q interface has Forwarding/WeakHostSend enabled, which will cause routing loops", iface)
+ }
+}
diff --git a/tunnel/scriptrunner.go b/tunnel/scriptrunner.go
new file mode 100644
index 00000000..eb97d98d
--- /dev/null
+++ b/tunnel/scriptrunner.go
@@ -0,0 +1,77 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package tunnel
+
+import (
+ "bufio"
+ "fmt"
+ "log"
+ "os"
+ "path/filepath"
+ "syscall"
+
+ "golang.org/x/sys/windows"
+
+ "golang.zx2c4.com/wireguard/windows/conf"
+)
+
+func runScriptCommand(command, interfaceName string) error {
+ if len(command) == 0 {
+ return nil
+ }
+ if !conf.AdminBool("DangerousScriptExecution") {
+ log.Printf("Skipping execution of script, because dangerous script execution is safely disabled: %#q", command)
+ return nil
+ }
+ log.Printf("Executing: %#q", command)
+ comspec, _ := os.LookupEnv("COMSPEC")
+ if len(comspec) == 0 {
+ system32, err := windows.GetSystemDirectory()
+ if err != nil {
+ return err
+ }
+ comspec = filepath.Join(system32, "cmd.exe")
+ }
+
+ devNull, err := os.OpenFile(os.DevNull, os.O_RDWR, 0)
+ if err != nil {
+ return err
+ }
+ defer devNull.Close()
+ reader, writer, err := os.Pipe()
+ if err != nil {
+ return err
+ }
+ process, err := os.StartProcess(comspec, nil /* CmdLine below */, &os.ProcAttr{
+ Files: []*os.File{devNull, writer, writer},
+ Env: append(os.Environ(), "WIREGUARD_TUNNEL_NAME="+interfaceName),
+ Sys: &syscall.SysProcAttr{
+ HideWindow: true,
+ CmdLine: fmt.Sprintf("cmd /c %s", command),
+ },
+ })
+ writer.Close()
+ if err != nil {
+ reader.Close()
+ return err
+ }
+ go func() {
+ scanner := bufio.NewScanner(reader)
+ for scanner.Scan() {
+ log.Printf("cmd> %s", scanner.Text())
+ }
+ }()
+ state, err := process.Wait()
+ reader.Close()
+ if err != nil {
+ return err
+ }
+ if state.ExitCode() == 0 {
+ return nil
+ }
+ log.Printf("Command error exit status: %d", state.ExitCode())
+ return windows.ERROR_GENERIC_COMMAND_FAILED
+}
diff --git a/tunnel/service.go b/tunnel/service.go
index e535894b..a56ed1f3 100644
--- a/tunnel/service.go
+++ b/tunnel/service.go
@@ -1,33 +1,27 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
import (
- "bufio"
"bytes"
"fmt"
"log"
- "net"
"os"
"runtime"
- "runtime/debug"
- "strings"
"time"
+ "golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/ipc"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/ringlogger"
"golang.zx2c4.com/wireguard/windows/services"
- "golang.zx2c4.com/wireguard/windows/version"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
type tunnelService struct {
@@ -35,12 +29,13 @@ type tunnelService struct {
}
func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
- changes <- svc.Status{State: svc.StartPending}
+ serviceState := svc.StartPending
+ changes <- svc.Status{State: serviceState}
- var dev *device.Device
- var uapi net.Listener
var watcher *interfaceWatcher
- var nativeTun *tun.NativeTun
+ var adapter *driver.Adapter
+ var luid winipcfg.LUID
+ var config *conf.Config
var err error
serviceError := services.ErrorSuccess
@@ -50,7 +45,8 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
if logErr != nil {
log.Println(logErr)
}
- changes <- svc.Status{State: svc.StopPending}
+ serviceState = svc.StopPending
+ changes <- svc.Status{State: serviceState}
stopIt := make(chan bool, 1)
go func() {
@@ -84,65 +80,63 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
}
}()
+ if logErr == nil && adapter != nil && config != nil {
+ logErr = runScriptCommand(config.Interface.PreDown, config.Name)
+ }
if watcher != nil {
watcher.Destroy()
}
- if uapi != nil {
- uapi.Close()
+ if adapter != nil {
+ adapter.Close()
}
- if dev != nil {
- dev.Close()
+ if logErr == nil && adapter != nil && config != nil {
+ _ = runScriptCommand(config.Interface.PostDown, config.Name)
}
stopIt <- true
log.Println("Shutting down")
}()
- err = ringlogger.InitGlobalLogger("TUN")
+ var logFile string
+ logFile, err = conf.LogFile(true)
if err != nil {
serviceError = services.ErrorRingloggerOpen
return
}
- defer func() {
- if x := recover(); x != nil {
- for _, line := range append([]string{fmt.Sprint(x)}, strings.Split(string(debug.Stack()), "\n")...) {
- if len(strings.TrimSpace(line)) > 0 {
- log.Println(line)
- }
- }
- panic(x)
- }
- }()
-
- conf, err := conf.LoadFromPath(service.Path)
+ err = ringlogger.InitGlobalLogger(logFile, "TUN")
if err != nil {
- serviceError = services.ErrorLoadConfiguration
+ serviceError = services.ErrorRingloggerOpen
return
}
- conf.DeduplicateNetworkEntries()
- err = CopyConfigOwnerToIPCSecurityDescriptor(service.Path)
+
+ config, err = conf.LoadFromPath(service.Path)
if err != nil {
serviceError = services.ErrorLoadConfiguration
return
}
+ config.DeduplicateNetworkEntries()
- logPrefix := fmt.Sprintf("[%s] ", conf.Name)
- log.SetPrefix(logPrefix)
+ log.SetPrefix(fmt.Sprintf("[%s] ", config.Name))
- log.Println("Starting", version.UserAgent())
+ services.PrintStarting()
- if m, err := mgr.Connect(); err == nil {
- if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked {
- /* If we don't do this, then the Wintun installation will block forever, because
- * installing a Wintun device starts a service too. Apparently at boot time, Windows
- * 8.1 locks the SCM for each service start, creating a deadlock if we don't announce
- * that we're running before starting additional services.
- */
- log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner)
- changes <- svc.Status{State: svc.Running}
+ if services.StartedAtBoot() {
+ if m, err := mgr.Connect(); err == nil {
+ if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked {
+ /* If we don't do this, then the driver installation will block forever, because
+ * installing a network adapter starts the driver service too. Apparently at boot time,
+ * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't
+ * announce that we're running before starting additional services.
+ */
+ log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner)
+ serviceState = svc.Running
+ changes <- svc.Status{State: serviceState}
+ }
+ m.Disconnect()
}
- m.Disconnect()
}
+ evaluateStaticPitfalls()
+
log.Println("Watching network interfaces")
watcher, err = watchInterface()
if err != nil {
@@ -151,28 +145,49 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
}
log.Println("Resolving DNS names")
- uapiConf, err := conf.ToUAPI()
+ err = config.ResolveEndpoints()
if err != nil {
serviceError = services.ErrorDNSLookup
return
}
- log.Println("Creating Wintun interface")
- wintun, err := tun.CreateTUNWithRequestedGUID(conf.Name, deterministicGUID(conf), 0)
+ log.Println("Creating network adapter")
+ for i := 0; i < 15; i++ {
+ if i > 0 {
+ time.Sleep(time.Second)
+ log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ }
+ adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config))
+ if err == nil || !services.StartedAtBoot() {
+ break
+ }
+ }
if err != nil {
- serviceError = services.ErrorCreateWintun
+ err = fmt.Errorf("Error creating adapter: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
return
}
- nativeTun = wintun.(*tun.NativeTun)
- wintunVersion, ndisVersion, err := nativeTun.Version()
+ luid = adapter.LUID()
+ driverVersion, err := driver.RunningVersion()
if err != nil {
- log.Printf("Warning: unable to determine Wintun version: %v", err)
+ log.Printf("Warning: unable to determine driver version: %v", err)
} else {
- log.Printf("Using Wintun/%s (NDIS %s)", wintunVersion, ndisVersion)
+ log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff)
+ }
+ err = adapter.SetLogging(driver.AdapterLogOn)
+ if err != nil {
+ err = fmt.Errorf("Error enabling adapter logging: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
}
- log.Println("Enabling firewall rules")
- err = enableFirewall(conf, nativeTun)
+ err = runScriptCommand(config.Interface.PreUp, config.Name)
+ if err != nil {
+ serviceError = services.ErrorRunScript
+ return
+ }
+
+ err = enableFirewall(config, luid)
if err != nil {
serviceError = services.ErrorFirewall
return
@@ -185,43 +200,28 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
return
}
- log.Println("Creating interface instance")
- logOutput := log.New(ringlogger.Global, logPrefix, 0)
- logger := &device.Logger{logOutput, logOutput, logOutput}
- dev = device.NewDevice(wintun, logger)
-
log.Println("Setting interface configuration")
- uapi, err = ipc.UAPIListen(conf.Name)
+ err = adapter.SetConfiguration(config.ToDriverConfiguration())
if err != nil {
- serviceError = services.ErrorUAPIListen
+ serviceError = services.ErrorDeviceSetConfig
return
}
- ipcErr := dev.IpcSetOperation(bufio.NewReader(strings.NewReader(uapiConf)))
- if ipcErr != nil {
- err = ipcErr
- serviceError = services.ErrorDeviceSetConfig
+ err = adapter.SetAdapterState(driver.AdapterStateUp)
+ if err != nil {
+ serviceError = services.ErrorDeviceBringUp
return
}
+ watcher.Configure(adapter, config, luid)
- log.Println("Bringing peers up")
- dev.Up()
-
- watcher.Configure(dev, conf, nativeTun)
-
- log.Println("Listening for UAPI requests")
- go func() {
- for {
- conn, err := uapi.Accept()
- if err != nil {
- continue
- }
- go dev.IpcHandle(conn)
- }
- }()
+ err = runScriptCommand(config.Interface.PostUp, config.Name)
+ if err != nil {
+ serviceError = services.ErrorRunScript
+ return
+ }
- changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown}
- log.Println("Startup complete")
+ changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown}
+ var started bool
for {
select {
case c := <-r:
@@ -233,8 +233,13 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
default:
log.Printf("Unexpected service control request #%d\n", c)
}
- case <-dev.Wait():
- return
+ case <-watcher.started:
+ if !started {
+ serviceState = svc.Running
+ changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown}
+ log.Println("Startup complete")
+ started = true
+ }
case e := <-watcher.errors:
serviceError, err = e.serviceError, e.err
return
@@ -247,7 +252,7 @@ func Run(confPath string) error {
if err != nil {
return err
}
- serviceName, err := services.ServiceNameOfTunnel(name)
+ serviceName, err := conf.ServiceNameOfTunnel(name)
if err != nil {
return err
}
diff --git a/tunnel/winipcfg/interface_change_handler.go b/tunnel/winipcfg/interface_change_handler.go
index 9406c18a..af29801a 100644
--- a/tunnel/winipcfg/interface_change_handler.go
+++ b/tunnel/winipcfg/interface_change_handler.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/luid.go b/tunnel/winipcfg/luid.go
index 8f5ba61b..0c898b89 100644
--- a/tunnel/winipcfg/luid.go
+++ b/tunnel/winipcfg/luid.go
@@ -1,17 +1,16 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"errors"
- "fmt"
- "net"
+ "net/netip"
+ "strings"
"golang.org/x/sys/windows"
- "golang.org/x/sys/windows/registry"
)
// LUID represents a network interface.
@@ -64,12 +63,23 @@ func LUIDFromGUID(guid *windows.GUID) (LUID, error) {
return luid, nil
}
+// LUIDFromIndex function converts a local index for a network interface to the locally unique identifier (LUID) for the interface.
+// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-convertinterfaceindextoluid
+func LUIDFromIndex(index uint32) (LUID, error) {
+ var luid LUID
+ err := convertInterfaceIndexToLUID(index, &luid)
+ if err != nil {
+ return 0, err
+ }
+ return luid, nil
+}
+
// IPAddress method returns MibUnicastIPAddressRow struct that matches to provided 'ip' argument. Corresponds to GetUnicastIpAddressEntry
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry)
-func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) {
+func (luid LUID) IPAddress(addr netip.Addr) (*MibUnicastIPAddressRow, error) {
row := &MibUnicastIPAddressRow{InterfaceLUID: luid}
- err := row.Address.SetIP(ip, 0)
+ err := row.Address.SetAddr(addr)
if err != nil {
return nil, err
}
@@ -84,22 +94,24 @@ func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) {
// AddIPAddress method adds new unicast IP address to the interface. Corresponds to CreateUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry).
-func (luid LUID) AddIPAddress(address net.IPNet) error {
+func (luid LUID) AddIPAddress(address netip.Prefix) error {
row := &MibUnicastIPAddressRow{}
row.Init()
row.InterfaceLUID = luid
- err := row.Address.SetIP(address.IP, 0)
+ row.DadState = DadStatePreferred
+ row.ValidLifetime = 0xffffffff
+ row.PreferredLifetime = 0xffffffff
+ err := row.Address.SetAddr(address.Addr())
if err != nil {
return err
}
- ones, _ := address.Mask.Size()
- row.OnLinkPrefixLength = uint8(ones)
+ row.OnLinkPrefixLength = uint8(address.Bits())
return row.Create()
}
// AddIPAddresses method adds multiple new unicast IP addresses to the interface. Corresponds to CreateUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry).
-func (luid LUID) AddIPAddresses(addresses []net.IPNet) error {
+func (luid LUID) AddIPAddresses(addresses []netip.Prefix) error {
for i := range addresses {
err := luid.AddIPAddress(addresses[i])
if err != nil {
@@ -110,7 +122,7 @@ func (luid LUID) AddIPAddresses(addresses []net.IPNet) error {
}
// SetIPAddresses method sets new unicast IP addresses to the interface.
-func (luid LUID) SetIPAddresses(addresses []net.IPNet) error {
+func (luid LUID) SetIPAddresses(addresses []netip.Prefix) error {
err := luid.FlushIPAddresses(windows.AF_UNSPEC)
if err != nil {
return err
@@ -119,16 +131,15 @@ func (luid LUID) SetIPAddresses(addresses []net.IPNet) error {
}
// SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface.
-func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error {
+func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []netip.Prefix) error {
err := luid.FlushIPAddresses(family)
if err != nil {
return err
}
for i := range addresses {
- asV4 := addresses[i].IP.To4()
- if asV4 == nil && family == windows.AF_INET {
+ if !addresses[i].Addr().Is4() && family == windows.AF_INET {
continue
- } else if asV4 != nil && family == windows.AF_INET6 {
+ } else if !addresses[i].Addr().Is6() && family == windows.AF_INET6 {
continue
}
err := luid.AddIPAddress(addresses[i])
@@ -141,17 +152,16 @@ func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.I
// DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry).
-func (luid LUID) DeleteIPAddress(address net.IPNet) error {
+func (luid LUID) DeleteIPAddress(address netip.Prefix) error {
row := &MibUnicastIPAddressRow{}
row.Init()
row.InterfaceLUID = luid
- err := row.Address.SetIP(address.IP, 0)
+ err := row.Address.SetAddr(address.Addr())
if err != nil {
return err
}
// Note: OnLinkPrefixLength member is ignored by DeleteUnicastIpAddressEntry().
- ones, _ := address.Mask.Size()
- row.OnLinkPrefixLength = uint8(ones)
+ row.OnLinkPrefixLength = uint8(address.Bits())
return row.Delete()
}
@@ -175,15 +185,17 @@ func (luid LUID) FlushIPAddresses(family AddressFamily) error {
// Route method returns route determined with the input arguments. Corresponds to GetIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2).
// NOTE: If the corresponding route isn't found, the method will return error.
-func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2, error) {
+func (luid LUID) Route(destination netip.Prefix, nextHop netip.Addr) (*MibIPforwardRow2, error) {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
- err := row.DestinationPrefix.SetIPNet(destination)
+ row.ValidLifetime = 0xffffffff
+ row.PreferredLifetime = 0xffffffff
+ err := row.DestinationPrefix.SetPrefix(destination)
if err != nil {
return nil, err
}
- err = row.NextHop.SetIP(nextHop, 0)
+ err = row.NextHop.SetAddr(nextHop)
if err != nil {
return nil, err
}
@@ -197,15 +209,15 @@ func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2
// AddRoute method adds a route to the interface. Corresponds to CreateIpForwardEntry2 function, with added splitDefault feature.
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2)
-func (luid LUID) AddRoute(destination net.IPNet, nextHop net.IP, metric uint32) error {
+func (luid LUID) AddRoute(destination netip.Prefix, nextHop netip.Addr, metric uint32) error {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
- err := row.DestinationPrefix.SetIPNet(destination)
+ err := row.DestinationPrefix.SetPrefix(destination)
if err != nil {
return err
}
- err = row.NextHop.SetIP(nextHop, 0)
+ err = row.NextHop.SetAddr(nextHop)
if err != nil {
return err
}
@@ -240,10 +252,9 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat
return err
}
for _, rd := range routesData {
- asV4 := rd.Destination.IP.To4()
- if asV4 == nil && family == windows.AF_INET {
+ if !rd.Destination.Addr().Is4() && family == windows.AF_INET {
continue
- } else if asV4 != nil && family == windows.AF_INET6 {
+ } else if !rd.Destination.Addr().Is6() && family == windows.AF_INET6 {
continue
}
err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric)
@@ -256,15 +267,15 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat
// DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2).
-func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error {
+func (luid LUID) DeleteRoute(destination netip.Prefix, nextHop netip.Addr) error {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
- err := row.DestinationPrefix.SetIPNet(destination)
+ err := row.DestinationPrefix.SetPrefix(destination)
if err != nil {
return err
}
- err = row.NextHop.SetIP(nextHop, 0)
+ err = row.NextHop.SetAddr(nextHop)
if err != nil {
return err
}
@@ -297,17 +308,19 @@ func (luid LUID) FlushRoutes(family AddressFamily) error {
}
// DNS method returns all DNS server addresses associated with the adapter.
-func (luid LUID) DNS() ([]net.IP, error) {
+func (luid LUID) DNS() ([]netip.Addr, error) {
addresses, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault)
if err != nil {
return nil, err
}
- r := make([]net.IP, 0, len(addresses))
+ r := make([]netip.Addr, 0, len(addresses))
for _, addr := range addresses {
if addr.LUID == luid {
for dns := addr.FirstDNSServerAddress; dns != nil; dns = dns.Next {
if ip := dns.Address.IP(); ip != nil {
- r = append(r, ip)
+ if a, ok := netip.AddrFromSlice(ip); ok {
+ r = append(r, a)
+ }
} else {
return nil, windows.ERROR_INVALID_PARAMETER
}
@@ -317,141 +330,58 @@ func (luid LUID) DNS() ([]net.IP, error) {
return r, nil
}
-const (
- netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both"
- netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both"
- netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
- netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
-)
-
-// FlushDNS method clears all DNS servers associated with the adapter.
-func (luid LUID) FlushDNS() error {
- cmds := make([]string, 0, 2)
- ipif4, err := luid.IPInterface(windows.AF_INET)
- if err == nil {
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush4, ipif4.InterfaceIndex))
- }
- ipif6, err := luid.IPInterface(windows.AF_INET6)
- if err == nil {
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush6, ipif6.InterfaceIndex))
- }
-
- if len(cmds) == 0 {
- return nil
- }
- return runNetsh(cmds)
-}
-
-// AddDNS method associates additional DNS servers with the adapter.
-func (luid LUID) AddDNS(dnses []net.IP) error {
- var ipif4, ipif6 *MibIPInterfaceRow
- var err error
- cmds := make([]string, 0, len(dnses))
- for i := 0; i < len(dnses); i++ {
- if v4 := dnses[i].To4(); v4 != nil {
- if ipif4 == nil {
- ipif4, err = luid.IPInterface(windows.AF_INET)
- if err != nil {
- return err
- }
- }
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif4.InterfaceIndex, v4.String()))
- } else if v6 := dnses[i].To16(); v6 != nil {
- if ipif6 == nil {
- ipif6, err = luid.IPInterface(windows.AF_INET6)
- if err != nil {
- return err
- }
- }
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif6.InterfaceIndex, v6.String()))
- }
- }
-
- if len(cmds) == 0 {
- return nil
+// SetDNS method clears previous and associates new DNS servers and search domains with the adapter for a specific family.
+func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []string) error {
+ if family != windows.AF_INET && family != windows.AF_INET6 {
+ return windows.ERROR_PROTOCOL_UNREACHABLE
}
- return runNetsh(cmds)
-}
-// SetDNS method clears previous and associates new DNS servers with the adapter.
-func (luid LUID) SetDNS(dnses []net.IP) error {
- cmds := make([]string, 0, 2+len(dnses))
- ipif4, err := luid.IPInterface(windows.AF_INET)
- if err == nil {
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush4, ipif4.InterfaceIndex))
- }
- ipif6, err := luid.IPInterface(windows.AF_INET6)
- if err == nil {
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateFlush6, ipif6.InterfaceIndex))
- }
- for i := 0; i < len(dnses); i++ {
- if v4 := dnses[i].To4(); v4 != nil {
- if ipif4 == nil {
- return windows.ERROR_NOT_SUPPORTED
- }
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif4.InterfaceIndex, v4.String()))
- } else if v6 := dnses[i].To16(); v6 != nil {
- if ipif6 == nil {
- return windows.ERROR_NOT_SUPPORTED
- }
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif6.InterfaceIndex, v6.String()))
+ var filteredServers []string
+ for _, server := range servers {
+ if (server.Is4() && family == windows.AF_INET) || (server.Is6() && family == windows.AF_INET6) {
+ filteredServers = append(filteredServers, server.String())
}
}
-
- if len(cmds) == 0 {
- return nil
- }
- return runNetsh(cmds)
-}
-
-// SetDNSForFamily method clears previous and associates new DNS servers with the adapter for a specific family.
-func (luid LUID) SetDNSForFamily(family AddressFamily, dnses []net.IP) error {
- var templateFlush string
- if family == windows.AF_INET {
- templateFlush = netshCmdTemplateFlush4
- } else if family == windows.AF_INET6 {
- templateFlush = netshCmdTemplateFlush6
- }
-
- cmds := make([]string, 0, 1+len(dnses))
- ipif, err := luid.IPInterface(family)
+ servers16, err := windows.UTF16PtrFromString(strings.Join(filteredServers, ","))
if err != nil {
return err
}
- cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex))
- for i := 0; i < len(dnses); i++ {
- if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET {
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String()))
- } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 {
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String()))
- }
+ domains16, err := windows.UTF16PtrFromString(strings.Join(domains, ","))
+ if err != nil {
+ return err
}
- return runNetsh(cmds)
-}
-
-// SetDNSDomain method sets the interface-specific DNS domain.
-func (luid LUID) SetDNSDomain(domain string) error {
guid, err := luid.GUID()
if err != nil {
- return fmt.Errorf("Error converting luid to guid: %v", err)
+ return err
}
- key, err := registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", guid), registry.QUERY_VALUE)
- if err != nil {
- return fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err)
+ dnsInterfaceSettings := &DnsInterfaceSettings{
+ Version: DnsInterfaceSettingsVersion1,
+ Flags: DnsInterfaceSettingsFlagNameserver | DnsInterfaceSettingsFlagSearchList,
+ NameServer: servers16,
+ SearchList: domains16,
}
- paths, _, err := key.GetStringsValue("IpConfig")
- key.Close()
- if err != nil {
- return fmt.Errorf("Error reading IpConfig registry key: %v", err)
+ if family == windows.AF_INET6 {
+ dnsInterfaceSettings.Flags |= DnsInterfaceSettingsFlagIPv6
}
- if len(paths) == 0 {
- return errors.New("No TCP/IP interfaces found on adapter")
+ // For >= Windows 10 1809
+ err = SetInterfaceDnsSettings(*guid, dnsInterfaceSettings)
+ if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
+ return err
}
- key, err = registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), registry.SET_VALUE)
+
+ // For < Windows 10 1809
+ err = luid.fallbackSetDNSForFamily(family, servers)
if err != nil {
- return fmt.Errorf("Unable to open TCP/IP network registry key: %v", err)
+ return err
}
- err = key.SetStringValue("Domain", domain)
- key.Close()
- return err
+ if len(domains) > 0 {
+ return luid.fallbackSetDNSDomain(domains[0])
+ } else {
+ return luid.fallbackSetDNSDomain("")
+ }
+}
+
+// FlushDNS method clears all DNS servers associated with the adapter.
+func (luid LUID) FlushDNS(family AddressFamily) error {
+ return luid.SetDNS(family, nil, nil)
}
diff --git a/tunnel/winipcfg/mksyscall.go b/tunnel/winipcfg/mksyscall.go
index 8edb1cf2..d62d38df 100644
--- a/tunnel/winipcfg/mksyscall.go
+++ b/tunnel/winipcfg/mksyscall.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/netsh.go b/tunnel/winipcfg/netsh.go
index 4714c520..4f8e5b13 100644
--- a/tunnel/winipcfg/netsh.go
+++ b/tunnel/winipcfg/netsh.go
@@ -1,49 +1,108 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"bytes"
+ "errors"
"fmt"
"io"
+ "net/netip"
"os/exec"
"path/filepath"
"strings"
+ "syscall"
"golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/registry"
)
-// I wish we didn't have to do this. netiohlp.dll (what's used by netsh.exe) has some nice tricks with writing directly
-// to the registry and the nsi kernel object, but it's not clear copying those makes for a stable interface. WMI doesn't
-// work with v6. CMI isn't in Windows 7.
func runNetsh(cmds []string) error {
system32, err := windows.GetSystemDirectory()
if err != nil {
return err
}
- cmd := exec.Command(filepath.Join(system32, "netsh.exe")) // I wish we could append (, "-f", "CONIN$") but Go sets up the process context wrong.
+ cmd := exec.Command(filepath.Join(system32, "netsh.exe"))
+ cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
+
stdin, err := cmd.StdinPipe()
if err != nil {
- return fmt.Errorf("runNetsh stdin pipe - %v", err)
+ return fmt.Errorf("runNetsh stdin pipe - %w", err)
}
go func() {
defer stdin.Close()
io.WriteString(stdin, strings.Join(append(cmds, "exit\r\n"), "\r\n"))
}()
output, err := cmd.CombinedOutput()
- if err != nil {
- return fmt.Errorf("runNetsh run - %v", err)
- }
// Horrible kludges, sorry.
- cleaned := bytes.ReplaceAll(output, []byte("netsh>"), []byte{})
+ cleaned := bytes.ReplaceAll(output, []byte{'\r', '\n'}, []byte{'\n'})
+ cleaned = bytes.ReplaceAll(cleaned, []byte("netsh>"), []byte{})
cleaned = bytes.ReplaceAll(cleaned, []byte("There are no Domain Name Servers (DNS) configured on this computer."), []byte{})
cleaned = bytes.TrimSpace(cleaned)
- if len(cleaned) != 0 {
- return fmt.Errorf("runNetsh returned error strings.\ninput:\n%s\noutput\n:%s",
- strings.Join(cmds, "\n"), bytes.ReplaceAll(output, []byte{'\r', '\n'}, []byte{'\n'}))
+ if len(cleaned) != 0 && err == nil {
+ return fmt.Errorf("netsh: %#q", string(cleaned))
+ } else if err != nil {
+ return fmt.Errorf("netsh: %v: %#q", err, string(cleaned))
}
return nil
}
+
+const (
+ netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both"
+ netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both"
+ netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
+ netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
+)
+
+func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error {
+ var templateFlush string
+ if family == windows.AF_INET {
+ templateFlush = netshCmdTemplateFlush4
+ } else if family == windows.AF_INET6 {
+ templateFlush = netshCmdTemplateFlush6
+ }
+
+ cmds := make([]string, 0, 1+len(dnses))
+ ipif, err := luid.IPInterface(family)
+ if err != nil {
+ return err
+ }
+ cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex))
+ for i := 0; i < len(dnses); i++ {
+ if dnses[i].Is4() && family == windows.AF_INET {
+ cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, dnses[i].String()))
+ } else if dnses[i].Is6() && family == windows.AF_INET6 {
+ cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, dnses[i].String()))
+ }
+ }
+ return runNetsh(cmds)
+}
+
+func (luid LUID) fallbackSetDNSDomain(domain string) error {
+ guid, err := luid.GUID()
+ if err != nil {
+ return fmt.Errorf("Error converting luid to guid: %w", err)
+ }
+ key, err := registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", guid), registry.QUERY_VALUE)
+ if err != nil {
+ return fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %w", err)
+ }
+ paths, _, err := key.GetStringsValue("IpConfig")
+ key.Close()
+ if err != nil {
+ return fmt.Errorf("Error reading IpConfig registry key: %w", err)
+ }
+ if len(paths) == 0 {
+ return errors.New("No TCP/IP interfaces found on adapter")
+ }
+ key, err = registry.OpenKey(registry.LOCAL_MACHINE, fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), registry.SET_VALUE)
+ if err != nil {
+ return fmt.Errorf("Unable to open TCP/IP network registry key: %w", err)
+ }
+ err = key.SetStringValue("Domain", domain)
+ key.Close()
+ return err
+}
diff --git a/tunnel/winipcfg/route_change_handler.go b/tunnel/winipcfg/route_change_handler.go
index 1b4bad95..4b78331e 100644
--- a/tunnel/winipcfg/route_change_handler.go
+++ b/tunnel/winipcfg/route_change_handler.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go
index 81f9335d..8e8f4a59 100644
--- a/tunnel/winipcfg/types.go
+++ b/tunnel/winipcfg/types.go
@@ -1,13 +1,15 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
- "bytes"
- "net"
+ "encoding/binary"
+ "fmt"
+ "net/netip"
+ "strconv"
"unsafe"
"golang.org/x/sys/windows"
@@ -581,19 +583,17 @@ const (
ScopeLevelCount = 16
)
-// Theoretical array index limitations
-const (
- maxIndexCount8 = (1 << 31) - 1
- maxIndexCount16 = (1 << 30) - 1
-)
-
// RouteData structure describes a route to add
type RouteData struct {
- Destination net.IPNet
- NextHop net.IP
+ Destination netip.Prefix
+ NextHop netip.Addr
Metric uint32
}
+func (routeData *RouteData) String() string {
+ return fmt.Sprintf("%+v", *routeData)
+}
+
// IPAdapterDNSSuffix structure stores a DNS suffix in a linked list of DNS suffixes for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_dns_suffix
type IPAdapterDNSSuffix struct {
@@ -609,15 +609,7 @@ func (obj *IPAdapterDNSSuffix) String() string {
// AdapterName method returns the name of the adapter with which these addresses are associated.
// Unlike an adapter's friendly name, the adapter name returned by AdapterName is permanent and cannot be modified by the user.
func (addr *IPAdapterAddresses) AdapterName() string {
- if addr.adapterName == nil {
- return ""
- }
- slice := (*(*[maxIndexCount8]uint8)(unsafe.Pointer(addr.adapterName)))[:]
- null := bytes.IndexByte(slice, 0)
- if null != -1 {
- slice = slice[:null]
- }
- return string(slice)
+ return windows.BytePtrToString(addr.adapterName)
}
// DNSSuffix method returns adapter DNS suffix associated with this adapter.
@@ -625,7 +617,7 @@ func (addr *IPAdapterAddresses) DNSSuffix() string {
if addr.dnsSuffix == nil {
return ""
}
- return windows.UTF16ToString((*(*[maxIndexCount16]uint16)(unsafe.Pointer(addr.dnsSuffix)))[:])
+ return windows.UTF16PtrToString(addr.dnsSuffix)
}
// Description method returns description for the adapter.
@@ -633,7 +625,7 @@ func (addr *IPAdapterAddresses) Description() string {
if addr.description == nil {
return ""
}
- return windows.UTF16ToString((*(*[maxIndexCount16]uint16)(unsafe.Pointer(addr.description)))[:])
+ return windows.UTF16PtrToString(addr.description)
}
// FriendlyName method returns a user-friendly name for the adapter. For example: "Local Area Connection 1."
@@ -642,7 +634,7 @@ func (addr *IPAdapterAddresses) FriendlyName() string {
if addr.friendlyName == nil {
return ""
}
- return windows.UTF16ToString((*(*[maxIndexCount16]uint16)(unsafe.Pointer(addr.friendlyName)))[:])
+ return windows.UTF16PtrToString(addr.friendlyName)
}
// PhysicalAddress method returns the Media Access Control (MAC) address for the adapter.
@@ -693,9 +685,8 @@ func (row *MibIPInterfaceRow) Set() error {
}
// get method returns all table rows as a Go slice.
-func (tab *mibIPInterfaceTable) get() []MibIPInterfaceRow {
- const maxCount = maxIndexCount8 / unsafe.Sizeof(MibIPInterfaceRow{})
- return (*[maxCount]MibIPInterfaceRow)(unsafe.Pointer(&tab.table[0]))[:tab.numEntries]
+func (tab *mibIPInterfaceTable) get() (s []MibIPInterfaceRow) {
+ return unsafe.Slice(&tab.table[0], tab.numEntries)
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
@@ -731,9 +722,8 @@ func (row *MibIfRow2) get() (ret error) {
}
// get method returns all table rows as a Go slice.
-func (tab *mibIfTable2) get() []MibIfRow2 {
- const maxCount = maxIndexCount8 / unsafe.Sizeof(MibIfRow2{})
- return (*[maxCount]MibIfRow2)(unsafe.Pointer(&tab.table[0]))[:tab.numEntries]
+func (tab *mibIfTable2) get() (s []MibIfRow2) {
+ return unsafe.Slice(&tab.table[0], tab.numEntries)
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
@@ -749,45 +739,82 @@ type RawSockaddrInet struct {
data [26]byte
}
-// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port.
+func ntohs(i uint16) uint16 {
+ return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&i))[:])
+}
+
+func htons(i uint16) uint16 {
+ b := make([]byte, 2)
+ binary.BigEndian.PutUint16(b, i)
+ return *(*uint16)(unsafe.Pointer(&b[0]))
+}
+
+// SetAddrPort method sets family, address, and port to the given IPv4 or IPv6 address and port.
// All other members of the structure are set to zero.
-func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
- if v4 := ip.To4(); v4 != nil {
+func (addr *RawSockaddrInet) SetAddrPort(addrPort netip.AddrPort) error {
+ if addrPort.Addr().Is4() {
addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr))
addr4.Family = windows.AF_INET
- copy(addr4.Addr[:], v4)
- addr4.Port = port
+ addr4.Addr = addrPort.Addr().As4()
+ addr4.Port = htons(addrPort.Port())
for i := 0; i < 8; i++ {
addr4.Zero[i] = 0
}
return nil
- }
-
- if v6 := ip.To16(); v6 != nil {
+ } else if addrPort.Addr().Is6() {
addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr))
addr6.Family = windows.AF_INET6
- addr6.Port = port
+ addr6.Addr = addrPort.Addr().As16()
+ addr6.Port = htons(addrPort.Port())
addr6.Flowinfo = 0
- copy(addr6.Addr[:], v6)
- addr6.Scope_id = 0
+ scopeId := uint32(0)
+ if z := addrPort.Addr().Zone(); z != "" {
+ if s, err := strconv.ParseUint(z, 10, 32); err == nil {
+ scopeId = uint32(s)
+ }
+ }
+ addr6.Scope_id = scopeId
return nil
}
-
return windows.ERROR_INVALID_PARAMETER
}
-// IP method returns IPv4 or IPv6 address.
-// If the address is neither IPv4 not IPv6 nil is returned.
-func (addr *RawSockaddrInet) IP() net.IP {
+// SetAddr method sets family and address to the given IPv4 or IPv6 address.
+// All other members of the structure are set to zero.
+func (addr *RawSockaddrInet) SetAddr(netAddr netip.Addr) error {
+ return addr.SetAddrPort(netip.AddrPortFrom(netAddr, 0))
+}
+
+// AddrPort returns the IP address and port.
+func (addr *RawSockaddrInet) AddrPort() netip.AddrPort {
+ return netip.AddrPortFrom(addr.Addr(), addr.Port())
+}
+
+// Addr returns IPv4 or IPv6 address, or an invalid address if the address is neither.
+func (addr *RawSockaddrInet) Addr() netip.Addr {
switch addr.Family {
case windows.AF_INET:
- return (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr[:]
-
+ return netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr)
case windows.AF_INET6:
- return (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Addr[:]
+ raw := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr))
+ a := netip.AddrFrom16(raw.Addr)
+ if raw.Scope_id != 0 {
+ a = a.WithZone(strconv.FormatUint(uint64(raw.Scope_id), 10))
+ }
+ return a
}
+ return netip.Addr{}
+}
- return nil
+// Port returns the port if the address if IPv4 or IPv6, or 0 if neither.
+func (addr *RawSockaddrInet) Port() uint16 {
+ switch addr.Family {
+ case windows.AF_INET:
+ return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port)
+ case windows.AF_INET6:
+ return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port)
+ }
+ return 0
}
// Init method initializes a MibUnicastIPAddressRow structure with default values for a unicast IP address entry on the local computer.
@@ -821,9 +848,8 @@ func (row *MibUnicastIPAddressRow) Delete() error {
}
// get method returns all table rows as a Go slice.
-func (tab *mibUnicastIPAddressTable) get() []MibUnicastIPAddressRow {
- const maxCount = maxIndexCount8 / unsafe.Sizeof(MibUnicastIPAddressRow{})
- return (*[maxCount]MibUnicastIPAddressRow)(unsafe.Pointer(&tab.table[0]))[:tab.numEntries]
+func (tab *mibUnicastIPAddressTable) get() (s []MibUnicastIPAddressRow) {
+ return unsafe.Slice(&tab.table[0], tab.numEntries)
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
@@ -851,9 +877,8 @@ func (row *MibAnycastIPAddressRow) Delete() error {
}
// get method returns all table rows as a Go slice.
-func (tab *mibAnycastIPAddressTable) get() []MibAnycastIPAddressRow {
- const maxCount = maxIndexCount8 / unsafe.Sizeof(MibAnycastIPAddressRow{})
- return (*[maxCount]MibAnycastIPAddressRow)(unsafe.Pointer(&tab.table[0]))[:tab.numEntries]
+func (tab *mibAnycastIPAddressTable) get() (s []MibAnycastIPAddressRow) {
+ return unsafe.Slice(&tab.table[0], tab.numEntries)
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
@@ -865,32 +890,30 @@ func (tab *mibAnycastIPAddressTable) free() {
// IPAddressPrefix structure stores an IP address prefix.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_ip_address_prefix
type IPAddressPrefix struct {
- Prefix RawSockaddrInet
+ RawPrefix RawSockaddrInet
PrefixLength uint8
_ [2]byte
}
-// SetIPNet method sets IP address prefix using net.IPNet.
-func (prefix *IPAddressPrefix) SetIPNet(net net.IPNet) error {
- err := prefix.Prefix.SetIP(net.IP, 0)
+// SetPrefix method sets IP address prefix using netip.Prefix.
+func (prefix *IPAddressPrefix) SetPrefix(netPrefix netip.Prefix) error {
+ err := prefix.RawPrefix.SetAddr(netPrefix.Addr())
if err != nil {
return err
}
- ones, _ := net.Mask.Size()
- prefix.PrefixLength = uint8(ones)
+ prefix.PrefixLength = uint8(netPrefix.Bits())
return nil
}
-// IPNet method returns IP address prefix as net.IPNet.
-// If the address is neither IPv4 not IPv6 an empty net.IPNet is returned. The resulting net.IPNet should be checked appropriately.
-func (prefix *IPAddressPrefix) IPNet() net.IPNet {
- switch prefix.Prefix.Family {
+// Prefix returns IP address prefix as netip.Prefix.
+func (prefix *IPAddressPrefix) Prefix() netip.Prefix {
+ switch prefix.RawPrefix.Family {
case windows.AF_INET:
- return net.IPNet{IP: (*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv4len)}
+ return netip.PrefixFrom(netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength))
case windows.AF_INET6:
- return net.IPNet{IP: (*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv6len)}
+ return netip.PrefixFrom(netip.AddrFrom16((*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength))
}
- return net.IPNet{}
+ return netip.Prefix{}
}
// MibIPforwardRow2 structure stores information about an IP route entry.
@@ -944,9 +967,8 @@ func (row *MibIPforwardRow2) Delete() error {
}
// get method returns all table rows as a Go slice.
-func (tab *mibIPforwardTable2) get() []MibIPforwardRow2 {
- const maxCount = maxIndexCount8 / unsafe.Sizeof(MibIPforwardRow2{})
- return (*[maxCount]MibIPforwardRow2)(unsafe.Pointer(&tab.table[0]))[:tab.numEntries]
+func (tab *mibIPforwardTable2) get() (s []MibIPforwardRow2) {
+ return unsafe.Slice(&tab.table[0], tab.numEntries)
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
@@ -954,3 +976,43 @@ func (tab *mibIPforwardTable2) get() []MibIPforwardRow2 {
func (tab *mibIPforwardTable2) free() {
freeMibTable(unsafe.Pointer(tab))
}
+
+//
+// DNS API
+//
+
+// DnsInterfaceSettings is meant to be used with SetInterfaceDnsSettings
+type DnsInterfaceSettings struct {
+ Version uint32
+ _ [4]byte
+ Flags uint64
+ Domain *uint16
+ NameServer *uint16
+ SearchList *uint16
+ RegistrationEnabled uint32
+ RegisterAdapterName uint32
+ EnableLLMNR uint32
+ QueryAdapterName uint32
+ ProfileNameServer *uint16
+}
+
+const (
+ DnsInterfaceSettingsVersion1 = 1 // for DnsInterfaceSettings
+ DnsInterfaceSettingsVersion2 = 2 // for DnsInterfaceSettingsEx
+ DnsInterfaceSettingsVersion3 = 3 // for DnsInterfaceSettings3
+
+ DnsInterfaceSettingsFlagIPv6 = 0x0001
+ DnsInterfaceSettingsFlagNameserver = 0x0002
+ DnsInterfaceSettingsFlagSearchList = 0x0004
+ DnsInterfaceSettingsFlagRegistrationEnabled = 0x0008
+ DnsInterfaceSettingsFlagRegisterAdapterName = 0x0010
+ DnsInterfaceSettingsFlagDomain = 0x0020
+ DnsInterfaceSettingsFlagHostname = 0x0040
+ DnsInterfaceSettingsFlagEnableLLMNR = 0x0080
+ DnsInterfaceSettingsFlagQueryAdapterName = 0x0100
+ DnsInterfaceSettingsFlagProfileNameserver = 0x0200
+ DnsInterfaceSettingsFlagDisableUnconstrainedQueries = 0x0400 // v2 only
+ DnsInterfaceSettingsFlagSupplementalSearchList = 0x0800 // v2 only
+ DnsInterfaceSettingsFlagDOH = 0x1000 // v3 only
+ DnsInterfaceSettingsFlagDOHProfile = 0x2000 // v3 only
+)
diff --git a/tunnel/winipcfg/types_386.go b/tunnel/winipcfg/types_32.go
index 3a4b5733..1a8d4443 100644
--- a/tunnel/winipcfg/types_386.go
+++ b/tunnel/winipcfg/types_32.go
@@ -1,6 +1,8 @@
+//go:build 386 || arm
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/types_amd64.go b/tunnel/winipcfg/types_64.go
index 11242891..3a1fe07f 100644
--- a/tunnel/winipcfg/types_amd64.go
+++ b/tunnel/winipcfg/types_64.go
@@ -1,6 +1,8 @@
+//go:build amd64 || arm64
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/types_test.go b/tunnel/winipcfg/types_test.go
index c7494e8c..b72d73f5 100644
--- a/tunnel/winipcfg/types_test.go
+++ b/tunnel/winipcfg/types_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
@@ -957,7 +957,6 @@ func TestIPAddressPrefix(t *testing.T) {
offset := uintptr(unsafe.Pointer(&s.PrefixLength)) - sp
if offset != ipAddressPrefixPrefixLengthOffset {
t.Errorf("IPAddressPrefix.PrefixLength offset is %d although %d is expected", offset, ipAddressPrefixPrefixLengthOffset)
-
}
}
diff --git a/tunnel/winipcfg/types_test_386.go b/tunnel/winipcfg/types_test_32.go
index db5f5f86..9e62bfef 100644
--- a/tunnel/winipcfg/types_test_386.go
+++ b/tunnel/winipcfg/types_test_32.go
@@ -1,6 +1,8 @@
+//go:build 386 || arm
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/types_test_amd64.go b/tunnel/winipcfg/types_test_64.go
index acc74118..8a181575 100644
--- a/tunnel/winipcfg/types_test_amd64.go
+++ b/tunnel/winipcfg/types_test_64.go
@@ -1,6 +1,8 @@
+//go:build amd64 || arm64
+
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/unicast_address_change_handler.go b/tunnel/winipcfg/unicast_address_change_handler.go
index 5f8f2c96..cf4fcb3a 100644
--- a/tunnel/winipcfg/unicast_address_change_handler.go
+++ b/tunnel/winipcfg/unicast_address_change_handler.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/winipcfg.go b/tunnel/winipcfg/winipcfg.go
index 2fc0c875..e24157b9 100644
--- a/tunnel/winipcfg/winipcfg.go
+++ b/tunnel/winipcfg/winipcfg.go
@@ -1,11 +1,12 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
+ "runtime"
"unsafe"
"golang.org/x/sys/windows"
@@ -29,6 +30,7 @@ import (
//sys getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) = iphlpapi.GetIfTable2Ex
//sys convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid
//sys convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) = iphlpapi.ConvertInterfaceGuidToLuid
+//sys convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) = iphlpapi.ConvertInterfaceIndexToLuid
// GetAdaptersAddresses function retrieves the addresses associated with the adapters on the local computer.
// https://docs.microsoft.com/en-us/windows/desktop/api/iphlpapi/nf-iphlpapi-getadaptersaddresses
@@ -166,3 +168,29 @@ func GetIPForwardTable2(family AddressFamily) ([]MibIPforwardRow2, error) {
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-cancelmibchangenotify2
//sys cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) = iphlpapi.CancelMibChangeNotify2
+
+//
+// DNS-related functions
+//
+
+//sys setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
+//sys setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
+//sys setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
+
+// The GUID is passed by value, not by reference, which means different
+// things on different calling conventions. On amd64, this means it's
+// passed by reference anyway, while on arm, arm64, and 386, it's split
+// into words.
+func SetInterfaceDnsSettings(guid windows.GUID, settings *DnsInterfaceSettings) error {
+ words := (*[4]uintptr)(unsafe.Pointer(&guid))
+ switch runtime.GOARCH {
+ case "amd64":
+ return setInterfaceDnsSettingsByPtr(&guid, settings)
+ case "arm64":
+ return setInterfaceDnsSettingsByQwords(words[0], words[1], settings)
+ case "arm", "386":
+ return setInterfaceDnsSettingsByDwords(words[0], words[1], words[2], words[3], settings)
+ default:
+ panic("unknown calling convention")
+ }
+}
diff --git a/tunnel/winipcfg/winipcfg_test.go b/tunnel/winipcfg/winipcfg_test.go
index 0251aecf..b49daf33 100644
--- a/tunnel/winipcfg/winipcfg_test.go
+++ b/tunnel/winipcfg/winipcfg_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
/*
@@ -8,8 +8,8 @@
Some tests in this file require:
- A dedicated network adapter
- Any network adapter will do. It may be virtual (Wintun etc.). The adapter name
- must contain string "winipcfg_test".
+ Any network adapter will do. It may be virtual (WireGuardNT, Wintun,
+ etc.). The adapter name must contain string "winipcfg_test".
Tests will add, remove, flush DNS servers, change adapter IP address, manipulate
routes etc.
The adapter will not be returned to previous state, so use an expendable one.
@@ -22,9 +22,9 @@ Some tests in this file require:
package winipcfg
import (
- "bytes"
- "net"
+ "net/netip"
"strings"
+ "syscall"
"testing"
"time"
@@ -37,22 +37,13 @@ const (
// TODO: Add IPv6 tests.
var (
- unexistentIPAddresToAdd = net.IPNet{
- IP: net.IP{172, 16, 1, 114},
- Mask: net.IPMask{255, 255, 255, 0},
- }
- unexistentRouteIPv4ToAdd = RouteData{
- Destination: net.IPNet{
- IP: net.IP{172, 16, 200, 0},
- Mask: net.IPMask{255, 255, 255, 0},
- },
- NextHop: net.IP{172, 16, 1, 2},
- Metric: 0,
- }
- dnsesToSet = []net.IP{
- net.IPv4(8, 8, 8, 8),
- net.IPv4(8, 8, 4, 4),
+ nonexistantIPv4ToAdd = netip.MustParsePrefix("172.16.1.114/24")
+ nonexistentRouteIPv4ToAdd = RouteData{
+ Destination: netip.MustParsePrefix("172.16.200.0/24"),
+ NextHop: netip.MustParseAddr("172.16.1.2"),
+ Metric: 0,
}
+ dnsesToSet = []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}
)
func runningElevated() bool {
@@ -73,7 +64,7 @@ func getTestInterface() (*IPAdapterAddresses, error) {
marker := strings.ToLower(testInterfaceMarker)
for _, ifc := range ifcs {
- if strings.Index(strings.ToLower(ifc.FriendlyName()), marker) != -1 {
+ if strings.Contains(strings.ToLower(ifc.FriendlyName()), marker) {
return ifc, nil
}
}
@@ -93,7 +84,7 @@ func getTestIPInterface(family AddressFamily) (*MibIPInterfaceRow, error) {
func TestAdaptersAddresses(t *testing.T) {
ifcs, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagIncludeAll)
if err != nil {
- t.Errorf("GetAdaptersAddresses() returned error: %v", err)
+ t.Errorf("GetAdaptersAddresses() returned error: %w", err)
} else if ifcs == nil {
t.Errorf("GetAdaptersAddresses() returned nil.")
} else if len(ifcs) == 0 {
@@ -107,7 +98,7 @@ func TestAdaptersAddresses(t *testing.T) {
i.PhysicalAddress()
i.DHCPv6ClientDUID()
for dnsSuffix := i.FirstDNSSuffix; dnsSuffix != nil; dnsSuffix = dnsSuffix.Next {
- dnsSuffix.String()
+ _ = dnsSuffix.String()
}
}
}
@@ -117,7 +108,7 @@ func TestAdaptersAddresses(t *testing.T) {
for _, i := range ifcs {
ifc, err := i.LUID.Interface()
if err != nil {
- t.Errorf("LUID.Interface() returned an error: %v", err)
+ t.Errorf("LUID.Interface() returned an error: %w", err)
continue
} else if ifc == nil {
t.Errorf("LUID.Interface() returned nil.")
@@ -128,7 +119,7 @@ func TestAdaptersAddresses(t *testing.T) {
for _, i := range ifcs {
guid, err := i.LUID.GUID()
if err != nil {
- t.Errorf("LUID.GUID() returned an error: %v", err)
+ t.Errorf("LUID.GUID() returned an error: %w", err)
continue
}
if guid == nil {
@@ -138,7 +129,7 @@ func TestAdaptersAddresses(t *testing.T) {
luid, err := LUIDFromGUID(guid)
if err != nil {
- t.Errorf("LUIDFromGUID() returned an error: %v", err)
+ t.Errorf("LUIDFromGUID() returned an error: %w", err)
continue
}
if luid != i.LUID {
@@ -151,7 +142,7 @@ func TestAdaptersAddresses(t *testing.T) {
func TestIPInterface(t *testing.T) {
ifcs, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault)
if err != nil {
- t.Errorf("GetAdaptersAddresses() returned error: %v", err)
+ t.Errorf("GetAdaptersAddresses() returned error: %w", err)
}
for _, i := range ifcs {
@@ -161,12 +152,12 @@ func TestIPInterface(t *testing.T) {
continue
}
if err != nil {
- t.Errorf("LUID.IPInterface(%s) returned an error: %v", i.FriendlyName(), err)
+ t.Errorf("LUID.IPInterface(%s) returned an error: %w", i.FriendlyName(), err)
}
_, err = i.LUID.IPInterface(windows.AF_INET6)
if err != nil {
- t.Errorf("LUID.IPInterface(%s) returned an error: %v", i.FriendlyName(), err)
+ t.Errorf("LUID.IPInterface(%s) returned an error: %w", i.FriendlyName(), err)
}
}
}
@@ -174,7 +165,7 @@ func TestIPInterface(t *testing.T) {
func TestIPInterfaces(t *testing.T) {
tab, err := GetIPInterfaceTable(windows.AF_UNSPEC)
if err != nil {
- t.Errorf("GetIPInterfaceTable() returned an error: %v", err)
+ t.Errorf("GetIPInterfaceTable() returned an error: %w", err)
return
} else if tab == nil {
t.Error("GetIPInterfaceTable() returned nil.")
@@ -189,7 +180,7 @@ func TestIPInterfaces(t *testing.T) {
func TestIPChangeMetric(t *testing.T) {
ipifc, err := getTestIPInterface(windows.AF_INET)
if err != nil {
- t.Errorf("getTestIPInterface() returned an error: %v", err)
+ t.Errorf("getTestIPInterface() returned an error: %w", err)
return
}
if !runningElevated() {
@@ -208,13 +199,13 @@ func TestIPChangeMetric(t *testing.T) {
}
})
if err != nil {
- t.Errorf("RegisterInterfaceChangeCallback() returned error: %v", err)
+ t.Errorf("RegisterInterfaceChangeCallback() returned error: %w", err)
return
}
defer func() {
err = cb.Unregister()
if err != nil {
- t.Errorf("UnregisterInterfaceChangeCallback() returned error: %v", err)
+ t.Errorf("UnregisterInterfaceChangeCallback() returned error: %w", err)
}
}()
@@ -230,14 +221,14 @@ func TestIPChangeMetric(t *testing.T) {
ipifc.Metric = newMetric
err = ipifc.Set()
if err != nil {
- t.Errorf("MibIPInterfaceRow.Set() returned an error: %v", err)
+ t.Errorf("MibIPInterfaceRow.Set() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
ipifc, err = getTestIPInterface(windows.AF_INET)
if err != nil {
- t.Errorf("getTestIPInterface() returned an error: %v", err)
+ t.Errorf("getTestIPInterface() returned an error: %w", err)
return
}
if ipifc.Metric != newMetric {
@@ -255,14 +246,14 @@ func TestIPChangeMetric(t *testing.T) {
ipifc.Metric = metric
err = ipifc.Set()
if err != nil {
- t.Errorf("MibIPInterfaceRow.Set() returned an error: %v", err)
+ t.Errorf("MibIPInterfaceRow.Set() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
ipifc, err = getTestIPInterface(windows.AF_INET)
if err != nil {
- t.Errorf("getTestIPInterface() returned an error: %v", err)
+ t.Errorf("getTestIPInterface() returned an error: %w", err)
return
}
if ipifc.Metric != metric {
@@ -279,7 +270,7 @@ func TestIPChangeMetric(t *testing.T) {
func TestIPChangeMTU(t *testing.T) {
ipifc, err := getTestIPInterface(windows.AF_INET)
if err != nil {
- t.Errorf("getTestIPInterface() returned an error: %v", err)
+ t.Errorf("getTestIPInterface() returned an error: %w", err)
return
}
if !runningElevated() {
@@ -292,14 +283,14 @@ func TestIPChangeMTU(t *testing.T) {
ipifc.NLMTU = mtuToSet
err = ipifc.Set()
if err != nil {
- t.Errorf("Interface.Set() returned error: %v", err)
+ t.Errorf("Interface.Set() returned error: %w", err)
}
time.Sleep(500 * time.Millisecond)
ipifc, err = getTestIPInterface(windows.AF_INET)
if err != nil {
- t.Errorf("getTestIPInterface() returned an error: %v", err)
+ t.Errorf("getTestIPInterface() returned an error: %w", err)
return
}
if ipifc.NLMTU != mtuToSet {
@@ -309,14 +300,14 @@ func TestIPChangeMTU(t *testing.T) {
ipifc.NLMTU = prevMTU
err = ipifc.Set()
if err != nil {
- t.Errorf("Interface.Set() returned error: %v", err)
+ t.Errorf("Interface.Set() returned error: %w", err)
}
time.Sleep(500 * time.Millisecond)
ipifc, err = getTestIPInterface(windows.AF_INET)
if err != nil {
- t.Errorf("getTestIPInterface() returned an error: %v", err)
+ t.Errorf("getTestIPInterface() returned an error: %w", err)
}
if ipifc.NLMTU != prevMTU {
t.Errorf("Interface.NLMTU is %d although %d is expected.", ipifc.NLMTU, prevMTU)
@@ -326,13 +317,13 @@ func TestIPChangeMTU(t *testing.T) {
func TestGetIfRow(t *testing.T) {
ifc, err := getTestInterface()
if err != nil {
- t.Errorf("getTestInterface() returned an error: %v", err)
+ t.Errorf("getTestInterface() returned an error: %w", err)
return
}
row, err := ifc.LUID.Interface()
if err != nil {
- t.Errorf("LUID.Interface() returned an error: %v", err)
+ t.Errorf("LUID.Interface() returned an error: %w", err)
return
}
@@ -345,7 +336,7 @@ func TestGetIfRow(t *testing.T) {
func TestGetIfRows(t *testing.T) {
tab, err := GetIfTable2Ex(MibIfEntryNormal)
if err != nil {
- t.Errorf("GetIfTable2Ex() returned an error: %v", err)
+ t.Errorf("GetIfTable2Ex() returned an error: %w", err)
return
} else if tab == nil {
t.Errorf("GetIfTable2Ex() returned nil")
@@ -363,7 +354,7 @@ func TestGetIfRows(t *testing.T) {
func TestUnicastIPAddress(t *testing.T) {
_, err := GetUnicastIPAddressTable(windows.AF_UNSPEC)
if err != nil {
- t.Errorf("GetUnicastAddresses() returned an error: %v", err)
+ t.Errorf("GetUnicastAddresses() returned an error: %w", err)
return
}
}
@@ -371,7 +362,7 @@ func TestUnicastIPAddress(t *testing.T) {
func TestAddDeleteIPAddress(t *testing.T) {
ifc, err := getTestInterface()
if err != nil {
- t.Errorf("getTestInterface() returned an error: %v", err)
+ t.Errorf("getTestInterface() returned an error: %w", err)
return
}
if !runningElevated() {
@@ -379,12 +370,12 @@ func TestAddDeleteIPAddress(t *testing.T) {
return
}
- addr, err := ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP)
+ addr, err := ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr())
if err == nil {
- t.Errorf("Unicast address %s already exists. Please set unexistentIPAddresToAdd appropriately.", unexistentIPAddresToAdd.IP.String())
+ t.Errorf("Unicast address %s already exists. Please set nonexistantIPv4ToAdd appropriately.", nonexistantIPv4ToAdd.Addr().String())
return
} else if err != windows.ERROR_NOT_FOUND {
- t.Errorf("LUID.IPAddress() returned an error: %v", err)
+ t.Errorf("LUID.IPAddress() returned an error: %w", err)
return
}
@@ -401,7 +392,7 @@ func TestAddDeleteIPAddress(t *testing.T) {
}
})
if err != nil {
- t.Errorf("RegisterUnicastAddressChangeCallback() returned an error: %v", err)
+ t.Errorf("RegisterUnicastAddressChangeCallback() returned an error: %w", err)
} else {
defer cb.Unregister()
}
@@ -409,9 +400,9 @@ func TestAddDeleteIPAddress(t *testing.T) {
for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next {
count--
}
- err = ifc.LUID.AddIPAddresses([]net.IPNet{unexistentIPAddresToAdd})
+ err = ifc.LUID.AddIPAddresses([]netip.Prefix{nonexistantIPv4ToAdd})
if err != nil {
- t.Errorf("LUID.AddIPAddresses() returned an error: %v", err)
+ t.Errorf("LUID.AddIPAddresses() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
@@ -423,28 +414,28 @@ func TestAddDeleteIPAddress(t *testing.T) {
if count != 1 {
t.Errorf("After adding there are %d new interface(s).", count)
}
- addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP)
+ addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr())
if err != nil {
- t.Errorf("LUID.IPAddress() returned an error: %v", err)
+ t.Errorf("LUID.IPAddress() returned an error: %w", err)
} else if addr == nil {
- t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", unexistentIPAddresToAdd.IP.String())
+ t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", nonexistantIPv4ToAdd.Addr().String())
}
if !created {
t.Errorf("Notification handler has not been called on add.")
}
- err = ifc.LUID.DeleteIPAddress(unexistentIPAddresToAdd)
+ err = ifc.LUID.DeleteIPAddress(nonexistantIPv4ToAdd)
if err != nil {
- t.Errorf("LUID.DeleteIPAddress() returned an error: %v", err)
+ t.Errorf("LUID.DeleteIPAddress() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
- addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP)
+ addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr())
if err == nil {
- t.Errorf("Unicast address %s still exists, although it's deleted successfully.", unexistentIPAddresToAdd.IP.String())
+ t.Errorf("Unicast address %s still exists, although it's deleted successfully.", nonexistantIPv4ToAdd.Addr().String())
} else if err != windows.ERROR_NOT_FOUND {
- t.Errorf("LUID.IPAddress() returned an error: %v", err)
+ t.Errorf("LUID.IPAddress() returned an error: %w", err)
}
if !deleted {
t.Errorf("Notification handler has not been called on delete.")
@@ -454,19 +445,18 @@ func TestAddDeleteIPAddress(t *testing.T) {
func TestGetRoutes(t *testing.T) {
_, err := GetIPForwardTable2(windows.AF_UNSPEC)
if err != nil {
- t.Errorf("GetIPForwardTable2() returned error: %v", err)
+ t.Errorf("GetIPForwardTable2() returned error: %w", err)
}
}
func TestAddDeleteRoute(t *testing.T) {
- findRoute := func(luid LUID, dest net.IPNet) ([]MibIPforwardRow2, error) {
+ findRoute := func(luid LUID, dest netip.Prefix) ([]MibIPforwardRow2, error) {
var family AddressFamily
- switch {
- case dest.IP.To4() != nil:
+ if dest.Addr().Is4() {
family = windows.AF_INET
- case dest.IP.To16() != nil:
+ } else if dest.Addr().Is6() {
family = windows.AF_INET6
- default:
+ } else {
return nil, windows.ERROR_INVALID_PARAMETER
}
r, err := GetIPForwardTable2(family)
@@ -474,9 +464,8 @@ func TestAddDeleteRoute(t *testing.T) {
return nil, err
}
matches := make([]MibIPforwardRow2, 0, len(r))
- ones, _ := dest.Mask.Size()
for _, route := range r {
- if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(ones) && route.DestinationPrefix.Prefix.Family == family && route.DestinationPrefix.Prefix.IP().Equal(dest.IP) {
+ if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(dest.Bits()) && route.DestinationPrefix.RawPrefix.Family == family && route.DestinationPrefix.RawPrefix.Addr() == dest.Addr() {
matches = append(matches, route)
}
}
@@ -485,7 +474,7 @@ func TestAddDeleteRoute(t *testing.T) {
ifc, err := getTestInterface()
if err != nil {
- t.Errorf("getTestInterface() returned an error: %v", err)
+ t.Errorf("getTestInterface() returned an error: %w", err)
return
}
if !runningElevated() {
@@ -493,20 +482,20 @@ func TestAddDeleteRoute(t *testing.T) {
return
}
- _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err == nil {
- t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?")
+ t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?")
return
} else if err != windows.ERROR_NOT_FOUND {
- t.Errorf("LUID.Route() returned an error: %v", err)
+ t.Errorf("LUID.Route() returned an error: %w", err)
return
}
- routes, err := findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination)
+ routes, err := findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination)
if err != nil {
- t.Errorf("findRoute() returned an error: %v", err)
+ t.Errorf("findRoute() returned an error: %w", err)
} else if len(routes) != 0 {
- t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?", len(routes))
+ t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?", len(routes))
}
var created, deleted bool
@@ -519,58 +508,58 @@ func TestAddDeleteRoute(t *testing.T) {
}
})
if err != nil {
- t.Errorf("RegisterRouteChangeCallback() returned an error: %v", err)
+ t.Errorf("RegisterRouteChangeCallback() returned an error: %w", err)
} else {
defer cb.Unregister()
}
- err = ifc.LUID.AddRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop, unexistentRouteIPv4ToAdd.Metric)
+ err = ifc.LUID.AddRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop, nonexistentRouteIPv4ToAdd.Metric)
if err != nil {
- t.Errorf("LUID.AddRoute() returned an error: %v", err)
+ t.Errorf("LUID.AddRoute() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
- route, err := ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ route, err := ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err == windows.ERROR_NOT_FOUND {
t.Error("LUID.Route() returned nil although the route is added successfully.")
} else if err != nil {
- t.Errorf("LUID.Route() returned an error: %v", err)
- } else if !route.DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) || !route.NextHop.IP().Equal(unexistentRouteIPv4ToAdd.NextHop) {
+ t.Errorf("LUID.Route() returned an error: %w", err)
+ } else if route.DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() || route.NextHop.Addr() != nonexistentRouteIPv4ToAdd.NextHop {
t.Error("LUID.Route() returned a wrong route!")
}
if !created {
t.Errorf("Route handler has not been called on add.")
}
- routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination)
+ routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination)
if err != nil {
- t.Errorf("findRoute() returned an error: %v", err)
+ t.Errorf("findRoute() returned an error: %w", err)
} else if len(routes) != 1 {
t.Errorf("findRoute() returned %d items although %d is expected.", len(routes), 1)
- } else if !routes[0].DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) {
- t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.Prefix.IP().String(), unexistentRouteIPv4ToAdd.Destination.IP.String())
+ } else if routes[0].DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() {
+ t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.RawPrefix.Addr().String(), nonexistentRouteIPv4ToAdd.Destination.Addr().String())
}
- err = ifc.LUID.DeleteRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ err = ifc.LUID.DeleteRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err != nil {
- t.Errorf("LUID.DeleteRoute() returned an error: %v", err)
+ t.Errorf("LUID.DeleteRoute() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
- _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err == nil {
t.Error("LUID.Route() returned a route although it is removed successfully.")
} else if err != windows.ERROR_NOT_FOUND {
- t.Errorf("LUID.Route() returned an error: %v", err)
+ t.Errorf("LUID.Route() returned an error: %w", err)
}
if !deleted {
t.Errorf("Route handler has not been called on delete.")
}
- routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination)
+ routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination)
if err != nil {
- t.Errorf("findRoute() returned an error: %v", err)
+ t.Errorf("findRoute() returned an error: %w", err)
} else if len(routes) != 0 {
t.Errorf("findRoute() returned %d items although the route is deleted successfully.", len(routes))
}
@@ -579,7 +568,7 @@ func TestAddDeleteRoute(t *testing.T) {
func TestFlushDNS(t *testing.T) {
ifc, err := getTestInterface()
if err != nil {
- t.Errorf("getTestInterface() returned an error: %v", err)
+ t.Errorf("getTestInterface() returned an error: %w", err)
return
}
if !runningElevated() {
@@ -589,12 +578,12 @@ func TestFlushDNS(t *testing.T) {
prevDNSes, err := ifc.LUID.DNS()
if err != nil {
- t.Errorf("LUID.DNS() returned an error: %v", err)
+ t.Errorf("LUID.DNS() returned an error: %w", err)
}
- err = ifc.LUID.FlushDNS()
+ err = ifc.LUID.FlushDNS(syscall.AF_INET)
if err != nil {
- t.Errorf("LUID.FlushDNS() returned an error: %v", err)
+ t.Errorf("LUID.FlushDNS() returned an error: %w", err)
}
ifc, _ = getTestInterface()
@@ -602,10 +591,10 @@ func TestFlushDNS(t *testing.T) {
n := 0
dns, err := ifc.LUID.DNS()
if err != nil {
- t.Errorf("LUID.DNS() returned an error: %v", err)
+ t.Errorf("LUID.DNS() returned an error: %w", err)
}
for _, a := range dns {
- if len(a) != 16 || a.To4() != nil || !((a[15] == 1 || a[15] == 2 || a[15] == 3) && bytes.HasPrefix(a, []byte{0xfe, 0xc0, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})) {
+ if a.Is4() {
n++
}
}
@@ -613,51 +602,7 @@ func TestFlushDNS(t *testing.T) {
t.Errorf("DNSServerAddresses contains %d items, although FlushDNS is executed successfully.", n)
}
- err = ifc.LUID.SetDNS(prevDNSes)
- if err != nil {
- t.Errorf("LUID.SetDNS() returned an error: %v.", err)
- }
-}
-
-func TestAddDNS(t *testing.T) {
- ifc, err := getTestInterface()
- if err != nil {
- t.Errorf("getTestInterface() returned an error: %v", err)
- return
- }
- if !runningElevated() {
- t.Errorf("%s requires elevation", t.Name())
- return
- }
-
- prevDNSes, err := ifc.LUID.DNS()
- if err != nil {
- t.Errorf("LUID.DNS() returned an error: %v", err)
- }
- expectedDNSes := append(prevDNSes, dnsesToSet...)
-
- err = ifc.LUID.AddDNS(dnsesToSet)
- if err != nil {
- t.Errorf("LUID.AddDNS() returned an error: %v", err)
- return
- }
-
- ifc, _ = getTestInterface()
-
- newDNSes, err := ifc.LUID.DNS()
- if err != nil {
- t.Errorf("LUID.DNS() returned an error: %v", err)
- } else if len(newDNSes) != len(expectedDNSes) {
- t.Errorf("expectedDNSes contains %d items, while DNSServerAddresses contains %d.", len(expectedDNSes), len(newDNSes))
- } else {
- for i := range expectedDNSes {
- if !expectedDNSes[i].Equal(newDNSes[i]) {
- t.Errorf("expectedDNSes[%d] = %s while DNSServerAddresses[%d] = %s.", i, expectedDNSes[i].String(), i, newDNSes[i].String())
- }
- }
- }
-
- err = ifc.LUID.SetDNS(prevDNSes)
+ err = ifc.LUID.SetDNS(windows.AF_INET, prevDNSes, nil)
if err != nil {
t.Errorf("LUID.SetDNS() returned an error: %v.", err)
}
@@ -666,7 +611,7 @@ func TestAddDNS(t *testing.T) {
func TestSetDNS(t *testing.T) {
ifc, err := getTestInterface()
if err != nil {
- t.Errorf("getTestInterface() returned an error: %v", err)
+ t.Errorf("getTestInterface() returned an error: %w", err)
return
}
if !runningElevated() {
@@ -676,12 +621,12 @@ func TestSetDNS(t *testing.T) {
prevDNSes, err := ifc.LUID.DNS()
if err != nil {
- t.Errorf("LUID.DNS() returned an error: %v", err)
+ t.Errorf("LUID.DNS() returned an error: %w", err)
}
- err = ifc.LUID.SetDNS(dnsesToSet)
+ err = ifc.LUID.SetDNS(windows.AF_INET, dnsesToSet, nil)
if err != nil {
- t.Errorf("LUID.SetDNS() returned an error: %v", err)
+ t.Errorf("LUID.SetDNS() returned an error: %w", err)
return
}
@@ -689,18 +634,18 @@ func TestSetDNS(t *testing.T) {
newDNSes, err := ifc.LUID.DNS()
if err != nil {
- t.Errorf("LUID.DNS() returned an error: %v", err)
+ t.Errorf("LUID.DNS() returned an error: %w", err)
} else if len(newDNSes) != len(dnsesToSet) {
t.Errorf("dnsesToSet contains %d items, while DNSServerAddresses contains %d.", len(dnsesToSet), len(newDNSes))
} else {
for i := range dnsesToSet {
- if !dnsesToSet[i].Equal(newDNSes[i]) {
+ if dnsesToSet[i] != newDNSes[i] {
t.Errorf("dnsesToSet[%d] = %s while DNSServerAddresses[%d] = %s.", i, dnsesToSet[i].String(), i, newDNSes[i].String())
}
}
}
- err = ifc.LUID.SetDNS(prevDNSes)
+ err = ifc.LUID.SetDNS(windows.AF_INET, prevDNSes, nil)
if err != nil {
t.Errorf("LUID.SetDNS() returned an error: %v.", err)
}
@@ -709,7 +654,7 @@ func TestSetDNS(t *testing.T) {
func TestAnycastIPAddress(t *testing.T) {
_, err := GetAnycastIPAddressTable(windows.AF_UNSPEC)
if err != nil {
- t.Errorf("GetAnycastIPAddressTable() returned an error: %v", err)
+ t.Errorf("GetAnycastIPAddressTable() returned an error: %w", err)
return
}
}
diff --git a/tunnel/winipcfg/zwinipcfg_windows.go b/tunnel/winipcfg/zwinipcfg_windows.go
index 8f37ac26..3a0d8680 100644
--- a/tunnel/winipcfg/zwinipcfg_windows.go
+++ b/tunnel/winipcfg/zwinipcfg_windows.go
@@ -19,6 +19,7 @@ const (
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
+ errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
@@ -26,7 +27,7 @@ var (
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
- return nil
+ return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
@@ -39,182 +40,198 @@ func errnoErr(e syscall.Errno) error {
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
- procFreeMibTable = modiphlpapi.NewProc("FreeMibTable")
- procInitializeIpInterfaceEntry = modiphlpapi.NewProc("InitializeIpInterfaceEntry")
- procGetIpInterfaceTable = modiphlpapi.NewProc("GetIpInterfaceTable")
- procGetIpInterfaceEntry = modiphlpapi.NewProc("GetIpInterfaceEntry")
- procSetIpInterfaceEntry = modiphlpapi.NewProc("SetIpInterfaceEntry")
- procGetIfEntry2 = modiphlpapi.NewProc("GetIfEntry2")
- procGetIfTable2Ex = modiphlpapi.NewProc("GetIfTable2Ex")
- procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
+ procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
procConvertInterfaceGuidToLuid = modiphlpapi.NewProc("ConvertInterfaceGuidToLuid")
- procGetUnicastIpAddressTable = modiphlpapi.NewProc("GetUnicastIpAddressTable")
- procInitializeUnicastIpAddressEntry = modiphlpapi.NewProc("InitializeUnicastIpAddressEntry")
- procGetUnicastIpAddressEntry = modiphlpapi.NewProc("GetUnicastIpAddressEntry")
- procSetUnicastIpAddressEntry = modiphlpapi.NewProc("SetUnicastIpAddressEntry")
+ procConvertInterfaceIndexToLuid = modiphlpapi.NewProc("ConvertInterfaceIndexToLuid")
+ procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
+ procCreateAnycastIpAddressEntry = modiphlpapi.NewProc("CreateAnycastIpAddressEntry")
+ procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2")
procCreateUnicastIpAddressEntry = modiphlpapi.NewProc("CreateUnicastIpAddressEntry")
+ procDeleteAnycastIpAddressEntry = modiphlpapi.NewProc("DeleteAnycastIpAddressEntry")
+ procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2")
procDeleteUnicastIpAddressEntry = modiphlpapi.NewProc("DeleteUnicastIpAddressEntry")
- procGetAnycastIpAddressTable = modiphlpapi.NewProc("GetAnycastIpAddressTable")
+ procFreeMibTable = modiphlpapi.NewProc("FreeMibTable")
procGetAnycastIpAddressEntry = modiphlpapi.NewProc("GetAnycastIpAddressEntry")
- procCreateAnycastIpAddressEntry = modiphlpapi.NewProc("CreateAnycastIpAddressEntry")
- procDeleteAnycastIpAddressEntry = modiphlpapi.NewProc("DeleteAnycastIpAddressEntry")
+ procGetAnycastIpAddressTable = modiphlpapi.NewProc("GetAnycastIpAddressTable")
+ procGetIfEntry2 = modiphlpapi.NewProc("GetIfEntry2")
+ procGetIfTable2Ex = modiphlpapi.NewProc("GetIfTable2Ex")
+ procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2")
procGetIpForwardTable2 = modiphlpapi.NewProc("GetIpForwardTable2")
+ procGetIpInterfaceEntry = modiphlpapi.NewProc("GetIpInterfaceEntry")
+ procGetIpInterfaceTable = modiphlpapi.NewProc("GetIpInterfaceTable")
+ procGetUnicastIpAddressEntry = modiphlpapi.NewProc("GetUnicastIpAddressEntry")
+ procGetUnicastIpAddressTable = modiphlpapi.NewProc("GetUnicastIpAddressTable")
procInitializeIpForwardEntry = modiphlpapi.NewProc("InitializeIpForwardEntry")
- procGetIpForwardEntry2 = modiphlpapi.NewProc("GetIpForwardEntry2")
- procSetIpForwardEntry2 = modiphlpapi.NewProc("SetIpForwardEntry2")
- procCreateIpForwardEntry2 = modiphlpapi.NewProc("CreateIpForwardEntry2")
- procDeleteIpForwardEntry2 = modiphlpapi.NewProc("DeleteIpForwardEntry2")
+ procInitializeIpInterfaceEntry = modiphlpapi.NewProc("InitializeIpInterfaceEntry")
+ procInitializeUnicastIpAddressEntry = modiphlpapi.NewProc("InitializeUnicastIpAddressEntry")
procNotifyIpInterfaceChange = modiphlpapi.NewProc("NotifyIpInterfaceChange")
- procNotifyUnicastIpAddressChange = modiphlpapi.NewProc("NotifyUnicastIpAddressChange")
procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
- procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
+ procNotifyUnicastIpAddressChange = modiphlpapi.NewProc("NotifyUnicastIpAddressChange")
+ procSetInterfaceDnsSettings = modiphlpapi.NewProc("SetInterfaceDnsSettings")
+ procSetIpForwardEntry2 = modiphlpapi.NewProc("SetIpForwardEntry2")
+ procSetIpInterfaceEntry = modiphlpapi.NewProc("SetIpInterfaceEntry")
+ procSetUnicastIpAddressEntry = modiphlpapi.NewProc("SetUnicastIpAddressEntry")
)
-func freeMibTable(memory unsafe.Pointer) {
- syscall.Syscall(procFreeMibTable.Addr(), 1, uintptr(memory), 0, 0)
+func cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) {
+ r0, _, _ := syscall.Syscall(procCancelMibChangeNotify2.Addr(), 1, uintptr(notificationHandle), 0, 0)
+ if r0 != 0 {
+ ret = syscall.Errno(r0)
+ }
return
}
-func initializeIPInterfaceEntry(row *MibIPInterfaceRow) {
- syscall.Syscall(procInitializeIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) {
+ r0, _, _ := syscall.Syscall(procConvertInterfaceGuidToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceGUID)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
+ if r0 != 0 {
+ ret = syscall.Errno(r0)
+ }
return
}
-func getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) {
- r0, _, _ := syscall.Syscall(procGetIpInterfaceTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
+func convertInterfaceIndexToLUID(interfaceIndex uint32, interfaceLUID *LUID) (ret error) {
+ r0, _, _ := syscall.Syscall(procConvertInterfaceIndexToLuid.Addr(), 2, uintptr(interfaceIndex), uintptr(unsafe.Pointer(interfaceLUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) {
- r0, _, _ := syscall.Syscall(procGetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) {
+ r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) {
- r0, _, _ := syscall.Syscall(procSetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
+ r0, _, _ := syscall.Syscall(procCreateAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func getIfEntry2(row *MibIfRow2) (ret error) {
- r0, _, _ := syscall.Syscall(procGetIfEntry2.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func createIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
+ r0, _, _ := syscall.Syscall(procCreateIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) {
- r0, _, _ := syscall.Syscall(procGetIfTable2Ex.Addr(), 2, uintptr(level), uintptr(unsafe.Pointer(table)), 0)
+func createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
+ r0, _, _ := syscall.Syscall(procCreateUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func convertInterfaceLUIDToGUID(interfaceLUID *LUID, interfaceGUID *windows.GUID) (ret error) {
- r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
+func deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
+ r0, _, _ := syscall.Syscall(procDeleteAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func convertInterfaceGUIDToLUID(interfaceGUID *windows.GUID, interfaceLUID *LUID) (ret error) {
- r0, _, _ := syscall.Syscall(procConvertInterfaceGuidToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceGUID)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
+func deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
+ r0, _, _ := syscall.Syscall(procDeleteIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) {
- r0, _, _ := syscall.Syscall(procGetUnicastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
+func deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
+ r0, _, _ := syscall.Syscall(procDeleteUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) {
- syscall.Syscall(procInitializeUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func freeMibTable(memory unsafe.Pointer) {
+ syscall.Syscall(procFreeMibTable.Addr(), 1, uintptr(memory), 0, 0)
return
}
-func getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
- r0, _, _ := syscall.Syscall(procGetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
+ r0, _, _ := syscall.Syscall(procGetAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
- r0, _, _ := syscall.Syscall(procSetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) {
+ r0, _, _ := syscall.Syscall(procGetAnycastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func createUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
- r0, _, _ := syscall.Syscall(procCreateUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func getIfEntry2(row *MibIfRow2) (ret error) {
+ r0, _, _ := syscall.Syscall(procGetIfEntry2.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func deleteUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
- r0, _, _ := syscall.Syscall(procDeleteUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func getIfTable2Ex(level MibIfEntryLevel, table **mibIfTable2) (ret error) {
+ r0, _, _ := syscall.Syscall(procGetIfTable2Ex.Addr(), 2, uintptr(level), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func getAnycastIPAddressTable(family AddressFamily, table **mibAnycastIPAddressTable) (ret error) {
- r0, _, _ := syscall.Syscall(procGetAnycastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
+func getIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
+ r0, _, _ := syscall.Syscall(procGetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func getAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
- r0, _, _ := syscall.Syscall(procGetAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) {
+ r0, _, _ := syscall.Syscall(procGetIpForwardTable2.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func createAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
- r0, _, _ := syscall.Syscall(procCreateAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func getIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) {
+ r0, _, _ := syscall.Syscall(procGetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func deleteAnycastIPAddressEntry(row *MibAnycastIPAddressRow) (ret error) {
- r0, _, _ := syscall.Syscall(procDeleteAnycastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+func getIPInterfaceTable(family AddressFamily, table **mibIPInterfaceTable) (ret error) {
+ r0, _, _ := syscall.Syscall(procGetIpInterfaceTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func getIPForwardTable2(family AddressFamily, table **mibIPforwardTable2) (ret error) {
- r0, _, _ := syscall.Syscall(procGetIpForwardTable2.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
+func getUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
+ r0, _, _ := syscall.Syscall(procGetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+ if r0 != 0 {
+ ret = syscall.Errno(r0)
+ }
+ return
+}
+
+func getUnicastIPAddressTable(family AddressFamily, table **mibUnicastIPAddressTable) (ret error) {
+ r0, _, _ := syscall.Syscall(procGetUnicastIpAddressTable.Addr(), 2, uintptr(family), uintptr(unsafe.Pointer(table)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
@@ -226,82 +243,106 @@ func initializeIPForwardEntry(route *MibIPforwardRow2) {
return
}
-func getIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
- r0, _, _ := syscall.Syscall(procGetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
+func initializeIPInterfaceEntry(row *MibIPInterfaceRow) {
+ syscall.Syscall(procInitializeIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+ return
+}
+
+func initializeUnicastIPAddressEntry(row *MibUnicastIPAddressRow) {
+ syscall.Syscall(procInitializeUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
+ return
+}
+
+func notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
+ var _p0 uint32
+ if initialNotification {
+ _p0 = 1
+ }
+ r0, _, _ := syscall.Syscall6(procNotifyIpInterfaceChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func setIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
- r0, _, _ := syscall.Syscall(procSetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
+func notifyRouteChange2(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
+ var _p0 uint32
+ if initialNotification {
+ _p0 = 1
+ }
+ r0, _, _ := syscall.Syscall6(procNotifyRouteChange2.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func createIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
- r0, _, _ := syscall.Syscall(procCreateIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
+func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
+ var _p0 uint32
+ if initialNotification {
+ _p0 = 1
+ }
+ r0, _, _ := syscall.Syscall6(procNotifyUnicastIpAddressChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func deleteIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
- r0, _, _ := syscall.Syscall(procDeleteIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
+func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) {
+ ret = procSetInterfaceDnsSettings.Find()
+ if ret != nil {
+ return
+ }
+ r0, _, _ := syscall.Syscall6(procSetInterfaceDnsSettings.Addr(), 5, uintptr(guid1), uintptr(guid2), uintptr(guid3), uintptr(guid4), uintptr(unsafe.Pointer(settings)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func notifyIPInterfaceChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
- var _p0 uint32
- if initialNotification {
- _p0 = 1
- } else {
- _p0 = 0
+func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) {
+ ret = procSetInterfaceDnsSettings.Find()
+ if ret != nil {
+ return
}
- r0, _, _ := syscall.Syscall6(procNotifyIpInterfaceChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
+ r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings)))
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
- var _p0 uint32
- if initialNotification {
- _p0 = 1
- } else {
- _p0 = 0
+func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) {
+ ret = procSetInterfaceDnsSettings.Find()
+ if ret != nil {
+ return
}
- r0, _, _ := syscall.Syscall6(procNotifyUnicastIpAddressChange.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
+ r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func notifyRouteChange2(family AddressFamily, callback uintptr, callerContext uintptr, initialNotification bool, notificationHandle *windows.Handle) (ret error) {
- var _p0 uint32
- if initialNotification {
- _p0 = 1
- } else {
- _p0 = 0
+func setIPForwardEntry2(route *MibIPforwardRow2) (ret error) {
+ r0, _, _ := syscall.Syscall(procSetIpForwardEntry2.Addr(), 1, uintptr(unsafe.Pointer(route)), 0, 0)
+ if r0 != 0 {
+ ret = syscall.Errno(r0)
}
- r0, _, _ := syscall.Syscall6(procNotifyRouteChange2.Addr(), 5, uintptr(family), uintptr(callback), uintptr(callerContext), uintptr(_p0), uintptr(unsafe.Pointer(notificationHandle)), 0)
+ return
+}
+
+func setIPInterfaceEntry(row *MibIPInterfaceRow) (ret error) {
+ r0, _, _ := syscall.Syscall(procSetIpInterfaceEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) {
- r0, _, _ := syscall.Syscall(procCancelMibChangeNotify2.Addr(), 1, uintptr(notificationHandle), 0, 0)
+func setUnicastIPAddressEntry(row *MibUnicastIPAddressRow) (ret error) {
+ r0, _, _ := syscall.Syscall(procSetUnicastIpAddressEntry.Addr(), 1, uintptr(unsafe.Pointer(row)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
diff --git a/tunnel/wintun_test.go b/tunnel/wintun_test.go
deleted file mode 100644
index 11d7ab2c..00000000
--- a/tunnel/wintun_test.go
+++ /dev/null
@@ -1,202 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package tunnel_test
-
-import (
- "bytes"
- "crypto/rand"
- "encoding/binary"
- "fmt"
- "net"
- "sync"
- "testing"
- "time"
-
- "golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/tun"
-
- "golang.zx2c4.com/wireguard/windows/elevate"
- "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
-)
-
-func TestWintunOrdering(t *testing.T) {
- var tunDevice tun.Device
- err := elevate.DoAsSystem(func() error {
- var err error
- tunDevice, err = tun.CreateTUNWithRequestedGUID("tunordertest", &windows.GUID{12, 12, 12, [8]byte{12, 12, 12, 12, 12, 12, 12, 12}}, 1500)
- return err
- })
- if err != nil {
- t.Fatal(err)
- }
- defer tunDevice.Close()
- nativeTunDevice := tunDevice.(*tun.NativeTun)
- luid := winipcfg.LUID(nativeTunDevice.LUID())
- ip, ipnet, _ := net.ParseCIDR("10.82.31.4/24")
- err = luid.SetIPAddresses([]net.IPNet{{ip, ipnet.Mask}})
- if err != nil {
- t.Fatal(err)
- }
- err = luid.SetRoutes([]*winipcfg.RouteData{{*ipnet, ipnet.IP, 0}})
- if err != nil {
- t.Fatal(err)
- }
- var token [32]byte
- _, err = rand.Read(token[:])
- if err != nil {
- t.Fatal(err)
- }
- var sockWrite net.Conn
- for i := 0; i < 1000; i++ {
- sockWrite, err = net.Dial("udp", "10.82.31.5:9999")
- if err == nil {
- defer sockWrite.Close()
- break
- }
- time.Sleep(time.Millisecond * 100)
- }
- if err != nil {
- t.Fatal(err)
- }
- var sockRead *net.UDPConn
- for i := 0; i < 1000; i++ {
- var listenAddress *net.UDPAddr
- listenAddress, err = net.ResolveUDPAddr("udp", "10.82.31.4:9999")
- if err != nil {
- continue
- }
- sockRead, err = net.ListenUDP("udp", listenAddress)
- if err == nil {
- defer sockRead.Close()
- break
- }
- time.Sleep(time.Millisecond * 100)
- }
- if err != nil {
- t.Fatal(err)
- }
- var wait sync.WaitGroup
- wait.Add(4)
- doneSockWrite := false
- doneTunWrite := false
- fatalErrors := make(chan error, 2)
- errors := make(chan error, 2)
- go func() {
- defer wait.Done()
- buffer := append(token[:], 0, 0, 0, 0, 0, 0, 0, 0)
- for sendingIndex := uint64(0); !doneSockWrite; sendingIndex++ {
- binary.LittleEndian.PutUint64(buffer[32:], sendingIndex)
- _, err := sockWrite.Write(buffer[:])
- if err != nil {
- fatalErrors <- err
- }
- }
- }()
- go func() {
- defer wait.Done()
- packet := [20 + 8 + 32 + 8]byte{
- 0x45, 0, 0, 20 + 8 + 32 + 8,
- 0, 0, 0, 0,
- 0x80, 0x11, 0, 0,
- 10, 82, 31, 5,
- 10, 82, 31, 4,
- 8888 >> 8, 8888 & 0xff, 9999 >> 8, 9999 & 0xff, 0, 8 + 32 + 8, 0, 0,
- }
- copy(packet[28:], token[:])
- for sendingIndex := uint64(0); !doneTunWrite; sendingIndex++ {
- binary.BigEndian.PutUint16(packet[4:], uint16(sendingIndex))
- var checksum uint32
- for i := 0; i < 20; i += 2 {
- if i != 10 {
- checksum += uint32(binary.BigEndian.Uint16(packet[i:]))
- }
- }
- binary.BigEndian.PutUint16(packet[10:], ^(uint16(checksum>>16) + uint16(checksum&0xffff)))
- binary.LittleEndian.PutUint64(packet[20+8+32:], sendingIndex)
- n, err := tunDevice.Write(packet[:], 0)
- if err != nil {
- fatalErrors <- err
- }
- if n == 0 {
- time.Sleep(time.Millisecond * 300)
- }
- }
- }()
- const packetsPerTest = 1 << 21
- go func() {
- defer func() {
- doneSockWrite = true
- wait.Done()
- }()
- var expectedIndex uint64
- for i := uint64(0); i < packetsPerTest; {
- var buffer [(1 << 16) - 1]byte
- bytesRead, err := tunDevice.Read(buffer[:], 0)
- if err != nil {
- fatalErrors <- err
- }
- if bytesRead < 0 || bytesRead > len(buffer) {
- continue
- }
- packet := buffer[:bytesRead]
- tokenPos := bytes.Index(packet, token[:])
- if tokenPos == -1 || tokenPos+32+8 > len(packet) {
- continue
- }
- foundIndex := binary.LittleEndian.Uint64(packet[tokenPos+32:])
- if foundIndex < expectedIndex {
- errors <- fmt.Errorf("Sock write, tun read: expected packet %d, received packet %d", expectedIndex, foundIndex)
- }
- expectedIndex = foundIndex + 1
- i++
- }
- }()
- go func() {
- defer func() {
- doneTunWrite = true
- wait.Done()
- }()
- var expectedIndex uint64
- for i := uint64(0); i < packetsPerTest; {
- var buffer [(1 << 16) - 1]byte
- bytesRead, err := sockRead.Read(buffer[:])
- if err != nil {
- fatalErrors <- err
- }
- if bytesRead < 0 || bytesRead > len(buffer) {
- continue
- }
- packet := buffer[:bytesRead]
- if len(packet) != 32+8 || !bytes.HasPrefix(packet, token[:]) {
- continue
- }
- foundIndex := binary.LittleEndian.Uint64(packet[32:])
- if foundIndex < expectedIndex {
- errors <- fmt.Errorf("Tun write, sock read: expected packet %d, received packet %d", expectedIndex, foundIndex)
- }
- expectedIndex = foundIndex + 1
- i++
- }
- }()
- done := make(chan bool, 2)
- doneFunc := func() {
- wait.Wait()
- done <- true
- }
- defer doneFunc()
- go doneFunc()
- for {
- select {
- case err := <-fatalErrors:
- t.Fatal(err)
- case err := <-errors:
- t.Error(err)
- case <-done:
- return
- }
- }
-}