aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/tunnel
diff options
context:
space:
mode:
Diffstat (limited to 'tunnel')
-rw-r--r--tunnel/addressconfig.go187
-rw-r--r--tunnel/deterministicguid.go16
-rw-r--r--tunnel/firewall/blocker.go6
-rw-r--r--tunnel/firewall/helpers.go2
-rw-r--r--tunnel/firewall/mksyscall.go2
-rw-r--r--tunnel/firewall/rules.go17
-rw-r--r--tunnel/firewall/syscall_windows.go2
-rw-r--r--tunnel/firewall/types_windows.go8
-rw-r--r--tunnel/firewall/types_windows_32.go4
-rw-r--r--tunnel/firewall/types_windows_64.go4
-rw-r--r--tunnel/firewall/types_windows_test.go29
-rw-r--r--tunnel/interfacewatcher.go116
-rw-r--r--tunnel/ipcpermissions.go63
-rw-r--r--tunnel/mtumonitor.go (renamed from tunnel/defaultroutemonitor.go)59
-rw-r--r--tunnel/pitfalls.go177
-rw-r--r--tunnel/scriptrunner.go2
-rw-r--r--tunnel/service.go149
-rw-r--r--tunnel/winipcfg/interface_change_handler.go2
-rw-r--r--tunnel/winipcfg/luid.go93
-rw-r--r--tunnel/winipcfg/mksyscall.go2
-rw-r--r--tunnel/winipcfg/netsh.go14
-rw-r--r--tunnel/winipcfg/route_change_handler.go2
-rw-r--r--tunnel/winipcfg/types.go186
-rw-r--r--tunnel/winipcfg/types_32.go4
-rw-r--r--tunnel/winipcfg/types_64.go4
-rw-r--r--tunnel/winipcfg/types_test.go3
-rw-r--r--tunnel/winipcfg/types_test_32.go4
-rw-r--r--tunnel/winipcfg/types_test_64.go4
-rw-r--r--tunnel/winipcfg/unicast_address_change_handler.go2
-rw-r--r--tunnel/winipcfg/winipcfg.go12
-rw-r--r--tunnel/winipcfg/winipcfg_test.go90
-rw-r--r--tunnel/winipcfg/zwinipcfg_windows.go10
-rw-r--r--tunnel/wintun_test.go202
33 files changed, 638 insertions, 839 deletions
diff --git a/tunnel/addressconfig.go b/tunnel/addressconfig.go
index 44bfd8ae..a3ce6295 100644
--- a/tunnel/addressconfig.go
+++ b/tunnel/addressconfig.go
@@ -1,42 +1,30 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
import (
- "bytes"
+ "fmt"
"log"
- "net"
- "sort"
+ "net/netip"
+ "time"
"golang.org/x/sys/windows"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
-func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []net.IPNet) {
+func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, addresses []netip.Prefix) {
if len(addresses) == 0 {
return
}
- includedInAddresses := func(a net.IPNet) bool {
- // TODO: this makes the whole algorithm O(n^2). But we can't stick net.IPNet in a Go hashmap. Bummer!
- for _, addr := range addresses {
- ip := addr.IP
- if ip4 := ip.To4(); ip4 != nil {
- ip = ip4
- }
- mA, _ := addr.Mask.Size()
- mB, _ := a.Mask.Size()
- if bytes.Equal(ip, a.IP) && mA == mB {
- return true
- }
- }
- return false
+ addrHash := make(map[netip.Addr]bool, len(addresses))
+ for i := range addresses {
+ addrHash[addresses[i].Addr()] = true
}
interfaces, err := winipcfg.GetAdaptersAddresses(family, winipcfg.GAAFlagDefault)
if err != nil {
@@ -47,147 +35,124 @@ func cleanupAddressesOnDisconnectedInterfaces(family winipcfg.AddressFamily, add
continue
}
for address := iface.FirstUnicastAddress; address != nil; address = address.Next {
- ip := address.Address.IP()
- ipnet := net.IPNet{IP: ip, Mask: net.CIDRMask(int(address.OnLinkPrefixLength), 8*len(ip))}
- if includedInAddresses(ipnet) {
- log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", ipnet.String(), iface.FriendlyName())
- iface.LUID.DeleteIPAddress(ipnet)
+ if ip, _ := netip.AddrFromSlice(address.Address.IP()); addrHash[ip] {
+ prefix := netip.PrefixFrom(ip, int(address.OnLinkPrefixLength))
+ log.Printf("Cleaning up stale address %s from interface ā€˜%sā€™", prefix.String(), iface.FriendlyName())
+ iface.LUID.DeleteIPAddress(prefix)
}
}
}
}
-func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, tun *tun.NativeTun) error {
- luid := winipcfg.LUID(tun.LUID())
+func configureInterface(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) error {
+ retryOnFailure := services.StartedAtBoot()
+ tryTimes := 0
+startOver:
+ var err error
+ if tryTimes > 0 {
+ log.Printf("Retrying interface configuration after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ time.Sleep(time.Second)
+ retryOnFailure = retryOnFailure && tryTimes < 15
+ }
+ tryTimes++
estimatedRouteCount := 0
for _, peer := range conf.Peers {
estimatedRouteCount += len(peer.AllowedIPs)
}
- routes := make([]winipcfg.RouteData, 0, estimatedRouteCount)
- addresses := make([]net.IPNet, len(conf.Interface.Addresses))
- var haveV4Address, haveV6Address bool
- for i, addr := range conf.Interface.Addresses {
- addresses[i] = addr.IPNet()
- if addr.Bits() == 32 {
- haveV4Address = true
- } else if addr.Bits() == 128 {
- haveV6Address = true
- }
- }
+ routes := make(map[winipcfg.RouteData]bool, estimatedRouteCount)
foundDefault4 := false
foundDefault6 := false
for _, peer := range conf.Peers {
for _, allowedip := range peer.AllowedIPs {
- allowedip.MaskSelf()
- if (allowedip.Bits() == 32 && !haveV4Address) || (allowedip.Bits() == 128 && !haveV6Address) {
- continue
- }
route := winipcfg.RouteData{
- Destination: allowedip.IPNet(),
+ Destination: allowedip.Masked(),
Metric: 0,
}
- if allowedip.Bits() == 32 {
- if allowedip.Cidr == 0 {
+ if allowedip.Addr().Is4() {
+ if allowedip.Bits() == 0 {
foundDefault4 = true
}
- route.NextHop = net.IPv4zero
- } else if allowedip.Bits() == 128 {
- if allowedip.Cidr == 0 {
+ route.NextHop = netip.IPv4Unspecified()
+ } else if allowedip.Addr().Is6() {
+ if allowedip.Bits() == 0 {
foundDefault6 = true
}
- route.NextHop = net.IPv6zero
+ route.NextHop = netip.IPv6Unspecified()
}
- routes = append(routes, route)
+ routes[route] = true
}
}
- err := luid.SetIPAddressesForFamily(family, addresses)
- if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
- cleanupAddressesOnDisconnectedInterfaces(family, addresses)
- err = luid.SetIPAddressesForFamily(family, addresses)
- }
- if err != nil {
- return err
+ deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
+ for route := range routes {
+ r := route
+ deduplicatedRoutes = append(deduplicatedRoutes, &r)
}
- deduplicatedRoutes := make([]*winipcfg.RouteData, 0, len(routes))
- sort.Slice(routes, func(i, j int) bool {
- if routes[i].Metric != routes[j].Metric {
- return routes[i].Metric < routes[j].Metric
- }
- if c := bytes.Compare(routes[i].NextHop, routes[j].NextHop); c != 0 {
- return c < 0
+ if !conf.Interface.TableOff {
+ err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set routes: %w", err)
}
- if c := bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP); c != 0 {
- return c < 0
- }
- if c := bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask); c != 0 {
- return c < 0
- }
- return false
- })
- for i := 0; i < len(routes); i++ {
- if i > 0 && routes[i].Metric == routes[i-1].Metric &&
- bytes.Equal(routes[i].NextHop, routes[i-1].NextHop) &&
- bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) &&
- bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) {
- continue
- }
- deduplicatedRoutes = append(deduplicatedRoutes, &routes[i])
}
- err = luid.SetRoutesForFamily(family, deduplicatedRoutes)
- if err != nil {
- return err
+ err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses)
+ if err == windows.ERROR_OBJECT_ALREADY_EXISTS {
+ cleanupAddressesOnDisconnectedInterfaces(family, conf.Interface.Addresses)
+ err = luid.SetIPAddressesForFamily(family, conf.Interface.Addresses)
+ }
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set ips: %w", err)
}
- ipif, err := luid.IPInterface(family)
+ var ipif *winipcfg.MibIPInterfaceRow
+ ipif, err = luid.IPInterface(family)
if err != nil {
return err
}
+ ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
+ ipif.DadTransmits = 0
+ ipif.ManagedAddressConfigurationSupported = false
+ ipif.OtherStatefulConfigurationSupported = false
if conf.Interface.MTU > 0 {
ipif.NLMTU = uint32(conf.Interface.MTU)
- tun.ForceMTU(int(ipif.NLMTU))
}
- if family == windows.AF_INET {
- if foundDefault4 {
- ipif.UseAutomaticMetric = false
- ipif.Metric = 0
- }
- } else if family == windows.AF_INET6 {
- if foundDefault6 {
- ipif.UseAutomaticMetric = false
- ipif.Metric = 0
- }
- ipif.DadTransmits = 0
- ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled
+ if (family == windows.AF_INET && foundDefault4) || (family == windows.AF_INET6 && foundDefault6) {
+ ipif.UseAutomaticMetric = false
+ ipif.Metric = 0
}
err = ipif.Set()
- if err != nil {
- return err
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set metric and MTU: %w", err)
}
- return luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch)
+ err = luid.SetDNS(family, conf.Interface.DNS, conf.Interface.DNSSearch)
+ if err == windows.ERROR_NOT_FOUND && retryOnFailure {
+ goto startOver
+ } else if err != nil {
+ return fmt.Errorf("unable to set DNS: %w", err)
+ }
+ return nil
}
-func enableFirewall(conf *conf.Config, tun *tun.NativeTun) error {
+func enableFirewall(conf *conf.Config, luid winipcfg.LUID) error {
doNotRestrict := true
- if len(conf.Peers) == 1 {
- nextallowedip:
+ if len(conf.Peers) == 1 && !conf.Interface.TableOff {
for _, allowedip := range conf.Peers[0].AllowedIPs {
- if allowedip.Cidr == 0 {
- for _, b := range allowedip.IP {
- if b != 0 {
- continue nextallowedip
- }
- }
+ if allowedip.Bits() == 0 && allowedip == allowedip.Masked() {
doNotRestrict = false
break
}
}
}
log.Println("Enabling firewall rules")
- return firewall.EnableFirewall(tun.LUID(), doNotRestrict, conf.Interface.DNS)
+ return firewall.EnableFirewall(uint64(luid), doNotRestrict, conf.Interface.DNS)
}
diff --git a/tunnel/deterministicguid.go b/tunnel/deterministicguid.go
index 455deaeb..405d33a3 100644
--- a/tunnel/deterministicguid.go
+++ b/tunnel/deterministicguid.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
@@ -18,8 +18,10 @@ import (
"golang.zx2c4.com/wireguard/windows/conf"
)
-const deterministicGUIDLabel = "Deterministic WireGuard Windows GUID v1 jason@zx2c4.com"
-const fixedGUIDLabel = "Fixed WireGuard Windows GUID v1 jason@zx2c4.com"
+const (
+ deterministicGUIDLabel = "Deterministic WireGuard Windows GUID v1 jason@zx2c4.com"
+ fixedGUIDLabel = "Fixed WireGuard Windows GUID v1 jason@zx2c4.com"
+)
// Escape hatch for external consumers, not us.
var UseFixedGUIDInsteadOfDeterministic = false
@@ -80,13 +82,13 @@ func deterministicGUID(c *conf.Config) *windows.GUID {
b2Number(len(peer.AllowedIPs))
sortedAllowedIPs := peer.AllowedIPs
sort.Slice(sortedAllowedIPs, func(i, j int) bool {
- if bi, bj := sortedAllowedIPs[i].Bits(), sortedAllowedIPs[j].Bits(); bi != bj {
+ if bi, bj := sortedAllowedIPs[i].Addr().BitLen(), sortedAllowedIPs[j].Addr().BitLen(); bi != bj {
return bi < bj
}
- if sortedAllowedIPs[i].Cidr != sortedAllowedIPs[j].Cidr {
- return sortedAllowedIPs[i].Cidr < sortedAllowedIPs[j].Cidr
+ if sortedAllowedIPs[i].Bits() != sortedAllowedIPs[j].Bits() {
+ return sortedAllowedIPs[i].Bits() < sortedAllowedIPs[j].Bits()
}
- return bytes.Compare(sortedAllowedIPs[i].IP[:], sortedAllowedIPs[j].IP[:]) < 0
+ return sortedAllowedIPs[i].Addr().Compare(sortedAllowedIPs[j].Addr()) < 0
})
for _, allowedip := range sortedAllowedIPs {
b2String(allowedip.String())
diff --git a/tunnel/firewall/blocker.go b/tunnel/firewall/blocker.go
index 2cb2c7f2..8a4967ba 100644
--- a/tunnel/firewall/blocker.go
+++ b/tunnel/firewall/blocker.go
@@ -1,13 +1,13 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
import (
"errors"
- "net"
+ "net/netip"
"unsafe"
"golang.org/x/sys/windows"
@@ -101,7 +101,7 @@ func registerBaseObjects(session uintptr) (*baseObjects, error) {
return bo, nil
}
-func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []net.IP) error {
+func EnableFirewall(luid uint64, doNotRestrict bool, restrictToDNSServers []netip.Addr) error {
if wfpSession != 0 {
return errors.New("The firewall has already been enabled")
}
diff --git a/tunnel/firewall/helpers.go b/tunnel/firewall/helpers.go
index e3e4eac6..46e43aa5 100644
--- a/tunnel/firewall/helpers.go
+++ b/tunnel/firewall/helpers.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
diff --git a/tunnel/firewall/mksyscall.go b/tunnel/firewall/mksyscall.go
index d5ff98aa..fc108007 100644
--- a/tunnel/firewall/mksyscall.go
+++ b/tunnel/firewall/mksyscall.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
diff --git a/tunnel/firewall/rules.go b/tunnel/firewall/rules.go
index c4488a31..41632f98 100644
--- a/tunnel/firewall/rules.go
+++ b/tunnel/firewall/rules.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
@@ -8,7 +8,7 @@ package firewall
import (
"encoding/binary"
"errors"
- "net"
+ "net/netip"
"runtime"
"unsafe"
@@ -582,7 +582,6 @@ func permitDHCPIPv6(session uintptr, baseObjects *baseObjects, weight uint8) err
}
func permitNdp(session uintptr, baseObjects *baseObjects, weight uint8) error {
-
/* TODO: actually handle the hop limit somehow! The rules should vaguely be:
* - icmpv6 133: must be outgoing, dst must be FF02::2/128, hop limit must be 255
* - icmpv6 134: must be incoming, src must be FE80::/10, hop limit must be 255
@@ -985,7 +984,7 @@ func blockAll(session uintptr, baseObjects *baseObjects, weight uint8) error {
}
// Block all DNS traffic except towards specified DNS servers.
-func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weightAllow uint8, weightDeny uint8) error {
+func blockDNS(except []netip.Addr, session uintptr, baseObjects *baseObjects, weightAllow, weightDeny uint8) error {
if weightDeny >= weightAllow {
return errors.New("The allow weight must be greater than the deny weight")
}
@@ -1106,8 +1105,7 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight
allowConditionsV4 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except))
allowConditionsV4 = append(allowConditionsV4, denyConditions...)
for _, ip := range except {
- ip4 := ip.To4()
- if ip4 == nil {
+ if !ip.Is4() {
continue
}
allowConditionsV4 = append(allowConditionsV4, wtFwpmFilterCondition0{
@@ -1115,7 +1113,7 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight
matchType: cFWP_MATCH_EQUAL,
conditionValue: wtFwpConditionValue0{
_type: cFWP_UINT32,
- value: uintptr(binary.BigEndian.Uint32(ip4)),
+ value: uintptr(binary.BigEndian.Uint32(ip.AsSlice())),
},
})
}
@@ -1124,11 +1122,10 @@ func blockDNS(except []net.IP, session uintptr, baseObjects *baseObjects, weight
allowConditionsV6 := make([]wtFwpmFilterCondition0, 0, len(denyConditions)+len(except))
allowConditionsV6 = append(allowConditionsV6, denyConditions...)
for _, ip := range except {
- if ip.To4() != nil {
+ if !ip.Is6() {
continue
}
- var address wtFwpByteArray16
- copy(address.byteArray16[:], ip)
+ address := wtFwpByteArray16{byteArray16: ip.As16()}
allowConditionsV6 = append(allowConditionsV6, wtFwpmFilterCondition0{
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
matchType: cFWP_MATCH_EQUAL,
diff --git a/tunnel/firewall/syscall_windows.go b/tunnel/firewall/syscall_windows.go
index 661527d9..4d8eea42 100644
--- a/tunnel/firewall/syscall_windows.go
+++ b/tunnel/firewall/syscall_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
diff --git a/tunnel/firewall/types_windows.go b/tunnel/firewall/types_windows.go
index 075daae4..54e2aad7 100644
--- a/tunnel/firewall/types_windows.go
+++ b/tunnel/firewall/types_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
@@ -148,8 +148,10 @@ var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
}
-var cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
-var cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
+var (
+ cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
+ cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
+)
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
diff --git a/tunnel/firewall/types_windows_32.go b/tunnel/firewall/types_windows_32.go
index 83aced6e..29ae553a 100644
--- a/tunnel/firewall/types_windows_32.go
+++ b/tunnel/firewall/types_windows_32.go
@@ -1,8 +1,8 @@
-// +build 386 arm
+//go:build 386 || arm
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
diff --git a/tunnel/firewall/types_windows_64.go b/tunnel/firewall/types_windows_64.go
index 6e60aa5b..a476a745 100644
--- a/tunnel/firewall/types_windows_64.go
+++ b/tunnel/firewall/types_windows_64.go
@@ -1,8 +1,8 @@
-// +build amd64 arm64
+//go:build amd64 || arm64
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
diff --git a/tunnel/firewall/types_windows_test.go b/tunnel/firewall/types_windows_test.go
index 683e52ea..afa1988f 100644
--- a/tunnel/firewall/types_windows_test.go
+++ b/tunnel/firewall/types_windows_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package firewall
@@ -11,7 +11,6 @@ import (
)
func TestWtFwpByteBlobSize(t *testing.T) {
-
const actualWtFwpByteBlobSize = unsafe.Sizeof(wtFwpByteBlob{})
if actualWtFwpByteBlobSize != wtFwpByteBlob_Size {
@@ -21,7 +20,6 @@ func TestWtFwpByteBlobSize(t *testing.T) {
}
func TestWtFwpByteBlobOffsets(t *testing.T) {
-
s := wtFwpByteBlob{}
sp := uintptr(unsafe.Pointer(&s))
@@ -34,7 +32,6 @@ func TestWtFwpByteBlobOffsets(t *testing.T) {
}
func TestWtFwpmAction0Size(t *testing.T) {
-
const actualWtFwpmAction0Size = unsafe.Sizeof(wtFwpmAction0{})
if actualWtFwpmAction0Size != wtFwpmAction0_Size {
@@ -44,7 +41,6 @@ func TestWtFwpmAction0Size(t *testing.T) {
}
func TestWtFwpmAction0Offsets(t *testing.T) {
-
s := wtFwpmAction0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -58,7 +54,6 @@ func TestWtFwpmAction0Offsets(t *testing.T) {
}
func TestWtFwpBitmapArray64Size(t *testing.T) {
-
const actualWtFwpBitmapArray64Size = unsafe.Sizeof(wtFwpBitmapArray64{})
if actualWtFwpBitmapArray64Size != wtFwpBitmapArray64_Size {
@@ -68,7 +63,6 @@ func TestWtFwpBitmapArray64Size(t *testing.T) {
}
func TestWtFwpByteArray6Size(t *testing.T) {
-
const actualWtFwpByteArray6Size = unsafe.Sizeof(wtFwpByteArray6{})
if actualWtFwpByteArray6Size != wtFwpByteArray6_Size {
@@ -78,7 +72,6 @@ func TestWtFwpByteArray6Size(t *testing.T) {
}
func TestWtFwpByteArray16Size(t *testing.T) {
-
const actualWtFwpByteArray16Size = unsafe.Sizeof(wtFwpByteArray16{})
if actualWtFwpByteArray16Size != wtFwpByteArray16_Size {
@@ -88,7 +81,6 @@ func TestWtFwpByteArray16Size(t *testing.T) {
}
func TestWtFwpConditionValue0Size(t *testing.T) {
-
const actualWtFwpConditionValue0Size = unsafe.Sizeof(wtFwpConditionValue0{})
if actualWtFwpConditionValue0Size != wtFwpConditionValue0_Size {
@@ -98,7 +90,6 @@ func TestWtFwpConditionValue0Size(t *testing.T) {
}
func TestWtFwpConditionValue0Offsets(t *testing.T) {
-
s := wtFwpConditionValue0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -111,7 +102,6 @@ func TestWtFwpConditionValue0Offsets(t *testing.T) {
}
func TestWtFwpV4AddrAndMaskSize(t *testing.T) {
-
const actualWtFwpV4AddrAndMaskSize = unsafe.Sizeof(wtFwpV4AddrAndMask{})
if actualWtFwpV4AddrAndMaskSize != wtFwpV4AddrAndMask_Size {
@@ -121,7 +111,6 @@ func TestWtFwpV4AddrAndMaskSize(t *testing.T) {
}
func TestWtFwpV4AddrAndMaskOffsets(t *testing.T) {
-
s := wtFwpV4AddrAndMask{}
sp := uintptr(unsafe.Pointer(&s))
@@ -135,7 +124,6 @@ func TestWtFwpV4AddrAndMaskOffsets(t *testing.T) {
}
func TestWtFwpV6AddrAndMaskSize(t *testing.T) {
-
const actualWtFwpV6AddrAndMaskSize = unsafe.Sizeof(wtFwpV6AddrAndMask{})
if actualWtFwpV6AddrAndMaskSize != wtFwpV6AddrAndMask_Size {
@@ -145,7 +133,6 @@ func TestWtFwpV6AddrAndMaskSize(t *testing.T) {
}
func TestWtFwpV6AddrAndMaskOffsets(t *testing.T) {
-
s := wtFwpV6AddrAndMask{}
sp := uintptr(unsafe.Pointer(&s))
@@ -159,7 +146,6 @@ func TestWtFwpV6AddrAndMaskOffsets(t *testing.T) {
}
func TestWtFwpValue0Size(t *testing.T) {
-
const actualWtFwpValue0Size = unsafe.Sizeof(wtFwpValue0{})
if actualWtFwpValue0Size != wtFwpValue0_Size {
@@ -168,7 +154,6 @@ func TestWtFwpValue0Size(t *testing.T) {
}
func TestWtFwpValue0Offsets(t *testing.T) {
-
s := wtFwpValue0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -181,7 +166,6 @@ func TestWtFwpValue0Offsets(t *testing.T) {
}
func TestWtFwpmDisplayData0Size(t *testing.T) {
-
const actualWtFwpmDisplayData0Size = unsafe.Sizeof(wtFwpmDisplayData0{})
if actualWtFwpmDisplayData0Size != wtFwpmDisplayData0_Size {
@@ -191,7 +175,6 @@ func TestWtFwpmDisplayData0Size(t *testing.T) {
}
func TestWtFwpmDisplayData0Offsets(t *testing.T) {
-
s := wtFwpmDisplayData0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -205,7 +188,6 @@ func TestWtFwpmDisplayData0Offsets(t *testing.T) {
}
func TestWtFwpmFilterCondition0Size(t *testing.T) {
-
const actualWtFwpmFilterCondition0Size = unsafe.Sizeof(wtFwpmFilterCondition0{})
if actualWtFwpmFilterCondition0Size != wtFwpmFilterCondition0_Size {
@@ -215,7 +197,6 @@ func TestWtFwpmFilterCondition0Size(t *testing.T) {
}
func TestWtFwpmFilterCondition0Offsets(t *testing.T) {
-
s := wtFwpmFilterCondition0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -237,7 +218,6 @@ func TestWtFwpmFilterCondition0Offsets(t *testing.T) {
}
func TestWtFwpmFilter0Size(t *testing.T) {
-
const actualWtFwpmFilter0Size = unsafe.Sizeof(wtFwpmFilter0{})
if actualWtFwpmFilter0Size != wtFwpmFilter0_Size {
@@ -247,7 +227,6 @@ func TestWtFwpmFilter0Size(t *testing.T) {
}
func TestWtFwpmFilter0Offsets(t *testing.T) {
-
s := wtFwpmFilter0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -364,7 +343,6 @@ func TestWtFwpmFilter0Offsets(t *testing.T) {
}
func TestWtFwpProvider0Size(t *testing.T) {
-
const actualWtFwpProvider0Size = unsafe.Sizeof(wtFwpProvider0{})
if actualWtFwpProvider0Size != wtFwpProvider0_Size {
@@ -374,7 +352,6 @@ func TestWtFwpProvider0Size(t *testing.T) {
}
func TestWtFwpProvider0Offsets(t *testing.T) {
-
s := wtFwpProvider0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -412,7 +389,6 @@ func TestWtFwpProvider0Offsets(t *testing.T) {
}
func TestWtFwpmSession0Size(t *testing.T) {
-
const actualWtFwpmSession0Size = unsafe.Sizeof(wtFwpmSession0{})
if actualWtFwpmSession0Size != wtFwpmSession0_Size {
@@ -422,7 +398,6 @@ func TestWtFwpmSession0Size(t *testing.T) {
}
func TestWtFwpmSession0Offsets(t *testing.T) {
-
s := wtFwpmSession0{}
sp := uintptr(unsafe.Pointer(&s))
@@ -482,7 +457,6 @@ func TestWtFwpmSession0Offsets(t *testing.T) {
}
func TestWtFwpmSublayer0Size(t *testing.T) {
-
const actualWtFwpmSublayer0Size = unsafe.Sizeof(wtFwpmSublayer0{})
if actualWtFwpmSublayer0Size != wtFwpmSublayer0_Size {
@@ -492,7 +466,6 @@ func TestWtFwpmSublayer0Size(t *testing.T) {
}
func TestWtFwpmSublayer0Offsets(t *testing.T) {
-
s := wtFwpmSublayer0{}
sp := uintptr(unsafe.Pointer(&s))
diff --git a/tunnel/interfacewatcher.go b/tunnel/interfacewatcher.go
index e12e5929..a831d06e 100644
--- a/tunnel/interfacewatcher.go
+++ b/tunnel/interfacewatcher.go
@@ -1,20 +1,20 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
import (
+ "errors"
+ "fmt"
"log"
"sync"
+ "time"
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/services"
"golang.zx2c4.com/wireguard/windows/tunnel/firewall"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -24,63 +24,30 @@ type interfaceWatcherError struct {
serviceError services.Error
err error
}
+
type interfaceWatcherEvent struct {
luid winipcfg.LUID
family winipcfg.AddressFamily
}
+
type interfaceWatcher struct {
- errors chan interfaceWatcherError
+ errors chan interfaceWatcherError
+ started chan winipcfg.AddressFamily
- binder conn.BindSocketToInterface
- conf *conf.Config
- tun *tun.NativeTun
+ conf *conf.Config
+ adapter *driver.Adapter
+ luid winipcfg.LUID
setupMutex sync.Mutex
interfaceChangeCallback winipcfg.ChangeCallback
changeCallbacks4 []winipcfg.ChangeCallback
changeCallbacks6 []winipcfg.ChangeCallback
storedEvents []interfaceWatcherEvent
-}
-
-func hasDefaultRoute(family winipcfg.AddressFamily, peers []conf.Peer) bool {
- var (
- foundV401 bool
- foundV41281 bool
- foundV600001 bool
- foundV680001 bool
- foundV400 bool
- foundV600 bool
- v40 = [4]byte{}
- v60 = [16]byte{}
- v48 = [4]byte{0x80}
- v68 = [16]byte{0x80}
- )
- for _, peer := range peers {
- for _, allowedip := range peer.AllowedIPs {
- if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) {
- foundV600001 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) {
- foundV680001 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) {
- foundV401 = true
- } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) {
- foundV41281 = true
- } else if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) {
- foundV600 = true
- } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) {
- foundV400 = true
- }
- }
- }
- if family == windows.AF_INET {
- return foundV400 || (foundV401 && foundV41281)
- } else if family == windows.AF_INET6 {
- return foundV600 || (foundV600001 && foundV680001)
- }
- return false
+ watchdog *time.Timer
}
func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
+ iw.watchdog.Stop()
var changeCallbacks *[]winipcfg.ChangeCallback
var ipversion string
if family == windows.AF_INET {
@@ -100,25 +67,35 @@ func (iw *interfaceWatcher) setup(family winipcfg.AddressFamily) {
}
var err error
- log.Printf("Monitoring default %s routes", ipversion)
- *changeCallbacks, err = monitorDefaultRoutes(family, iw.binder, iw.conf.Interface.MTU == 0, hasDefaultRoute(family, iw.conf.Peers), iw.tun)
- if err != nil {
- iw.errors <- interfaceWatcherError{services.ErrorBindSocketsToDefaultRoutes, err}
- return
+ if iw.conf.Interface.MTU == 0 {
+ log.Printf("Monitoring MTU of default %s routes", ipversion)
+ *changeCallbacks, err = monitorMTU(family, iw.luid)
+ if err != nil {
+ iw.errors <- interfaceWatcherError{services.ErrorMonitorMTUChanges, err}
+ return
+ }
}
log.Printf("Setting device %s addresses", ipversion)
- err = configureInterface(family, iw.conf, iw.tun)
+ err = configureInterface(family, iw.conf, iw.luid)
if err != nil {
iw.errors <- interfaceWatcherError{services.ErrorSetNetConfig, err}
return
}
+ evaluateDynamicPitfalls(family, iw.conf, iw.luid)
+
+ iw.started <- family
}
func watchInterface() (*interfaceWatcher, error) {
iw := &interfaceWatcher{
- errors: make(chan interfaceWatcherError, 2),
+ errors: make(chan interfaceWatcherError, 2),
+ started: make(chan winipcfg.AddressFamily, 4),
}
+ iw.watchdog = time.AfterFunc(time.Duration(1<<63-1), func() {
+ iw.errors <- interfaceWatcherError{services.ErrorCreateNetworkAdapter, errors.New("TCP/IP interface for adapter did not appear after one minute")}
+ })
+ iw.watchdog.Stop()
var err error
iw.interfaceChangeCallback, err = winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
iw.setupMutex.Lock()
@@ -127,28 +104,41 @@ func watchInterface() (*interfaceWatcher, error) {
if notificationType != winipcfg.MibAddInstance {
return
}
- if iw.tun == nil {
+ if iw.luid == 0 {
iw.storedEvents = append(iw.storedEvents, interfaceWatcherEvent{iface.InterfaceLUID, iface.Family})
return
}
- if iface.InterfaceLUID != winipcfg.LUID(iw.tun.LUID()) {
+ if iface.InterfaceLUID != iw.luid {
return
}
iw.setup(iface.Family)
+
+ if state, err := iw.adapter.AdapterState(); err == nil && state == driver.AdapterStateDown {
+ log.Println("Reinitializing adapter configuration")
+ err = iw.adapter.SetConfiguration(iw.conf.ToDriverConfiguration())
+ if err != nil {
+ log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceSetConfig, err))
+ }
+ err = iw.adapter.SetAdapterState(driver.AdapterStateUp)
+ if err != nil {
+ log.Println(fmt.Errorf("%v: %w", services.ErrorDeviceBringUp, err))
+ }
+ }
})
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("unable to register interface change callback: %w", err)
}
return iw, nil
}
-func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, conf *conf.Config, tun *tun.NativeTun) {
+func (iw *interfaceWatcher) Configure(adapter *driver.Adapter, conf *conf.Config, luid winipcfg.LUID) {
iw.setupMutex.Lock()
defer iw.setupMutex.Unlock()
+ iw.watchdog.Reset(time.Minute)
- iw.binder, iw.conf, iw.tun = binder, conf, tun
+ iw.adapter, iw.conf, iw.luid = adapter, conf, luid
for _, event := range iw.storedEvents {
- if event.luid == winipcfg.LUID(iw.tun.LUID()) {
+ if event.luid == luid {
iw.setup(event.family)
}
}
@@ -157,10 +147,11 @@ func (iw *interfaceWatcher) Configure(binder conn.BindSocketToInterface, conf *c
func (iw *interfaceWatcher) Destroy() {
iw.setupMutex.Lock()
+ iw.watchdog.Stop()
changeCallbacks4 := iw.changeCallbacks4
changeCallbacks6 := iw.changeCallbacks6
interfaceChangeCallback := iw.interfaceChangeCallback
- tun := iw.tun
+ luid := iw.luid
iw.setupMutex.Unlock()
if interfaceChangeCallback != nil {
@@ -186,10 +177,9 @@ func (iw *interfaceWatcher) Destroy() {
changeCallbacks6 = changeCallbacks6[1:]
}
firewall.DisableFirewall()
- if tun != nil && iw.tun == tun {
+ if luid != 0 && iw.luid == luid {
// It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active
// routes, so to be certain, just remove everything before destroying.
- luid := winipcfg.LUID(tun.LUID())
luid.FlushRoutes(windows.AF_INET)
luid.FlushIPAddresses(windows.AF_INET)
luid.FlushDNS(windows.AF_INET)
diff --git a/tunnel/ipcpermissions.go b/tunnel/ipcpermissions.go
deleted file mode 100644
index 3a676e4b..00000000
--- a/tunnel/ipcpermissions.go
+++ /dev/null
@@ -1,63 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package tunnel
-
-import (
- "golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/ipc"
-
- "golang.zx2c4.com/wireguard/windows/conf"
-)
-
-func CopyConfigOwnerToIPCSecurityDescriptor(filename string) error {
- if conf.PathIsEncrypted(filename) {
- return nil
- }
-
- fileSd, err := windows.GetNamedSecurityInfo(filename, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
- if err != nil {
- return err
- }
- fileOwner, _, err := fileSd.Owner()
- if err != nil {
- return err
- }
- if fileOwner.IsWellKnown(windows.WinLocalSystemSid) {
- return nil
- }
- additionalEntries := []windows.EXPLICIT_ACCESS{{
- AccessPermissions: windows.GENERIC_ALL,
- AccessMode: windows.GRANT_ACCESS,
- Trustee: windows.TRUSTEE{
- TrusteeForm: windows.TRUSTEE_IS_SID,
- TrusteeType: windows.TRUSTEE_IS_USER,
- TrusteeValue: windows.TrusteeValueFromSID(fileOwner),
- },
- }}
-
- sd, err := ipc.UAPISecurityDescriptor.ToAbsolute()
- if err != nil {
- return err
- }
- dacl, defaulted, _ := sd.DACL()
-
- newDacl, err := windows.ACLFromEntries(additionalEntries, dacl)
- if err != nil {
- return err
- }
- err = sd.SetDACL(newDacl, true, defaulted)
- if err != nil {
- return err
- }
- sd, err = sd.ToSelfRelative()
- if err != nil {
- return err
- }
- ipc.UAPISecurityDescriptor = sd
-
- return nil
-}
diff --git a/tunnel/defaultroutemonitor.go b/tunnel/mtumonitor.go
index aa0db675..c07823a2 100644
--- a/tunnel/defaultroutemonitor.go
+++ b/tunnel/mtumonitor.go
@@ -1,30 +1,23 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
import (
- "log"
- "sync"
- "time"
-
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
-func bindSocketRoute(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32, blackholeWhenLoop bool) error {
+func findDefaultLUID(family winipcfg.AddressFamily, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32) error {
r, err := winipcfg.GetIPForwardTable2(family)
if err != nil {
return err
}
lowestMetric := ^uint32(0)
- index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want.
- luid := winipcfg.LUID(0) // Hopefully luid zero is unspecified, but hard to find docs saying so.
+ index := uint32(0)
+ luid := winipcfg.LUID(0)
for i := range r {
if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID {
continue
@@ -50,36 +43,24 @@ func bindSocketRoute(family winipcfg.AddressFamily, binder conn.BindSocketToInte
}
*lastLUID = luid
*lastIndex = index
- blackhole := blackholeWhenLoop && index == 0
- if family == windows.AF_INET {
- log.Printf("Binding v4 socket to interface %d (blackhole=%v)", index, blackhole)
- return binder.BindSocketToInterface4(index, blackhole)
- } else if family == windows.AF_INET6 {
- log.Printf("Binding v6 socket to interface %d (blackhole=%v)", index, blackhole)
- return binder.BindSocketToInterface6(index, blackhole)
- }
return nil
}
-func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketToInterface, autoMTU bool, blackholeWhenLoop bool, tun *tun.NativeTun) ([]winipcfg.ChangeCallback, error) {
+func monitorMTU(family winipcfg.AddressFamily, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) {
var minMTU uint32
if family == windows.AF_INET {
minMTU = 576
} else if family == windows.AF_INET6 {
minMTU = 1280
}
- ourLUID := winipcfg.LUID(tun.LUID())
lastLUID := winipcfg.LUID(0)
lastIndex := ^uint32(0)
lastMTU := uint32(0)
doIt := func() error {
- err := bindSocketRoute(family, binder, ourLUID, &lastLUID, &lastIndex, blackholeWhenLoop)
+ err := findDefaultLUID(family, ourLUID, &lastLUID, &lastIndex)
if err != nil {
return err
}
- if !autoMTU {
- return nil
- }
mtu := uint32(0)
if lastLUID != 0 {
iface, err := lastLUID.Interface()
@@ -103,7 +84,6 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketT
if err != nil {
return err
}
- tun.ForceMTU(int(iface.NLMTU)) // TODO: having one MTU for both v4 and v6 kind of breaks the windows model, so right now this just gets the second one which is... bad.
lastMTU = mtu
}
return nil
@@ -112,32 +92,9 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketT
if err != nil {
return nil, err
}
-
- firstBurst := time.Time{}
- burstMutex := sync.Mutex{}
- burstTimer := time.AfterFunc(time.Hour*200, func() {
- burstMutex.Lock()
- firstBurst = time.Time{}
- doIt()
- burstMutex.Unlock()
- })
- burstTimer.Stop()
- bump := func() {
- burstMutex.Lock()
- burstTimer.Reset(time.Millisecond * 150)
- if firstBurst.IsZero() {
- firstBurst = time.Now()
- } else if time.Since(firstBurst) > time.Second*2 {
- firstBurst = time.Time{}
- burstTimer.Stop()
- doIt()
- }
- burstMutex.Unlock()
- }
-
cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
if route != nil && route.DestinationPrefix.PrefixLength == 0 {
- bump()
+ doIt()
}
})
if err != nil {
@@ -145,7 +102,7 @@ func monitorDefaultRoutes(family winipcfg.AddressFamily, binder conn.BindSocketT
}
cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
if notificationType == winipcfg.MibParameterNotification {
- bump()
+ doIt()
}
})
if err != nil {
diff --git a/tunnel/pitfalls.go b/tunnel/pitfalls.go
new file mode 100644
index 00000000..fdef6eb2
--- /dev/null
+++ b/tunnel/pitfalls.go
@@ -0,0 +1,177 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
+ */
+
+package tunnel
+
+import (
+ "log"
+ "net/netip"
+ "strings"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+ "golang.org/x/sys/windows/svc/mgr"
+ "golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
+)
+
+func evaluateStaticPitfalls() {
+ go func() {
+ pitfallDnsCacheDisabled()
+ pitfallVirtioNetworkDriver()
+ }()
+}
+
+func evaluateDynamicPitfalls(family winipcfg.AddressFamily, conf *conf.Config, luid winipcfg.LUID) {
+ go func() {
+ pitfallWeakHostSend(family, conf, luid)
+ }()
+}
+
+func pitfallDnsCacheDisabled() {
+ scm, err := mgr.Connect()
+ if err != nil {
+ return
+ }
+ defer scm.Disconnect()
+ svc := mgr.Service{Name: "dnscache"}
+ svc.Handle, err = windows.OpenService(scm.Handle, windows.StringToUTF16Ptr(svc.Name), windows.SERVICE_QUERY_CONFIG)
+ if err != nil {
+ return
+ }
+ defer svc.Close()
+ cfg, err := svc.Config()
+ if err != nil {
+ return
+ }
+ if cfg.StartType != mgr.StartDisabled {
+ return
+ }
+
+ log.Printf("Warning: the %q (dnscache) service is disabled; please re-enable it", cfg.DisplayName)
+}
+
+func pitfallVirtioNetworkDriver() {
+ var modules []windows.RTL_PROCESS_MODULE_INFORMATION
+ for bufferSize := uint32(128 * 1024); ; {
+ moduleBuffer := make([]byte, bufferSize)
+ err := windows.NtQuerySystemInformation(windows.SystemModuleInformation, unsafe.Pointer(&moduleBuffer[0]), bufferSize, &bufferSize)
+ switch err {
+ case windows.STATUS_INFO_LENGTH_MISMATCH:
+ continue
+ case nil:
+ break
+ default:
+ return
+ }
+ mods := (*windows.RTL_PROCESS_MODULES)(unsafe.Pointer(&moduleBuffer[0]))
+ modules = unsafe.Slice(&mods.Modules[0], mods.NumberOfModules)
+ break
+ }
+ for i := range modules {
+ if !strings.EqualFold(windows.ByteSliceToString(modules[i].FullPathName[modules[i].OffsetToFileName:]), "netkvm.sys") {
+ continue
+ }
+ driverPath := `\\?\GLOBALROOT` + windows.ByteSliceToString(modules[i].FullPathName[:])
+ var zero windows.Handle
+ infoSize, err := windows.GetFileVersionInfoSize(driverPath, &zero)
+ if err != nil {
+ return
+ }
+ versionInfo := make([]byte, infoSize)
+ err = windows.GetFileVersionInfo(driverPath, 0, infoSize, unsafe.Pointer(&versionInfo[0]))
+ if err != nil {
+ return
+ }
+ var fixedInfo *windows.VS_FIXEDFILEINFO
+ fixedInfoLen := uint32(unsafe.Sizeof(*fixedInfo))
+ err = windows.VerQueryValue(unsafe.Pointer(&versionInfo[0]), `\`, unsafe.Pointer(&fixedInfo), &fixedInfoLen)
+ if err != nil {
+ return
+ }
+ const minimumPlausibleVersion = 40 << 48
+ const minimumGoodVersion = (100 << 48) | (85 << 32) | (104 << 16) | (20800 << 0)
+ version := (uint64(fixedInfo.FileVersionMS) << 32) | uint64(fixedInfo.FileVersionLS)
+ if version >= minimumGoodVersion || version < minimumPlausibleVersion {
+ return
+ }
+ log.Println("Warning: the VirtIO network driver (NetKVM) is out of date and may cause known problems; please update to v100.85.104.20800 or later")
+ return
+ }
+}
+
+func pitfallWeakHostSend(family winipcfg.AddressFamily, conf *conf.Config, ourLUID winipcfg.LUID) {
+ routingTable, err := winipcfg.GetIPForwardTable2(family)
+ if err != nil {
+ return
+ }
+ type endpointRoute struct {
+ addr netip.Addr
+ name string
+ lowestMetric uint32
+ highestCIDR uint8
+ weakHostSend bool
+ finalIsOurs bool
+ }
+ endpoints := make([]endpointRoute, 0, len(conf.Peers))
+ for _, peer := range conf.Peers {
+ addr, err := netip.ParseAddr(peer.Endpoint.Host)
+ if err != nil || (addr.Is4() && family != windows.AF_INET) || (addr.Is6() && family != windows.AF_INET6) {
+ continue
+ }
+ endpoints = append(endpoints, endpointRoute{addr: addr, lowestMetric: ^uint32(0)})
+ }
+ for i := range routingTable {
+ var (
+ ifrow *winipcfg.MibIfRow2
+ ifacerow *winipcfg.MibIPInterfaceRow
+ metric uint32
+ )
+ for j := range endpoints {
+ r, e := &routingTable[i], &endpoints[j]
+ if r.DestinationPrefix.PrefixLength < e.highestCIDR {
+ continue
+ }
+ if !r.DestinationPrefix.Prefix().Contains(e.addr) {
+ continue
+ }
+ if ifrow == nil {
+ ifrow, err = r.InterfaceLUID.Interface()
+ if err != nil {
+ continue
+ }
+ }
+ if ifrow.OperStatus != winipcfg.IfOperStatusUp {
+ continue
+ }
+ if ifacerow == nil {
+ ifacerow, err = r.InterfaceLUID.IPInterface(family)
+ if err != nil {
+ continue
+ }
+ metric = r.Metric + ifacerow.Metric
+ }
+ if r.DestinationPrefix.PrefixLength == e.highestCIDR && metric > e.lowestMetric {
+ continue
+ }
+ e.lowestMetric = metric
+ e.highestCIDR = r.DestinationPrefix.PrefixLength
+ e.finalIsOurs = r.InterfaceLUID == ourLUID
+ if !e.finalIsOurs {
+ e.name = ifrow.Alias()
+ e.weakHostSend = ifacerow.ForwardingEnabled || ifacerow.WeakHostSend
+ }
+ }
+ }
+ problematicInterfaces := make(map[string]bool, len(endpoints))
+ for _, e := range endpoints {
+ if e.weakHostSend && e.finalIsOurs {
+ problematicInterfaces[e.name] = true
+ }
+ }
+ for iface := range problematicInterfaces {
+ log.Printf("Warning: the %q interface has Forwarding/WeakHostSend enabled, which will cause routing loops", iface)
+ }
+}
diff --git a/tunnel/scriptrunner.go b/tunnel/scriptrunner.go
index 450d8e21..eb97d98d 100644
--- a/tunnel/scriptrunner.go
+++ b/tunnel/scriptrunner.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
diff --git a/tunnel/service.go b/tunnel/service.go
index 63cd243f..a56ed1f3 100644
--- a/tunnel/service.go
+++ b/tunnel/service.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package tunnel
@@ -9,7 +9,6 @@ import (
"bytes"
"fmt"
"log"
- "net"
"os"
"runtime"
"time"
@@ -17,16 +16,12 @@ import (
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
"golang.org/x/sys/windows/svc/mgr"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/ipc"
- "golang.zx2c4.com/wireguard/tun"
-
"golang.zx2c4.com/wireguard/windows/conf"
+ "golang.zx2c4.com/wireguard/windows/driver"
"golang.zx2c4.com/wireguard/windows/elevate"
"golang.zx2c4.com/wireguard/windows/ringlogger"
"golang.zx2c4.com/wireguard/windows/services"
- "golang.zx2c4.com/wireguard/windows/version"
+ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
)
type tunnelService struct {
@@ -34,12 +29,12 @@ type tunnelService struct {
}
func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (svcSpecificEC bool, exitCode uint32) {
- changes <- svc.Status{State: svc.StartPending}
+ serviceState := svc.StartPending
+ changes <- svc.Status{State: serviceState}
- var dev *device.Device
- var uapi net.Listener
var watcher *interfaceWatcher
- var nativeTun *tun.NativeTun
+ var adapter *driver.Adapter
+ var luid winipcfg.LUID
var config *conf.Config
var err error
serviceError := services.ErrorSuccess
@@ -50,7 +45,8 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
if logErr != nil {
log.Println(logErr)
}
- changes <- svc.Status{State: svc.StopPending}
+ serviceState = svc.StopPending
+ changes <- svc.Status{State: serviceState}
stopIt := make(chan bool, 1)
go func() {
@@ -84,60 +80,63 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
}
}()
- if logErr == nil && dev != nil && config != nil {
+ if logErr == nil && adapter != nil && config != nil {
logErr = runScriptCommand(config.Interface.PreDown, config.Name)
}
if watcher != nil {
watcher.Destroy()
}
- if uapi != nil {
- uapi.Close()
- }
- if dev != nil {
- dev.Close()
+ if adapter != nil {
+ adapter.Close()
}
- if logErr == nil && dev != nil && config != nil {
+ if logErr == nil && adapter != nil && config != nil {
_ = runScriptCommand(config.Interface.PostDown, config.Name)
}
stopIt <- true
log.Println("Shutting down")
}()
- err = ringlogger.InitGlobalLogger("TUN")
+ var logFile string
+ logFile, err = conf.LogFile(true)
if err != nil {
serviceError = services.ErrorRingloggerOpen
return
}
-
- config, err = conf.LoadFromPath(service.Path)
+ err = ringlogger.InitGlobalLogger(logFile, "TUN")
if err != nil {
- serviceError = services.ErrorLoadConfiguration
+ serviceError = services.ErrorRingloggerOpen
return
}
- config.DeduplicateNetworkEntries()
- err = CopyConfigOwnerToIPCSecurityDescriptor(service.Path)
+
+ config, err = conf.LoadFromPath(service.Path)
if err != nil {
serviceError = services.ErrorLoadConfiguration
return
}
+ config.DeduplicateNetworkEntries()
log.SetPrefix(fmt.Sprintf("[%s] ", config.Name))
- log.Println("Starting", version.UserAgent())
+ services.PrintStarting()
- if m, err := mgr.Connect(); err == nil {
- if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked {
- /* If we don't do this, then the Wintun installation will block forever, because
- * installing a Wintun device starts a service too. Apparently at boot time, Windows
- * 8.1 locks the SCM for each service start, creating a deadlock if we don't announce
- * that we're running before starting additional services.
- */
- log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner)
- changes <- svc.Status{State: svc.Running}
+ if services.StartedAtBoot() {
+ if m, err := mgr.Connect(); err == nil {
+ if lockStatus, err := m.LockStatus(); err == nil && lockStatus.IsLocked {
+ /* If we don't do this, then the driver installation will block forever, because
+ * installing a network adapter starts the driver service too. Apparently at boot time,
+ * Windows 8.1 locks the SCM for each service start, creating a deadlock if we don't
+ * announce that we're running before starting additional services.
+ */
+ log.Printf("SCM locked for %v by %s, marking service as started", lockStatus.Age, lockStatus.Owner)
+ serviceState = svc.Running
+ changes <- svc.Status{State: serviceState}
+ }
+ m.Disconnect()
}
- m.Disconnect()
}
+ evaluateStaticPitfalls()
+
log.Println("Watching network interfaces")
watcher, err = watchInterface()
if err != nil {
@@ -146,34 +145,40 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
}
log.Println("Resolving DNS names")
- uapiConf, err := config.ToUAPI()
+ err = config.ResolveEndpoints()
if err != nil {
serviceError = services.ErrorDNSLookup
return
}
- log.Println("Creating Wintun interface")
- var wintun tun.Device
- for i := 0; i < 5; i++ {
+ log.Println("Creating network adapter")
+ for i := 0; i < 15; i++ {
if i > 0 {
time.Sleep(time.Second)
- log.Printf("Retrying Wintun creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
+ log.Printf("Retrying adapter creation after failure because system just booted (T+%v): %v", windows.DurationSinceBoot(), err)
}
- wintun, err = tun.CreateTUNWithRequestedGUID(config.Name, deterministicGUID(config), 0)
- if err == nil || windows.DurationSinceBoot() > time.Minute*4 {
+ adapter, err = driver.CreateAdapter(config.Name, "WireGuard", deterministicGUID(config))
+ if err == nil || !services.StartedAtBoot() {
break
}
}
if err != nil {
- serviceError = services.ErrorCreateWintun
+ err = fmt.Errorf("Error creating adapter: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
return
}
- nativeTun = wintun.(*tun.NativeTun)
- wintunVersion, err := nativeTun.RunningVersion()
+ luid = adapter.LUID()
+ driverVersion, err := driver.RunningVersion()
if err != nil {
- log.Printf("Warning: unable to determine Wintun version: %v", err)
+ log.Printf("Warning: unable to determine driver version: %v", err)
} else {
- log.Printf("Using Wintun/%d.%d", (wintunVersion>>16)&0xffff, wintunVersion&0xffff)
+ log.Printf("Using WireGuardNT/%d.%d", (driverVersion>>16)&0xffff, driverVersion&0xffff)
+ }
+ err = adapter.SetLogging(driver.AdapterLogOn)
+ if err != nil {
+ err = fmt.Errorf("Error enabling adapter logging: %w", err)
+ serviceError = services.ErrorCreateNetworkAdapter
+ return
}
err = runScriptCommand(config.Interface.PreUp, config.Name)
@@ -182,7 +187,7 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
return
}
- err = enableFirewall(config, nativeTun)
+ err = enableFirewall(config, luid)
if err != nil {
serviceError = services.ErrorFirewall
return
@@ -195,37 +200,18 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
return
}
- log.Println("Creating interface instance")
- bind := conn.NewDefaultBind()
- dev = device.NewDevice(wintun, bind, &device.Logger{log.Printf, log.Printf})
-
log.Println("Setting interface configuration")
- uapi, err = ipc.UAPIListen(config.Name)
+ err = adapter.SetConfiguration(config.ToDriverConfiguration())
if err != nil {
- serviceError = services.ErrorUAPIListen
+ serviceError = services.ErrorDeviceSetConfig
return
}
- err = dev.IpcSet(uapiConf)
+ err = adapter.SetAdapterState(driver.AdapterStateUp)
if err != nil {
- serviceError = services.ErrorDeviceSetConfig
+ serviceError = services.ErrorDeviceBringUp
return
}
-
- log.Println("Bringing peers up")
- dev.Up()
-
- watcher.Configure(bind.(conn.BindSocketToInterface), config, nativeTun)
-
- log.Println("Listening for UAPI requests")
- go func() {
- for {
- conn, err := uapi.Accept()
- if err != nil {
- continue
- }
- go dev.IpcHandle(conn)
- }
- }()
+ watcher.Configure(adapter, config, luid)
err = runScriptCommand(config.Interface.PostUp, config.Name)
if err != nil {
@@ -233,9 +219,9 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
return
}
- changes <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown}
- log.Println("Startup complete")
+ changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown}
+ var started bool
for {
select {
case c := <-r:
@@ -247,8 +233,13 @@ func (service *tunnelService) Execute(args []string, r <-chan svc.ChangeRequest,
default:
log.Printf("Unexpected service control request #%d\n", c)
}
- case <-dev.Wait():
- return
+ case <-watcher.started:
+ if !started {
+ serviceState = svc.Running
+ changes <- svc.Status{State: serviceState, Accepts: svc.AcceptStop | svc.AcceptShutdown}
+ log.Println("Startup complete")
+ started = true
+ }
case e := <-watcher.errors:
serviceError, err = e.serviceError, e.err
return
@@ -261,7 +252,7 @@ func Run(confPath string) error {
if err != nil {
return err
}
- serviceName, err := services.ServiceNameOfTunnel(name)
+ serviceName, err := conf.ServiceNameOfTunnel(name)
if err != nil {
return err
}
diff --git a/tunnel/winipcfg/interface_change_handler.go b/tunnel/winipcfg/interface_change_handler.go
index 4d229e78..af29801a 100644
--- a/tunnel/winipcfg/interface_change_handler.go
+++ b/tunnel/winipcfg/interface_change_handler.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/luid.go b/tunnel/winipcfg/luid.go
index 02ba65b4..0c898b89 100644
--- a/tunnel/winipcfg/luid.go
+++ b/tunnel/winipcfg/luid.go
@@ -1,13 +1,13 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
"errors"
- "net"
+ "net/netip"
"strings"
"golang.org/x/sys/windows"
@@ -76,10 +76,10 @@ func LUIDFromIndex(index uint32) (LUID, error) {
// IPAddress method returns MibUnicastIPAddressRow struct that matches to provided 'ip' argument. Corresponds to GetUnicastIpAddressEntry
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getunicastipaddressentry)
-func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) {
+func (luid LUID) IPAddress(addr netip.Addr) (*MibUnicastIPAddressRow, error) {
row := &MibUnicastIPAddressRow{InterfaceLUID: luid}
- err := row.Address.SetIP(ip, 0)
+ err := row.Address.SetAddr(addr)
if err != nil {
return nil, err
}
@@ -94,22 +94,24 @@ func (luid LUID) IPAddress(ip net.IP) (*MibUnicastIPAddressRow, error) {
// AddIPAddress method adds new unicast IP address to the interface. Corresponds to CreateUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry).
-func (luid LUID) AddIPAddress(address net.IPNet) error {
+func (luid LUID) AddIPAddress(address netip.Prefix) error {
row := &MibUnicastIPAddressRow{}
row.Init()
row.InterfaceLUID = luid
- err := row.Address.SetIP(address.IP, 0)
+ row.DadState = DadStatePreferred
+ row.ValidLifetime = 0xffffffff
+ row.PreferredLifetime = 0xffffffff
+ err := row.Address.SetAddr(address.Addr())
if err != nil {
return err
}
- ones, _ := address.Mask.Size()
- row.OnLinkPrefixLength = uint8(ones)
+ row.OnLinkPrefixLength = uint8(address.Bits())
return row.Create()
}
// AddIPAddresses method adds multiple new unicast IP addresses to the interface. Corresponds to CreateUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createunicastipaddressentry).
-func (luid LUID) AddIPAddresses(addresses []net.IPNet) error {
+func (luid LUID) AddIPAddresses(addresses []netip.Prefix) error {
for i := range addresses {
err := luid.AddIPAddress(addresses[i])
if err != nil {
@@ -120,7 +122,7 @@ func (luid LUID) AddIPAddresses(addresses []net.IPNet) error {
}
// SetIPAddresses method sets new unicast IP addresses to the interface.
-func (luid LUID) SetIPAddresses(addresses []net.IPNet) error {
+func (luid LUID) SetIPAddresses(addresses []netip.Prefix) error {
err := luid.FlushIPAddresses(windows.AF_UNSPEC)
if err != nil {
return err
@@ -129,16 +131,15 @@ func (luid LUID) SetIPAddresses(addresses []net.IPNet) error {
}
// SetIPAddressesForFamily method sets new unicast IP addresses for a specific family to the interface.
-func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.IPNet) error {
+func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []netip.Prefix) error {
err := luid.FlushIPAddresses(family)
if err != nil {
return err
}
for i := range addresses {
- asV4 := addresses[i].IP.To4()
- if asV4 == nil && family == windows.AF_INET {
+ if !addresses[i].Addr().Is4() && family == windows.AF_INET {
continue
- } else if asV4 != nil && family == windows.AF_INET6 {
+ } else if !addresses[i].Addr().Is6() && family == windows.AF_INET6 {
continue
}
err := luid.AddIPAddress(addresses[i])
@@ -151,17 +152,16 @@ func (luid LUID) SetIPAddressesForFamily(family AddressFamily, addresses []net.I
// DeleteIPAddress method deletes interface's unicast IP address. Corresponds to DeleteUnicastIpAddressEntry function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteunicastipaddressentry).
-func (luid LUID) DeleteIPAddress(address net.IPNet) error {
+func (luid LUID) DeleteIPAddress(address netip.Prefix) error {
row := &MibUnicastIPAddressRow{}
row.Init()
row.InterfaceLUID = luid
- err := row.Address.SetIP(address.IP, 0)
+ err := row.Address.SetAddr(address.Addr())
if err != nil {
return err
}
// Note: OnLinkPrefixLength member is ignored by DeleteUnicastIpAddressEntry().
- ones, _ := address.Mask.Size()
- row.OnLinkPrefixLength = uint8(ones)
+ row.OnLinkPrefixLength = uint8(address.Bits())
return row.Delete()
}
@@ -185,15 +185,17 @@ func (luid LUID) FlushIPAddresses(family AddressFamily) error {
// Route method returns route determined with the input arguments. Corresponds to GetIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-getipforwardentry2).
// NOTE: If the corresponding route isn't found, the method will return error.
-func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2, error) {
+func (luid LUID) Route(destination netip.Prefix, nextHop netip.Addr) (*MibIPforwardRow2, error) {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
- err := row.DestinationPrefix.SetIPNet(destination)
+ row.ValidLifetime = 0xffffffff
+ row.PreferredLifetime = 0xffffffff
+ err := row.DestinationPrefix.SetPrefix(destination)
if err != nil {
return nil, err
}
- err = row.NextHop.SetIP(nextHop, 0)
+ err = row.NextHop.SetAddr(nextHop)
if err != nil {
return nil, err
}
@@ -207,15 +209,15 @@ func (luid LUID) Route(destination net.IPNet, nextHop net.IP) (*MibIPforwardRow2
// AddRoute method adds a route to the interface. Corresponds to CreateIpForwardEntry2 function, with added splitDefault feature.
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-createipforwardentry2)
-func (luid LUID) AddRoute(destination net.IPNet, nextHop net.IP, metric uint32) error {
+func (luid LUID) AddRoute(destination netip.Prefix, nextHop netip.Addr, metric uint32) error {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
- err := row.DestinationPrefix.SetIPNet(destination)
+ err := row.DestinationPrefix.SetPrefix(destination)
if err != nil {
return err
}
- err = row.NextHop.SetIP(nextHop, 0)
+ err = row.NextHop.SetAddr(nextHop)
if err != nil {
return err
}
@@ -250,10 +252,9 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat
return err
}
for _, rd := range routesData {
- asV4 := rd.Destination.IP.To4()
- if asV4 == nil && family == windows.AF_INET {
+ if !rd.Destination.Addr().Is4() && family == windows.AF_INET {
continue
- } else if asV4 != nil && family == windows.AF_INET6 {
+ } else if !rd.Destination.Addr().Is6() && family == windows.AF_INET6 {
continue
}
err := luid.AddRoute(rd.Destination, rd.NextHop, rd.Metric)
@@ -266,15 +267,15 @@ func (luid LUID) SetRoutesForFamily(family AddressFamily, routesData []*RouteDat
// DeleteRoute method deletes a route that matches the criteria. Corresponds to DeleteIpForwardEntry2 function
// (https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/nf-netioapi-deleteipforwardentry2).
-func (luid LUID) DeleteRoute(destination net.IPNet, nextHop net.IP) error {
+func (luid LUID) DeleteRoute(destination netip.Prefix, nextHop netip.Addr) error {
row := &MibIPforwardRow2{}
row.Init()
row.InterfaceLUID = luid
- err := row.DestinationPrefix.SetIPNet(destination)
+ err := row.DestinationPrefix.SetPrefix(destination)
if err != nil {
return err
}
- err = row.NextHop.SetIP(nextHop, 0)
+ err = row.NextHop.SetAddr(nextHop)
if err != nil {
return err
}
@@ -307,17 +308,19 @@ func (luid LUID) FlushRoutes(family AddressFamily) error {
}
// DNS method returns all DNS server addresses associated with the adapter.
-func (luid LUID) DNS() ([]net.IP, error) {
+func (luid LUID) DNS() ([]netip.Addr, error) {
addresses, err := GetAdaptersAddresses(windows.AF_UNSPEC, GAAFlagDefault)
if err != nil {
return nil, err
}
- r := make([]net.IP, 0, len(addresses))
+ r := make([]netip.Addr, 0, len(addresses))
for _, addr := range addresses {
if addr.LUID == luid {
for dns := addr.FirstDNSServerAddress; dns != nil; dns = dns.Next {
if ip := dns.Address.IP(); ip != nil {
- r = append(r, ip)
+ if a, ok := netip.AddrFromSlice(ip); ok {
+ r = append(r, a)
+ }
} else {
return nil, windows.ERROR_INVALID_PARAMETER
}
@@ -328,17 +331,15 @@ func (luid LUID) DNS() ([]net.IP, error) {
}
// SetDNS method clears previous and associates new DNS servers and search domains with the adapter for a specific family.
-func (luid LUID) SetDNS(family AddressFamily, servers []net.IP, domains []string) error {
+func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []string) error {
if family != windows.AF_INET && family != windows.AF_INET6 {
return windows.ERROR_PROTOCOL_UNREACHABLE
}
var filteredServers []string
for _, server := range servers {
- if v4 := server.To4(); v4 != nil && family == windows.AF_INET {
- filteredServers = append(filteredServers, v4.String())
- } else if v6 := server.To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 {
- filteredServers = append(filteredServers, v6.String())
+ if (server.Is4() && family == windows.AF_INET) || (server.Is6() && family == windows.AF_INET6) {
+ filteredServers = append(filteredServers, server.String())
}
}
servers16, err := windows.UTF16PtrFromString(strings.Join(filteredServers, ","))
@@ -353,17 +354,17 @@ func (luid LUID) SetDNS(family AddressFamily, servers []net.IP, domains []string
if err != nil {
return err
}
- var maybeV6 uint64
+ dnsInterfaceSettings := &DnsInterfaceSettings{
+ Version: DnsInterfaceSettingsVersion1,
+ Flags: DnsInterfaceSettingsFlagNameserver | DnsInterfaceSettingsFlagSearchList,
+ NameServer: servers16,
+ SearchList: domains16,
+ }
if family == windows.AF_INET6 {
- maybeV6 = disFlagsIPv6
+ dnsInterfaceSettings.Flags |= DnsInterfaceSettingsFlagIPv6
}
// For >= Windows 10 1809
- err = setInterfaceDnsSettings(*guid, &dnsInterfaceSettings{
- Version: disVersion1,
- Flags: disFlagsNameServer | disFlagsSearchList | maybeV6,
- NameServer: servers16,
- SearchList: domains16,
- })
+ err = SetInterfaceDnsSettings(*guid, dnsInterfaceSettings)
if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
return err
}
diff --git a/tunnel/winipcfg/mksyscall.go b/tunnel/winipcfg/mksyscall.go
index e9e06676..d62d38df 100644
--- a/tunnel/winipcfg/mksyscall.go
+++ b/tunnel/winipcfg/mksyscall.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/netsh.go b/tunnel/winipcfg/netsh.go
index 1f3d12d0..4f8e5b13 100644
--- a/tunnel/winipcfg/netsh.go
+++ b/tunnel/winipcfg/netsh.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
@@ -10,7 +10,7 @@ import (
"errors"
"fmt"
"io"
- "net"
+ "net/netip"
"os/exec"
"path/filepath"
"strings"
@@ -57,7 +57,7 @@ const (
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
)
-func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) error {
+func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error {
var templateFlush string
if family == windows.AF_INET {
templateFlush = netshCmdTemplateFlush4
@@ -72,10 +72,10 @@ func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []net.IP) e
}
cmds = append(cmds, fmt.Sprintf(templateFlush, ipif.InterfaceIndex))
for i := 0; i < len(dnses); i++ {
- if v4 := dnses[i].To4(); v4 != nil && family == windows.AF_INET {
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, v4.String()))
- } else if v6 := dnses[i].To16(); v4 == nil && v6 != nil && family == windows.AF_INET6 {
- cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, v6.String()))
+ if dnses[i].Is4() && family == windows.AF_INET {
+ cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd4, ipif.InterfaceIndex, dnses[i].String()))
+ } else if dnses[i].Is6() && family == windows.AF_INET6 {
+ cmds = append(cmds, fmt.Sprintf(netshCmdTemplateAdd6, ipif.InterfaceIndex, dnses[i].String()))
}
}
return runNetsh(cmds)
diff --git a/tunnel/winipcfg/route_change_handler.go b/tunnel/winipcfg/route_change_handler.go
index 75dcf82c..4b78331e 100644
--- a/tunnel/winipcfg/route_change_handler.go
+++ b/tunnel/winipcfg/route_change_handler.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/types.go b/tunnel/winipcfg/types.go
index 02f7f788..8e8f4a59 100644
--- a/tunnel/winipcfg/types.go
+++ b/tunnel/winipcfg/types.go
@@ -1,12 +1,15 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
import (
- "net"
+ "encoding/binary"
+ "fmt"
+ "net/netip"
+ "strconv"
"unsafe"
"golang.org/x/sys/windows"
@@ -582,11 +585,15 @@ const (
// RouteData structure describes a route to add
type RouteData struct {
- Destination net.IPNet
- NextHop net.IP
+ Destination netip.Prefix
+ NextHop netip.Addr
Metric uint32
}
+func (routeData *RouteData) String() string {
+ return fmt.Sprintf("%+v", *routeData)
+}
+
// IPAdapterDNSSuffix structure stores a DNS suffix in a linked list of DNS suffixes for a particular adapter.
// https://docs.microsoft.com/en-us/windows/desktop/api/iptypes/ns-iptypes-_ip_adapter_dns_suffix
type IPAdapterDNSSuffix struct {
@@ -679,8 +686,7 @@ func (row *MibIPInterfaceRow) Set() error {
// get method returns all table rows as a Go slice.
func (tab *mibIPInterfaceTable) get() (s []MibIPInterfaceRow) {
- unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries))
- return
+ return unsafe.Slice(&tab.table[0], tab.numEntries)
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
@@ -717,8 +723,7 @@ func (row *MibIfRow2) get() (ret error) {
// get method returns all table rows as a Go slice.
func (tab *mibIfTable2) get() (s []MibIfRow2) {
- unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries))
- return
+ return unsafe.Slice(&tab.table[0], tab.numEntries)
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
@@ -734,45 +739,82 @@ type RawSockaddrInet struct {
data [26]byte
}
-// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port.
+func ntohs(i uint16) uint16 {
+ return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&i))[:])
+}
+
+func htons(i uint16) uint16 {
+ b := make([]byte, 2)
+ binary.BigEndian.PutUint16(b, i)
+ return *(*uint16)(unsafe.Pointer(&b[0]))
+}
+
+// SetAddrPort method sets family, address, and port to the given IPv4 or IPv6 address and port.
// All other members of the structure are set to zero.
-func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error {
- if v4 := ip.To4(); v4 != nil {
+func (addr *RawSockaddrInet) SetAddrPort(addrPort netip.AddrPort) error {
+ if addrPort.Addr().Is4() {
addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr))
addr4.Family = windows.AF_INET
- copy(addr4.Addr[:], v4)
- addr4.Port = port
+ addr4.Addr = addrPort.Addr().As4()
+ addr4.Port = htons(addrPort.Port())
for i := 0; i < 8; i++ {
addr4.Zero[i] = 0
}
return nil
- }
-
- if v6 := ip.To16(); v6 != nil {
+ } else if addrPort.Addr().Is6() {
addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr))
addr6.Family = windows.AF_INET6
- addr6.Port = port
+ addr6.Addr = addrPort.Addr().As16()
+ addr6.Port = htons(addrPort.Port())
addr6.Flowinfo = 0
- copy(addr6.Addr[:], v6)
- addr6.Scope_id = 0
+ scopeId := uint32(0)
+ if z := addrPort.Addr().Zone(); z != "" {
+ if s, err := strconv.ParseUint(z, 10, 32); err == nil {
+ scopeId = uint32(s)
+ }
+ }
+ addr6.Scope_id = scopeId
return nil
}
-
return windows.ERROR_INVALID_PARAMETER
}
-// IP method returns IPv4 or IPv6 address.
-// If the address is neither IPv4 not IPv6 nil is returned.
-func (addr *RawSockaddrInet) IP() net.IP {
+// SetAddr method sets family and address to the given IPv4 or IPv6 address.
+// All other members of the structure are set to zero.
+func (addr *RawSockaddrInet) SetAddr(netAddr netip.Addr) error {
+ return addr.SetAddrPort(netip.AddrPortFrom(netAddr, 0))
+}
+
+// AddrPort returns the IP address and port.
+func (addr *RawSockaddrInet) AddrPort() netip.AddrPort {
+ return netip.AddrPortFrom(addr.Addr(), addr.Port())
+}
+
+// Addr returns IPv4 or IPv6 address, or an invalid address if the address is neither.
+func (addr *RawSockaddrInet) Addr() netip.Addr {
switch addr.Family {
case windows.AF_INET:
- return (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr[:]
-
+ return netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Addr)
case windows.AF_INET6:
- return (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Addr[:]
+ raw := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr))
+ a := netip.AddrFrom16(raw.Addr)
+ if raw.Scope_id != 0 {
+ a = a.WithZone(strconv.FormatUint(uint64(raw.Scope_id), 10))
+ }
+ return a
}
+ return netip.Addr{}
+}
- return nil
+// Port returns the port if the address if IPv4 or IPv6, or 0 if neither.
+func (addr *RawSockaddrInet) Port() uint16 {
+ switch addr.Family {
+ case windows.AF_INET:
+ return ntohs((*windows.RawSockaddrInet4)(unsafe.Pointer(addr)).Port)
+ case windows.AF_INET6:
+ return ntohs((*windows.RawSockaddrInet6)(unsafe.Pointer(addr)).Port)
+ }
+ return 0
}
// Init method initializes a MibUnicastIPAddressRow structure with default values for a unicast IP address entry on the local computer.
@@ -807,8 +849,7 @@ func (row *MibUnicastIPAddressRow) Delete() error {
// get method returns all table rows as a Go slice.
func (tab *mibUnicastIPAddressTable) get() (s []MibUnicastIPAddressRow) {
- unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries))
- return
+ return unsafe.Slice(&tab.table[0], tab.numEntries)
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
@@ -837,8 +878,7 @@ func (row *MibAnycastIPAddressRow) Delete() error {
// get method returns all table rows as a Go slice.
func (tab *mibAnycastIPAddressTable) get() (s []MibAnycastIPAddressRow) {
- unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries))
- return
+ return unsafe.Slice(&tab.table[0], tab.numEntries)
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
@@ -850,32 +890,30 @@ func (tab *mibAnycastIPAddressTable) free() {
// IPAddressPrefix structure stores an IP address prefix.
// https://docs.microsoft.com/en-us/windows/desktop/api/netioapi/ns-netioapi-_ip_address_prefix
type IPAddressPrefix struct {
- Prefix RawSockaddrInet
+ RawPrefix RawSockaddrInet
PrefixLength uint8
_ [2]byte
}
-// SetIPNet method sets IP address prefix using net.IPNet.
-func (prefix *IPAddressPrefix) SetIPNet(net net.IPNet) error {
- err := prefix.Prefix.SetIP(net.IP, 0)
+// SetPrefix method sets IP address prefix using netip.Prefix.
+func (prefix *IPAddressPrefix) SetPrefix(netPrefix netip.Prefix) error {
+ err := prefix.RawPrefix.SetAddr(netPrefix.Addr())
if err != nil {
return err
}
- ones, _ := net.Mask.Size()
- prefix.PrefixLength = uint8(ones)
+ prefix.PrefixLength = uint8(netPrefix.Bits())
return nil
}
-// IPNet method returns IP address prefix as net.IPNet.
-// If the address is neither IPv4 not IPv6 an empty net.IPNet is returned. The resulting net.IPNet should be checked appropriately.
-func (prefix *IPAddressPrefix) IPNet() net.IPNet {
- switch prefix.Prefix.Family {
+// Prefix returns IP address prefix as netip.Prefix.
+func (prefix *IPAddressPrefix) Prefix() netip.Prefix {
+ switch prefix.RawPrefix.Family {
case windows.AF_INET:
- return net.IPNet{IP: (*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv4len)}
+ return netip.PrefixFrom(netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength))
case windows.AF_INET6:
- return net.IPNet{IP: (*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.Prefix)).Addr[:], Mask: net.CIDRMask(int(prefix.PrefixLength), 8*net.IPv6len)}
+ return netip.PrefixFrom(netip.AddrFrom16((*windows.RawSockaddrInet6)(unsafe.Pointer(&prefix.RawPrefix)).Addr), int(prefix.PrefixLength))
}
- return net.IPNet{}
+ return netip.Prefix{}
}
// MibIPforwardRow2 structure stores information about an IP route entry.
@@ -930,8 +968,7 @@ func (row *MibIPforwardRow2) Delete() error {
// get method returns all table rows as a Go slice.
func (tab *mibIPforwardTable2) get() (s []MibIPforwardRow2) {
- unsafeSlice(unsafe.Pointer(&s), unsafe.Pointer(&tab.table[0]), int(tab.numEntries))
- return
+ return unsafe.Slice(&tab.table[0], tab.numEntries)
}
// free method frees the buffer allocated by the functions that return tables of network interfaces, addresses, and routes.
@@ -941,11 +978,11 @@ func (tab *mibIPforwardTable2) free() {
}
//
-// Undocumented DNS API
+// DNS API
//
-// dnsInterfaceSettings is mean to be used with setInterfaceDnsSettings
-type dnsInterfaceSettings struct {
+// DnsInterfaceSettings is meant to be used with SetInterfaceDnsSettings
+type DnsInterfaceSettings struct {
Version uint32
_ [4]byte
Flags uint64
@@ -960,37 +997,22 @@ type dnsInterfaceSettings struct {
}
const (
- disVersion1 = 1
- disVersion2 = 2
-
- disFlagsIPv6 = 0x1
- disFlagsNameServer = 0x2
- disFlagsSearchList = 0x4
- disFlagsRegistrationEnabled = 0x8
- disFlagsRegisterAdapterName = 0x10
- disFlagsDomain = 0x20
- disFlagsHostname = 0x40 // ??
- disFlagsEnableLLMNR = 0x80
- disFlagsQueryAdapterName = 0x100
- disFlagsProfileNameServer = 0x200
- disFlagsVersion2 = 0x400 // ?? - v2 only
- disFlagsMoreFlags = 0x800 // ?? - v2 only
+ DnsInterfaceSettingsVersion1 = 1 // for DnsInterfaceSettings
+ DnsInterfaceSettingsVersion2 = 2 // for DnsInterfaceSettingsEx
+ DnsInterfaceSettingsVersion3 = 3 // for DnsInterfaceSettings3
+
+ DnsInterfaceSettingsFlagIPv6 = 0x0001
+ DnsInterfaceSettingsFlagNameserver = 0x0002
+ DnsInterfaceSettingsFlagSearchList = 0x0004
+ DnsInterfaceSettingsFlagRegistrationEnabled = 0x0008
+ DnsInterfaceSettingsFlagRegisterAdapterName = 0x0010
+ DnsInterfaceSettingsFlagDomain = 0x0020
+ DnsInterfaceSettingsFlagHostname = 0x0040
+ DnsInterfaceSettingsFlagEnableLLMNR = 0x0080
+ DnsInterfaceSettingsFlagQueryAdapterName = 0x0100
+ DnsInterfaceSettingsFlagProfileNameserver = 0x0200
+ DnsInterfaceSettingsFlagDisableUnconstrainedQueries = 0x0400 // v2 only
+ DnsInterfaceSettingsFlagSupplementalSearchList = 0x0800 // v2 only
+ DnsInterfaceSettingsFlagDOH = 0x1000 // v3 only
+ DnsInterfaceSettingsFlagDOHProfile = 0x2000 // v3 only
)
-
-// unsafeSlice updates the slice slicePtr to be a slice
-// referencing the provided data with its length & capacity set to
-// lenCap.
-//
-// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
-// update callers to use unsafe.Slice instead of this.
-func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
- type sliceHeader struct {
- Data unsafe.Pointer
- Len int
- Cap int
- }
- h := (*sliceHeader)(slicePtr)
- h.Data = data
- h.Len = lenCap
- h.Cap = lenCap
-}
diff --git a/tunnel/winipcfg/types_32.go b/tunnel/winipcfg/types_32.go
index 51a8d31c..1a8d4443 100644
--- a/tunnel/winipcfg/types_32.go
+++ b/tunnel/winipcfg/types_32.go
@@ -1,8 +1,8 @@
-// +build 386 arm
+//go:build 386 || arm
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/types_64.go b/tunnel/winipcfg/types_64.go
index 6623ce54..3a1fe07f 100644
--- a/tunnel/winipcfg/types_64.go
+++ b/tunnel/winipcfg/types_64.go
@@ -1,8 +1,8 @@
-// +build amd64 arm64
+//go:build amd64 || arm64
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/types_test.go b/tunnel/winipcfg/types_test.go
index 26268dbe..b72d73f5 100644
--- a/tunnel/winipcfg/types_test.go
+++ b/tunnel/winipcfg/types_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
@@ -957,7 +957,6 @@ func TestIPAddressPrefix(t *testing.T) {
offset := uintptr(unsafe.Pointer(&s.PrefixLength)) - sp
if offset != ipAddressPrefixPrefixLengthOffset {
t.Errorf("IPAddressPrefix.PrefixLength offset is %d although %d is expected", offset, ipAddressPrefixPrefixLengthOffset)
-
}
}
diff --git a/tunnel/winipcfg/types_test_32.go b/tunnel/winipcfg/types_test_32.go
index fa3f91d5..9e62bfef 100644
--- a/tunnel/winipcfg/types_test_32.go
+++ b/tunnel/winipcfg/types_test_32.go
@@ -1,8 +1,8 @@
-// +build 386 arm
+//go:build 386 || arm
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/types_test_64.go b/tunnel/winipcfg/types_test_64.go
index d20cdb30..8a181575 100644
--- a/tunnel/winipcfg/types_test_64.go
+++ b/tunnel/winipcfg/types_test_64.go
@@ -1,8 +1,8 @@
-// +build amd64 arm64
+//go:build amd64 || arm64
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/unicast_address_change_handler.go b/tunnel/winipcfg/unicast_address_change_handler.go
index 517ff037..cf4fcb3a 100644
--- a/tunnel/winipcfg/unicast_address_change_handler.go
+++ b/tunnel/winipcfg/unicast_address_change_handler.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
diff --git a/tunnel/winipcfg/winipcfg.go b/tunnel/winipcfg/winipcfg.go
index dd09d3a0..e24157b9 100644
--- a/tunnel/winipcfg/winipcfg.go
+++ b/tunnel/winipcfg/winipcfg.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
package winipcfg
@@ -170,18 +170,18 @@ func GetIPForwardTable2(family AddressFamily) ([]MibIPforwardRow2, error) {
//sys cancelMibChangeNotify2(notificationHandle windows.Handle) (ret error) = iphlpapi.CancelMibChangeNotify2
//
-// Undocumented DNS API
+// DNS-related functions
//
-//sys setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
-//sys setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
-//sys setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *dnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
+//sys setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
+//sys setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
+//sys setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) = iphlpapi.SetInterfaceDnsSettings?
// The GUID is passed by value, not by reference, which means different
// things on different calling conventions. On amd64, this means it's
// passed by reference anyway, while on arm, arm64, and 386, it's split
// into words.
-func setInterfaceDnsSettings(guid windows.GUID, settings *dnsInterfaceSettings) error {
+func SetInterfaceDnsSettings(guid windows.GUID, settings *DnsInterfaceSettings) error {
words := (*[4]uintptr)(unsafe.Pointer(&guid))
switch runtime.GOARCH {
case "amd64":
diff --git a/tunnel/winipcfg/winipcfg_test.go b/tunnel/winipcfg/winipcfg_test.go
index 7689a0c1..b49daf33 100644
--- a/tunnel/winipcfg/winipcfg_test.go
+++ b/tunnel/winipcfg/winipcfg_test.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
*/
/*
@@ -8,8 +8,8 @@
Some tests in this file require:
- A dedicated network adapter
- Any network adapter will do. It may be virtual (Wintun etc.). The adapter name
- must contain string "winipcfg_test".
+ Any network adapter will do. It may be virtual (WireGuardNT, Wintun,
+ etc.). The adapter name must contain string "winipcfg_test".
Tests will add, remove, flush DNS servers, change adapter IP address, manipulate
routes etc.
The adapter will not be returned to previous state, so use an expendable one.
@@ -22,8 +22,7 @@ Some tests in this file require:
package winipcfg
import (
- "bytes"
- "net"
+ "net/netip"
"strings"
"syscall"
"testing"
@@ -38,22 +37,13 @@ const (
// TODO: Add IPv6 tests.
var (
- unexistentIPAddresToAdd = net.IPNet{
- IP: net.IP{172, 16, 1, 114},
- Mask: net.IPMask{255, 255, 255, 0},
- }
- unexistentRouteIPv4ToAdd = RouteData{
- Destination: net.IPNet{
- IP: net.IP{172, 16, 200, 0},
- Mask: net.IPMask{255, 255, 255, 0},
- },
- NextHop: net.IP{172, 16, 1, 2},
- Metric: 0,
- }
- dnsesToSet = []net.IP{
- net.IPv4(8, 8, 8, 8),
- net.IPv4(8, 8, 4, 4),
+ nonexistantIPv4ToAdd = netip.MustParsePrefix("172.16.1.114/24")
+ nonexistentRouteIPv4ToAdd = RouteData{
+ Destination: netip.MustParsePrefix("172.16.200.0/24"),
+ NextHop: netip.MustParseAddr("172.16.1.2"),
+ Metric: 0,
}
+ dnsesToSet = []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}
)
func runningElevated() bool {
@@ -74,7 +64,7 @@ func getTestInterface() (*IPAdapterAddresses, error) {
marker := strings.ToLower(testInterfaceMarker)
for _, ifc := range ifcs {
- if strings.Index(strings.ToLower(ifc.FriendlyName()), marker) != -1 {
+ if strings.Contains(strings.ToLower(ifc.FriendlyName()), marker) {
return ifc, nil
}
}
@@ -380,9 +370,9 @@ func TestAddDeleteIPAddress(t *testing.T) {
return
}
- addr, err := ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP)
+ addr, err := ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr())
if err == nil {
- t.Errorf("Unicast address %s already exists. Please set unexistentIPAddresToAdd appropriately.", unexistentIPAddresToAdd.IP.String())
+ t.Errorf("Unicast address %s already exists. Please set nonexistantIPv4ToAdd appropriately.", nonexistantIPv4ToAdd.Addr().String())
return
} else if err != windows.ERROR_NOT_FOUND {
t.Errorf("LUID.IPAddress() returned an error: %w", err)
@@ -410,7 +400,7 @@ func TestAddDeleteIPAddress(t *testing.T) {
for addr := ifc.FirstUnicastAddress; addr != nil; addr = addr.Next {
count--
}
- err = ifc.LUID.AddIPAddresses([]net.IPNet{unexistentIPAddresToAdd})
+ err = ifc.LUID.AddIPAddresses([]netip.Prefix{nonexistantIPv4ToAdd})
if err != nil {
t.Errorf("LUID.AddIPAddresses() returned an error: %w", err)
}
@@ -424,26 +414,26 @@ func TestAddDeleteIPAddress(t *testing.T) {
if count != 1 {
t.Errorf("After adding there are %d new interface(s).", count)
}
- addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP)
+ addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr())
if err != nil {
t.Errorf("LUID.IPAddress() returned an error: %w", err)
} else if addr == nil {
- t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", unexistentIPAddresToAdd.IP.String())
+ t.Errorf("Unicast address %s still doesn't exist, although it's added successfully.", nonexistantIPv4ToAdd.Addr().String())
}
if !created {
t.Errorf("Notification handler has not been called on add.")
}
- err = ifc.LUID.DeleteIPAddress(unexistentIPAddresToAdd)
+ err = ifc.LUID.DeleteIPAddress(nonexistantIPv4ToAdd)
if err != nil {
t.Errorf("LUID.DeleteIPAddress() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
- addr, err = ifc.LUID.IPAddress(unexistentIPAddresToAdd.IP)
+ addr, err = ifc.LUID.IPAddress(nonexistantIPv4ToAdd.Addr())
if err == nil {
- t.Errorf("Unicast address %s still exists, although it's deleted successfully.", unexistentIPAddresToAdd.IP.String())
+ t.Errorf("Unicast address %s still exists, although it's deleted successfully.", nonexistantIPv4ToAdd.Addr().String())
} else if err != windows.ERROR_NOT_FOUND {
t.Errorf("LUID.IPAddress() returned an error: %w", err)
}
@@ -460,14 +450,13 @@ func TestGetRoutes(t *testing.T) {
}
func TestAddDeleteRoute(t *testing.T) {
- findRoute := func(luid LUID, dest net.IPNet) ([]MibIPforwardRow2, error) {
+ findRoute := func(luid LUID, dest netip.Prefix) ([]MibIPforwardRow2, error) {
var family AddressFamily
- switch {
- case dest.IP.To4() != nil:
+ if dest.Addr().Is4() {
family = windows.AF_INET
- case dest.IP.To16() != nil:
+ } else if dest.Addr().Is6() {
family = windows.AF_INET6
- default:
+ } else {
return nil, windows.ERROR_INVALID_PARAMETER
}
r, err := GetIPForwardTable2(family)
@@ -475,9 +464,8 @@ func TestAddDeleteRoute(t *testing.T) {
return nil, err
}
matches := make([]MibIPforwardRow2, 0, len(r))
- ones, _ := dest.Mask.Size()
for _, route := range r {
- if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(ones) && route.DestinationPrefix.Prefix.Family == family && route.DestinationPrefix.Prefix.IP().Equal(dest.IP) {
+ if route.InterfaceLUID == luid && route.DestinationPrefix.PrefixLength == uint8(dest.Bits()) && route.DestinationPrefix.RawPrefix.Family == family && route.DestinationPrefix.RawPrefix.Addr() == dest.Addr() {
matches = append(matches, route)
}
}
@@ -494,20 +482,20 @@ func TestAddDeleteRoute(t *testing.T) {
return
}
- _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err == nil {
- t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?")
+ t.Error("LUID.Route() returned a route although it isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?")
return
} else if err != windows.ERROR_NOT_FOUND {
t.Errorf("LUID.Route() returned an error: %w", err)
return
}
- routes, err := findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination)
+ routes, err := findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination)
if err != nil {
t.Errorf("findRoute() returned an error: %w", err)
} else if len(routes) != 0 {
- t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set unexistentRouteIPv4ToAdd appropriately?", len(routes))
+ t.Errorf("findRoute() returned %d items although the route isn't added yet. Have you forgot to set nonexistentRouteIPv4ToAdd appropriately?", len(routes))
}
var created, deleted bool
@@ -524,42 +512,42 @@ func TestAddDeleteRoute(t *testing.T) {
} else {
defer cb.Unregister()
}
- err = ifc.LUID.AddRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop, unexistentRouteIPv4ToAdd.Metric)
+ err = ifc.LUID.AddRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop, nonexistentRouteIPv4ToAdd.Metric)
if err != nil {
t.Errorf("LUID.AddRoute() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
- route, err := ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ route, err := ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err == windows.ERROR_NOT_FOUND {
t.Error("LUID.Route() returned nil although the route is added successfully.")
} else if err != nil {
t.Errorf("LUID.Route() returned an error: %w", err)
- } else if !route.DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) || !route.NextHop.IP().Equal(unexistentRouteIPv4ToAdd.NextHop) {
+ } else if route.DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() || route.NextHop.Addr() != nonexistentRouteIPv4ToAdd.NextHop {
t.Error("LUID.Route() returned a wrong route!")
}
if !created {
t.Errorf("Route handler has not been called on add.")
}
- routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination)
+ routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination)
if err != nil {
t.Errorf("findRoute() returned an error: %w", err)
} else if len(routes) != 1 {
t.Errorf("findRoute() returned %d items although %d is expected.", len(routes), 1)
- } else if !routes[0].DestinationPrefix.Prefix.IP().Equal(unexistentRouteIPv4ToAdd.Destination.IP) {
- t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.Prefix.IP().String(), unexistentRouteIPv4ToAdd.Destination.IP.String())
+ } else if routes[0].DestinationPrefix.RawPrefix.Addr() != nonexistentRouteIPv4ToAdd.Destination.Addr() {
+ t.Errorf("findRoute() returned a wrong route. Dest: %s; expected: %s.", routes[0].DestinationPrefix.RawPrefix.Addr().String(), nonexistentRouteIPv4ToAdd.Destination.Addr().String())
}
- err = ifc.LUID.DeleteRoute(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ err = ifc.LUID.DeleteRoute(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err != nil {
t.Errorf("LUID.DeleteRoute() returned an error: %w", err)
}
time.Sleep(500 * time.Millisecond)
- _, err = ifc.LUID.Route(unexistentRouteIPv4ToAdd.Destination, unexistentRouteIPv4ToAdd.NextHop)
+ _, err = ifc.LUID.Route(nonexistentRouteIPv4ToAdd.Destination, nonexistentRouteIPv4ToAdd.NextHop)
if err == nil {
t.Error("LUID.Route() returned a route although it is removed successfully.")
} else if err != windows.ERROR_NOT_FOUND {
@@ -569,7 +557,7 @@ func TestAddDeleteRoute(t *testing.T) {
t.Errorf("Route handler has not been called on delete.")
}
- routes, err = findRoute(ifc.LUID, unexistentRouteIPv4ToAdd.Destination)
+ routes, err = findRoute(ifc.LUID, nonexistentRouteIPv4ToAdd.Destination)
if err != nil {
t.Errorf("findRoute() returned an error: %w", err)
} else if len(routes) != 0 {
@@ -606,7 +594,7 @@ func TestFlushDNS(t *testing.T) {
t.Errorf("LUID.DNS() returned an error: %w", err)
}
for _, a := range dns {
- if len(a) != 16 || a.To4() != nil || !((a[15] == 1 || a[15] == 2 || a[15] == 3) && bytes.HasPrefix(a, []byte{0xfe, 0xc0, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})) {
+ if a.Is4() {
n++
}
}
@@ -651,7 +639,7 @@ func TestSetDNS(t *testing.T) {
t.Errorf("dnsesToSet contains %d items, while DNSServerAddresses contains %d.", len(dnsesToSet), len(newDNSes))
} else {
for i := range dnsesToSet {
- if !dnsesToSet[i].Equal(newDNSes[i]) {
+ if dnsesToSet[i] != newDNSes[i] {
t.Errorf("dnsesToSet[%d] = %s while DNSServerAddresses[%d] = %s.", i, dnsesToSet[i].String(), i, newDNSes[i].String())
}
}
diff --git a/tunnel/winipcfg/zwinipcfg_windows.go b/tunnel/winipcfg/zwinipcfg_windows.go
index ac89fec1..3a0d8680 100644
--- a/tunnel/winipcfg/zwinipcfg_windows.go
+++ b/tunnel/winipcfg/zwinipcfg_windows.go
@@ -289,7 +289,7 @@ func notifyUnicastIPAddressChange(family AddressFamily, callback uintptr, caller
return
}
-func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *dnsInterfaceSettings) (ret error) {
+func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr, guid4 uintptr, settings *DnsInterfaceSettings) (ret error) {
ret = procSetInterfaceDnsSettings.Find()
if ret != nil {
return
@@ -301,24 +301,24 @@ func setInterfaceDnsSettingsByDwords(guid1 uintptr, guid2 uintptr, guid3 uintptr
return
}
-func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *dnsInterfaceSettings) (ret error) {
+func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *DnsInterfaceSettings) (ret error) {
ret = procSetInterfaceDnsSettings.Find()
if ret != nil {
return
}
- r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0)
+ r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings)))
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
-func setInterfaceDnsSettingsByQwords(guid1 uintptr, guid2 uintptr, settings *dnsInterfaceSettings) (ret error) {
+func setInterfaceDnsSettingsByPtr(guid *windows.GUID, settings *DnsInterfaceSettings) (ret error) {
ret = procSetInterfaceDnsSettings.Find()
if ret != nil {
return
}
- r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 3, uintptr(guid1), uintptr(guid2), uintptr(unsafe.Pointer(settings)))
+ r0, _, _ := syscall.Syscall(procSetInterfaceDnsSettings.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(settings)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
diff --git a/tunnel/wintun_test.go b/tunnel/wintun_test.go
deleted file mode 100644
index 4e56ff65..00000000
--- a/tunnel/wintun_test.go
+++ /dev/null
@@ -1,202 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
- */
-
-package tunnel_test
-
-import (
- "bytes"
- "crypto/rand"
- "encoding/binary"
- "fmt"
- "net"
- "sync"
- "testing"
- "time"
-
- "golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/tun"
-
- "golang.zx2c4.com/wireguard/windows/elevate"
- "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
-)
-
-func TestWintunOrdering(t *testing.T) {
- var tunDevice tun.Device
- err := elevate.DoAsSystem(func() error {
- var err error
- tunDevice, err = tun.CreateTUNWithRequestedGUID("tunordertest", &windows.GUID{12, 12, 12, [8]byte{12, 12, 12, 12, 12, 12, 12, 12}}, 1500)
- return err
- })
- if err != nil {
- t.Fatal(err)
- }
- defer tunDevice.Close()
- nativeTunDevice := tunDevice.(*tun.NativeTun)
- luid := winipcfg.LUID(nativeTunDevice.LUID())
- ip, ipnet, _ := net.ParseCIDR("10.82.31.4/24")
- err = luid.SetIPAddresses([]net.IPNet{{ip, ipnet.Mask}})
- if err != nil {
- t.Fatal(err)
- }
- err = luid.SetRoutes([]*winipcfg.RouteData{{*ipnet, ipnet.IP, 0}})
- if err != nil {
- t.Fatal(err)
- }
- var token [32]byte
- _, err = rand.Read(token[:])
- if err != nil {
- t.Fatal(err)
- }
- var sockWrite net.Conn
- for i := 0; i < 1000; i++ {
- sockWrite, err = net.Dial("udp", "10.82.31.5:9999")
- if err == nil {
- defer sockWrite.Close()
- break
- }
- time.Sleep(time.Millisecond * 100)
- }
- if err != nil {
- t.Fatal(err)
- }
- var sockRead *net.UDPConn
- for i := 0; i < 1000; i++ {
- var listenAddress *net.UDPAddr
- listenAddress, err = net.ResolveUDPAddr("udp", "10.82.31.4:9999")
- if err != nil {
- continue
- }
- sockRead, err = net.ListenUDP("udp", listenAddress)
- if err == nil {
- defer sockRead.Close()
- break
- }
- time.Sleep(time.Millisecond * 100)
- }
- if err != nil {
- t.Fatal(err)
- }
- var wait sync.WaitGroup
- wait.Add(4)
- doneSockWrite := false
- doneTunWrite := false
- fatalErrors := make(chan error, 2)
- errors := make(chan error, 2)
- go func() {
- defer wait.Done()
- buffer := append(token[:], 0, 0, 0, 0, 0, 0, 0, 0)
- for sendingIndex := uint64(0); !doneSockWrite; sendingIndex++ {
- binary.LittleEndian.PutUint64(buffer[32:], sendingIndex)
- _, err := sockWrite.Write(buffer[:])
- if err != nil {
- fatalErrors <- err
- }
- }
- }()
- go func() {
- defer wait.Done()
- packet := [20 + 8 + 32 + 8]byte{
- 0x45, 0, 0, 20 + 8 + 32 + 8,
- 0, 0, 0, 0,
- 0x80, 0x11, 0, 0,
- 10, 82, 31, 5,
- 10, 82, 31, 4,
- 8888 >> 8, 8888 & 0xff, 9999 >> 8, 9999 & 0xff, 0, 8 + 32 + 8, 0, 0,
- }
- copy(packet[28:], token[:])
- for sendingIndex := uint64(0); !doneTunWrite; sendingIndex++ {
- binary.BigEndian.PutUint16(packet[4:], uint16(sendingIndex))
- var checksum uint32
- for i := 0; i < 20; i += 2 {
- if i != 10 {
- checksum += uint32(binary.BigEndian.Uint16(packet[i:]))
- }
- }
- binary.BigEndian.PutUint16(packet[10:], ^(uint16(checksum>>16) + uint16(checksum&0xffff)))
- binary.LittleEndian.PutUint64(packet[20+8+32:], sendingIndex)
- n, err := tunDevice.Write(packet[:], 0)
- if err != nil {
- fatalErrors <- err
- }
- if n == 0 {
- time.Sleep(time.Millisecond * 300)
- }
- }
- }()
- const packetsPerTest = 1 << 21
- go func() {
- defer func() {
- doneSockWrite = true
- wait.Done()
- }()
- var expectedIndex uint64
- for i := uint64(0); i < packetsPerTest; {
- var buffer [(1 << 16) - 1]byte
- bytesRead, err := tunDevice.Read(buffer[:], 0)
- if err != nil {
- fatalErrors <- err
- }
- if bytesRead < 0 || bytesRead > len(buffer) {
- continue
- }
- packet := buffer[:bytesRead]
- tokenPos := bytes.Index(packet, token[:])
- if tokenPos == -1 || tokenPos+32+8 > len(packet) {
- continue
- }
- foundIndex := binary.LittleEndian.Uint64(packet[tokenPos+32:])
- if foundIndex < expectedIndex {
- errors <- fmt.Errorf("Sock write, tun read: expected packet %d, received packet %d", expectedIndex, foundIndex)
- }
- expectedIndex = foundIndex + 1
- i++
- }
- }()
- go func() {
- defer func() {
- doneTunWrite = true
- wait.Done()
- }()
- var expectedIndex uint64
- for i := uint64(0); i < packetsPerTest; {
- var buffer [(1 << 16) - 1]byte
- bytesRead, err := sockRead.Read(buffer[:])
- if err != nil {
- fatalErrors <- err
- }
- if bytesRead < 0 || bytesRead > len(buffer) {
- continue
- }
- packet := buffer[:bytesRead]
- if len(packet) != 32+8 || !bytes.HasPrefix(packet, token[:]) {
- continue
- }
- foundIndex := binary.LittleEndian.Uint64(packet[32:])
- if foundIndex < expectedIndex {
- errors <- fmt.Errorf("Tun write, sock read: expected packet %d, received packet %d", expectedIndex, foundIndex)
- }
- expectedIndex = foundIndex + 1
- i++
- }
- }()
- done := make(chan bool, 2)
- doneFunc := func() {
- wait.Wait()
- done <- true
- }
- defer doneFunc()
- go doneFunc()
- for {
- select {
- case err := <-fatalErrors:
- t.Fatal(err)
- case err := <-errors:
- t.Error(err)
- case <-done:
- return
- }
- }
-}