aboutsummaryrefslogtreecommitdiffstats
path: root/tun
diff options
context:
space:
mode:
Diffstat (limited to 'tun')
-rw-r--r--tun/alignment_windows_test.go67
-rw-r--r--tun/checksum.go118
-rw-r--r--tun/checksum_test.go35
-rw-r--r--tun/errors.go12
-rw-r--r--tun/netstack/examples/http_client.go54
-rw-r--r--tun/netstack/examples/http_server.go51
-rw-r--r--tun/netstack/examples/ping_client.go75
-rw-r--r--tun/netstack/tun.go1055
-rw-r--r--tun/offload_linux.go993
-rw-r--r--tun/offload_linux_test.go752
-rw-r--r--tun/operateonfd.go4
-rw-r--r--tun/tun.go42
-rw-r--r--tun/tun_darwin.go259
-rw-r--r--tun/tun_freebsd.go398
-rw-r--r--tun/tun_linux.go415
-rw-r--r--tun/tun_openbsd.go117
-rw-r--r--tun/tun_windows.go265
-rw-r--r--tun/tuntest/tuntest.go155
-rw-r--r--tun/wintun/iphlpapi/conversion_windows.go25
-rw-r--r--tun/wintun/iphlpapi/mksyscall.go8
-rw-r--r--tun/wintun/iphlpapi/zsyscall_windows.go60
-rw-r--r--tun/wintun/namespace_windows.go98
-rw-r--r--tun/wintun/namespaceapi/mksyscall.go8
-rw-r--r--tun/wintun/namespaceapi/namespaceapi_windows.go83
-rw-r--r--tun/wintun/namespaceapi/zsyscall_windows.go116
-rw-r--r--tun/wintun/nci/mksyscall.go8
-rw-r--r--tun/wintun/nci/nci_windows.go28
-rw-r--r--tun/wintun/nci/zsyscall_windows.go60
-rw-r--r--tun/wintun/registry/mksyscall.go8
-rw-r--r--tun/wintun/registry/registry_windows.go272
-rw-r--r--tun/wintun/registry/registry_windows_test.go103
-rw-r--r--tun/wintun/registry/zregistry_windows.go63
-rw-r--r--tun/wintun/ring_windows.go97
-rw-r--r--tun/wintun/setupapi/mksyscall.go8
-rw-r--r--tun/wintun/setupapi/setupapi_windows.go506
-rw-r--r--tun/wintun/setupapi/setupapi_windows_test.go488
-rw-r--r--tun/wintun/setupapi/types_windows.go568
-rw-r--r--tun/wintun/setupapi/types_windows_386.go11
-rw-r--r--tun/wintun/setupapi/types_windows_amd64.go11
-rw-r--r--tun/wintun/setupapi/zsetupapi_windows.go398
-rw-r--r--tun/wintun/setupapi/zsetupapi_windows_test.go20
-rw-r--r--tun/wintun/wintun_windows.go803
42 files changed, 4127 insertions, 4590 deletions
diff --git a/tun/alignment_windows_test.go b/tun/alignment_windows_test.go
new file mode 100644
index 0000000..67a785e
--- /dev/null
+++ b/tun/alignment_windows_test.go
@@ -0,0 +1,67 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "reflect"
+ "testing"
+ "unsafe"
+)
+
+func checkAlignment(t *testing.T, name string, offset uintptr) {
+ t.Helper()
+ if offset%8 != 0 {
+ t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
+ }
+}
+
+// TestRateJugglerAlignment checks that atomically-accessed fields are
+// aligned to 64-bit boundaries, as required by the atomic package.
+//
+// Unfortunately, violating this rule on 32-bit platforms results in a
+// hard segfault at runtime.
+func TestRateJugglerAlignment(t *testing.T) {
+ var r rateJuggler
+
+ typ := reflect.TypeOf(&r).Elem()
+ t.Logf("Peer type size: %d, with fields:", typ.Size())
+ for i := 0; i < typ.NumField(); i++ {
+ field := typ.Field(i)
+ t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
+ field.Name,
+ field.Offset,
+ field.Type.Size(),
+ field.Type.Align(),
+ )
+ }
+
+ checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current))
+ checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount))
+ checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime))
+}
+
+// TestNativeTunAlignment checks that atomically-accessed fields are
+// aligned to 64-bit boundaries, as required by the atomic package.
+//
+// Unfortunately, violating this rule on 32-bit platforms results in a
+// hard segfault at runtime.
+func TestNativeTunAlignment(t *testing.T) {
+ var tun NativeTun
+
+ typ := reflect.TypeOf(&tun).Elem()
+ t.Logf("Peer type size: %d, with fields:", typ.Size())
+ for i := 0; i < typ.NumField(); i++ {
+ field := typ.Field(i)
+ t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
+ field.Name,
+ field.Offset,
+ field.Type.Size(),
+ field.Type.Align(),
+ )
+ }
+
+ checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate))
+}
diff --git a/tun/checksum.go b/tun/checksum.go
new file mode 100644
index 0000000..29a8fc8
--- /dev/null
+++ b/tun/checksum.go
@@ -0,0 +1,118 @@
+package tun
+
+import "encoding/binary"
+
+// TODO: Explore SIMD and/or other assembly optimizations.
+// TODO: Test native endian loads. See RFC 1071 section 2 part B.
+func checksumNoFold(b []byte, initial uint64) uint64 {
+ ac := initial
+
+ for len(b) >= 128 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+ ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+ ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+ ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+ ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+ ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+ ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+ ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ ac += uint64(binary.BigEndian.Uint32(b[64:68]))
+ ac += uint64(binary.BigEndian.Uint32(b[68:72]))
+ ac += uint64(binary.BigEndian.Uint32(b[72:76]))
+ ac += uint64(binary.BigEndian.Uint32(b[76:80]))
+ ac += uint64(binary.BigEndian.Uint32(b[80:84]))
+ ac += uint64(binary.BigEndian.Uint32(b[84:88]))
+ ac += uint64(binary.BigEndian.Uint32(b[88:92]))
+ ac += uint64(binary.BigEndian.Uint32(b[92:96]))
+ ac += uint64(binary.BigEndian.Uint32(b[96:100]))
+ ac += uint64(binary.BigEndian.Uint32(b[100:104]))
+ ac += uint64(binary.BigEndian.Uint32(b[104:108]))
+ ac += uint64(binary.BigEndian.Uint32(b[108:112]))
+ ac += uint64(binary.BigEndian.Uint32(b[112:116]))
+ ac += uint64(binary.BigEndian.Uint32(b[116:120]))
+ ac += uint64(binary.BigEndian.Uint32(b[120:124]))
+ ac += uint64(binary.BigEndian.Uint32(b[124:128]))
+ b = b[128:]
+ }
+ if len(b) >= 64 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+ ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+ ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+ ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+ ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+ ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+ ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+ ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ b = b[64:]
+ }
+ if len(b) >= 32 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ b = b[32:]
+ }
+ if len(b) >= 16 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ b = b[16:]
+ }
+ if len(b) >= 8 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ b = b[8:]
+ }
+ if len(b) >= 4 {
+ ac += uint64(binary.BigEndian.Uint32(b))
+ b = b[4:]
+ }
+ if len(b) >= 2 {
+ ac += uint64(binary.BigEndian.Uint16(b))
+ b = b[2:]
+ }
+ if len(b) == 1 {
+ ac += uint64(b[0]) << 8
+ }
+
+ return ac
+}
+
+func checksum(b []byte, initial uint64) uint16 {
+ ac := checksumNoFold(b, initial)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ ac = (ac >> 16) + (ac & 0xffff)
+ return uint16(ac)
+}
+
+func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
+ sum := checksumNoFold(srcAddr, 0)
+ sum = checksumNoFold(dstAddr, sum)
+ sum = checksumNoFold([]byte{0, protocol}, sum)
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ return checksumNoFold(tmp, sum)
+}
diff --git a/tun/checksum_test.go b/tun/checksum_test.go
new file mode 100644
index 0000000..c1ccff5
--- /dev/null
+++ b/tun/checksum_test.go
@@ -0,0 +1,35 @@
+package tun
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+)
+
+func BenchmarkChecksum(b *testing.B) {
+ lengths := []int{
+ 64,
+ 128,
+ 256,
+ 512,
+ 1024,
+ 1500,
+ 2048,
+ 4096,
+ 8192,
+ 9000,
+ 9001,
+ }
+
+ for _, length := range lengths {
+ b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
+ buf := make([]byte, length)
+ rng := rand.New(rand.NewSource(1))
+ rng.Read(buf)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ checksum(buf, 0)
+ }
+ })
+ }
+}
diff --git a/tun/errors.go b/tun/errors.go
new file mode 100644
index 0000000..75ae3a4
--- /dev/null
+++ b/tun/errors.go
@@ -0,0 +1,12 @@
+package tun
+
+import (
+ "errors"
+)
+
+var (
+ // ErrTooManySegments is returned by Device.Read() when segmentation
+ // overflows the length of supplied buffers. This error should not cause
+ // reads to cease.
+ ErrTooManySegments = errors.New("too many segments")
+)
diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go
new file mode 100644
index 0000000..ccd32ed
--- /dev/null
+++ b/tun/netstack/examples/http_client.go
@@ -0,0 +1,54 @@
+//go:build ignore
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import (
+ "io"
+ "log"
+ "net/http"
+ "net/netip"
+
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun/netstack"
+)
+
+func main() {
+ tun, tnet, err := netstack.CreateNetTUN(
+ []netip.Addr{netip.MustParseAddr("192.168.4.28")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8")},
+ 1420)
+ if err != nil {
+ log.Panic(err)
+ }
+ dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
+ err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379
+public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28
+allowed_ip=0.0.0.0/0
+endpoint=127.0.0.1:58120
+`)
+ err = dev.Up()
+ if err != nil {
+ log.Panic(err)
+ }
+
+ client := http.Client{
+ Transport: &http.Transport{
+ DialContext: tnet.DialContext,
+ },
+ }
+ resp, err := client.Get("http://192.168.4.29/")
+ if err != nil {
+ log.Panic(err)
+ }
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Panic(err)
+ }
+ log.Println(string(body))
+}
diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go
new file mode 100644
index 0000000..f5b7a8f
--- /dev/null
+++ b/tun/netstack/examples/http_server.go
@@ -0,0 +1,51 @@
+//go:build ignore
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import (
+ "io"
+ "log"
+ "net"
+ "net/http"
+ "net/netip"
+
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun/netstack"
+)
+
+func main() {
+ tun, tnet, err := netstack.CreateNetTUN(
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
+ 1420,
+ )
+ if err != nil {
+ log.Panic(err)
+ }
+ dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
+ dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641
+listen_port=58120
+public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c
+allowed_ip=192.168.4.28/32
+persistent_keepalive_interval=25
+`)
+ dev.Up()
+ listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80})
+ if err != nil {
+ log.Panicln(err)
+ }
+ http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) {
+ log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent())
+ io.WriteString(writer, "Hello from userspace TCP!")
+ })
+ err = http.Serve(listener, nil)
+ if err != nil {
+ log.Panicln(err)
+ }
+}
diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go
new file mode 100644
index 0000000..2eef0fb
--- /dev/null
+++ b/tun/netstack/examples/ping_client.go
@@ -0,0 +1,75 @@
+//go:build ignore
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import (
+ "bytes"
+ "log"
+ "math/rand"
+ "net/netip"
+ "time"
+
+ "golang.org/x/net/icmp"
+ "golang.org/x/net/ipv4"
+
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/device"
+ "golang.zx2c4.com/wireguard/tun/netstack"
+)
+
+func main() {
+ tun, tnet, err := netstack.CreateNetTUN(
+ []netip.Addr{netip.MustParseAddr("192.168.4.29")},
+ []netip.Addr{netip.MustParseAddr("8.8.8.8")},
+ 1420)
+ if err != nil {
+ log.Panic(err)
+ }
+ dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, ""))
+ dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f
+public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b
+endpoint=163.172.161.0:12912
+allowed_ip=0.0.0.0/0
+`)
+ err = dev.Up()
+ if err != nil {
+ log.Panic(err)
+ }
+
+ socket, err := tnet.Dial("ping4", "zx2c4.com")
+ if err != nil {
+ log.Panic(err)
+ }
+ requestPing := icmp.Echo{
+ Seq: rand.Intn(1 << 16),
+ Data: []byte("gopher burrow"),
+ }
+ icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
+ socket.SetReadDeadline(time.Now().Add(time.Second * 10))
+ start := time.Now()
+ _, err = socket.Write(icmpBytes)
+ if err != nil {
+ log.Panic(err)
+ }
+ n, err := socket.Read(icmpBytes[:])
+ if err != nil {
+ log.Panic(err)
+ }
+ replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
+ if err != nil {
+ log.Panic(err)
+ }
+ replyPing, ok := replyPacket.Body.(*icmp.Echo)
+ if !ok {
+ log.Panicf("invalid reply type: %v", replyPacket)
+ }
+ if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
+ log.Panicf("invalid ping reply: %v", replyPing)
+ }
+ log.Printf("Ping latency: %v", time.Since(start))
+}
diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go
new file mode 100644
index 0000000..2b73054
--- /dev/null
+++ b/tun/netstack/tun.go
@@ -0,0 +1,1055 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package netstack
+
+import (
+ "bytes"
+ "context"
+ "crypto/rand"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/netip"
+ "os"
+ "regexp"
+ "strconv"
+ "strings"
+ "syscall"
+ "time"
+
+ "golang.zx2c4.com/wireguard/tun"
+
+ "golang.org/x/net/dns/dnsmessage"
+ "gvisor.dev/gvisor/pkg/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+type netTun struct {
+ ep *channel.Endpoint
+ stack *stack.Stack
+ events chan tun.Event
+ incomingPacket chan *buffer.View
+ mtu int
+ dnsServers []netip.Addr
+ hasV4, hasV6 bool
+}
+
+type Net netTun
+
+func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
+ HandleLocal: true,
+ }
+ dev := &netTun{
+ ep: channel.New(1024, uint32(mtu), ""),
+ stack: stack.New(opts),
+ events: make(chan tun.Event, 10),
+ incomingPacket: make(chan *buffer.View),
+ dnsServers: dnsServers,
+ mtu: mtu,
+ }
+ sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
+ tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
+ }
+ dev.ep.AddNotify(dev)
+ tcpipErr = dev.stack.CreateNIC(1, dev.ep)
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
+ }
+ for _, ip := range localAddresses {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if ip.Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else if ip.Is6() {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: protoNumber,
+ AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(),
+ }
+ tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
+ if tcpipErr != nil {
+ return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
+ }
+ if ip.Is4() {
+ dev.hasV4 = true
+ } else if ip.Is6() {
+ dev.hasV6 = true
+ }
+ }
+ if dev.hasV4 {
+ dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1})
+ }
+ if dev.hasV6 {
+ dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1})
+ }
+
+ dev.events <- tun.EventUp
+ return dev, (*Net)(dev), nil
+}
+
+func (tun *netTun) Name() (string, error) {
+ return "go", nil
+}
+
+func (tun *netTun) File() *os.File {
+ return nil
+}
+
+func (tun *netTun) Events() <-chan tun.Event {
+ return tun.events
+}
+
+func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
+ view, ok := <-tun.incomingPacket
+ if !ok {
+ return 0, os.ErrClosed
+ }
+
+ n, err := view.Read(buf[0][offset:])
+ if err != nil {
+ return 0, err
+ }
+ sizes[0] = n
+ return 1, nil
+}
+
+func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
+ for _, buf := range buf {
+ packet := buf[offset:]
+ if len(packet) == 0 {
+ continue
+ }
+
+ pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)})
+ switch packet[0] >> 4 {
+ case 4:
+ tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb)
+ case 6:
+ tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb)
+ default:
+ return 0, syscall.EAFNOSUPPORT
+ }
+ }
+ return len(buf), nil
+}
+
+func (tun *netTun) WriteNotify() {
+ pkt := tun.ep.Read()
+ if pkt.IsNil() {
+ return
+ }
+
+ view := pkt.ToView()
+ pkt.DecRef()
+
+ tun.incomingPacket <- view
+}
+
+func (tun *netTun) Close() error {
+ tun.stack.RemoveNIC(1)
+
+ if tun.events != nil {
+ close(tun.events)
+ }
+
+ tun.ep.Close()
+
+ if tun.incomingPacket != nil {
+ close(tun.incomingPacket)
+ }
+
+ return nil
+}
+
+func (tun *netTun) MTU() (int, error) {
+ return tun.mtu, nil
+}
+
+func (tun *netTun) BatchSize() int {
+ return 1
+}
+
+func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if endpoint.Addr().Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ return tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()),
+ Port: endpoint.Port(),
+ }, protoNumber
+}
+
+func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialContextTCP(ctx, net.stack, fa, pn)
+}
+
+func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
+ if addr == nil {
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
+ }
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.DialTCP(net.stack, fa, pn)
+}
+
+func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
+ if addr == nil {
+ return net.DialTCPAddrPort(netip.AddrPort{})
+ }
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
+ fa, pn := convertToFullAddr(addr)
+ return gonet.ListenTCP(net.stack, fa, pn)
+}
+
+func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
+ if addr == nil {
+ return net.ListenTCPAddrPort(netip.AddrPort{})
+ }
+ ip, _ := netip.AddrFromSlice(addr.IP)
+ return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port)))
+}
+
+func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
+ var lfa, rfa *tcpip.FullAddress
+ var pn tcpip.NetworkProtocolNumber
+ if laddr.IsValid() || laddr.Port() > 0 {
+ var addr tcpip.FullAddress
+ addr, pn = convertToFullAddr(laddr)
+ lfa = &addr
+ }
+ if raddr.IsValid() || raddr.Port() > 0 {
+ var addr tcpip.FullAddress
+ addr, pn = convertToFullAddr(raddr)
+ rfa = &addr
+ }
+ return gonet.DialUDP(net.stack, lfa, rfa, pn)
+}
+
+func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) {
+ return net.DialUDPAddrPort(laddr, netip.AddrPort{})
+}
+
+func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ var la, ra netip.AddrPort
+ if laddr != nil {
+ ip, _ := netip.AddrFromSlice(laddr.IP)
+ la = netip.AddrPortFrom(ip, uint16(laddr.Port))
+ }
+ if raddr != nil {
+ ip, _ := netip.AddrFromSlice(raddr.IP)
+ ra = netip.AddrPortFrom(ip, uint16(raddr.Port))
+ }
+ return net.DialUDPAddrPort(la, ra)
+}
+
+func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
+ return net.DialUDP(laddr, nil)
+}
+
+type PingConn struct {
+ laddr PingAddr
+ raddr PingAddr
+ wq waiter.Queue
+ ep tcpip.Endpoint
+ deadline *time.Timer
+}
+
+type PingAddr struct{ addr netip.Addr }
+
+func (ia PingAddr) String() string {
+ return ia.addr.String()
+}
+
+func (ia PingAddr) Network() string {
+ if ia.addr.Is4() {
+ return "ping4"
+ } else if ia.addr.Is6() {
+ return "ping6"
+ }
+ return "ping"
+}
+
+func (ia PingAddr) Addr() netip.Addr {
+ return ia.addr
+}
+
+func PingAddrFromAddr(addr netip.Addr) *PingAddr {
+ return &PingAddr{addr}
+}
+
+func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) {
+ if !laddr.IsValid() && !raddr.IsValid() {
+ return nil, errors.New("ping dial: invalid address")
+ }
+ v6 := laddr.Is6() || raddr.Is6()
+ bind := laddr.IsValid()
+ if !bind {
+ if v6 {
+ laddr = netip.IPv6Unspecified()
+ } else {
+ laddr = netip.IPv4Unspecified()
+ }
+ }
+
+ tn := icmp.ProtocolNumber4
+ pn := ipv4.ProtocolNumber
+ if v6 {
+ tn = icmp.ProtocolNumber6
+ pn = ipv6.ProtocolNumber
+ }
+
+ pc := &PingConn{
+ laddr: PingAddr{laddr},
+ deadline: time.NewTimer(time.Hour << 10),
+ }
+ pc.deadline.Stop()
+
+ ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq)
+ if tcpipErr != nil {
+ return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr)
+ }
+ pc.ep = ep
+
+ if bind {
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0))
+ if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping bind: %s", tcpipErr)
+ }
+ }
+
+ if raddr.IsValid() {
+ pc.raddr = PingAddr{raddr}
+ fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0))
+ if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil {
+ return nil, fmt.Errorf("ping connect: %s", tcpipErr)
+ }
+ }
+
+ return pc, nil
+}
+
+func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) {
+ return net.DialPingAddr(laddr, netip.Addr{})
+}
+
+func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) {
+ var la, ra netip.Addr
+ if laddr != nil {
+ la = laddr.addr
+ }
+ if raddr != nil {
+ ra = raddr.addr
+ }
+ return net.DialPingAddr(la, ra)
+}
+
+func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) {
+ var la netip.Addr
+ if laddr != nil {
+ la = laddr.addr
+ }
+ return net.ListenPingAddr(la)
+}
+
+func (pc *PingConn) LocalAddr() net.Addr {
+ return pc.laddr
+}
+
+func (pc *PingConn) RemoteAddr() net.Addr {
+ return pc.raddr
+}
+
+func (pc *PingConn) Close() error {
+ pc.deadline.Reset(0)
+ pc.ep.Close()
+ return nil
+}
+
+func (pc *PingConn) SetWriteDeadline(t time.Time) error {
+ return errors.New("not implemented")
+}
+
+func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
+ var na netip.Addr
+ switch v := addr.(type) {
+ case *PingAddr:
+ na = v.addr
+ case *net.IPAddr:
+ na, _ = netip.AddrFromSlice(v.IP)
+ default:
+ return 0, fmt.Errorf("ping write: wrong net.Addr type")
+ }
+ if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) {
+ return 0, fmt.Errorf("ping write: mismatched protocols")
+ }
+
+ buf := bytes.NewReader(p)
+ rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0))
+ // won't block, no deadlines
+ n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{
+ To: &rfa,
+ })
+ if tcpipErr != nil {
+ return int(n64), fmt.Errorf("ping write: %s", tcpipErr)
+ }
+
+ return int(n64), nil
+}
+
+func (pc *PingConn) Write(p []byte) (n int, err error) {
+ return pc.WriteTo(p, &pc.raddr)
+}
+
+func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
+ e, notifyCh := waiter.NewChannelEntry(waiter.EventIn)
+ pc.wq.EventRegister(&e)
+ defer pc.wq.EventUnregister(&e)
+
+ select {
+ case <-pc.deadline.C:
+ return 0, nil, os.ErrDeadlineExceeded
+ case <-notifyCh:
+ }
+
+ w := tcpip.SliceWriter(p)
+
+ res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{
+ NeedRemoteAddr: true,
+ })
+ if tcpipErr != nil {
+ return 0, nil, fmt.Errorf("ping read: %s", tcpipErr)
+ }
+
+ remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice())
+ return res.Count, &PingAddr{remoteAddr}, nil
+}
+
+func (pc *PingConn) Read(p []byte) (n int, err error) {
+ n, _, err = pc.ReadFrom(p)
+ return
+}
+
+func (pc *PingConn) SetDeadline(t time.Time) error {
+ // pc.SetWriteDeadline is unimplemented
+
+ return pc.SetReadDeadline(t)
+}
+
+func (pc *PingConn) SetReadDeadline(t time.Time) error {
+ pc.deadline.Reset(time.Until(t))
+ return nil
+}
+
+var (
+ errNoSuchHost = errors.New("no such host")
+ errLameReferral = errors.New("lame referral")
+ errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
+ errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
+ errServerMisbehaving = errors.New("server misbehaving")
+ errInvalidDNSResponse = errors.New("invalid DNS response")
+ errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
+ errServerTemporarilyMisbehaving = errors.New("server misbehaving")
+ errCanceled = errors.New("operation was canceled")
+ errTimeout = errors.New("i/o timeout")
+ errNumericPort = errors.New("port must be numeric")
+ errNoSuitableAddress = errors.New("no suitable address found")
+ errMissingAddress = errors.New("missing address")
+)
+
+func (net *Net) LookupHost(host string) (addrs []string, err error) {
+ return net.LookupContextHost(context.Background(), host)
+}
+
+func isDomainName(s string) bool {
+ l := len(s)
+ if l == 0 || l > 254 || l == 254 && s[l-1] != '.' {
+ return false
+ }
+ last := byte('.')
+ nonNumeric := false
+ partlen := 0
+ for i := 0; i < len(s); i++ {
+ c := s[i]
+ switch {
+ default:
+ return false
+ case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_':
+ nonNumeric = true
+ partlen++
+ case '0' <= c && c <= '9':
+ partlen++
+ case c == '-':
+ if last == '.' {
+ return false
+ }
+ partlen++
+ nonNumeric = true
+ case c == '.':
+ if last == '.' || last == '-' {
+ return false
+ }
+ if partlen > 63 || partlen == 0 {
+ return false
+ }
+ partlen = 0
+ }
+ last = c
+ }
+ if last == '-' || partlen > 63 {
+ return false
+ }
+ return nonNumeric
+}
+
+func randU16() uint16 {
+ var b [2]byte
+ _, err := rand.Read(b[:])
+ if err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint16(b[:])
+}
+
+func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
+ id = randU16()
+ b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
+ b.EnableCompression()
+ if err := b.StartQuestions(); err != nil {
+ return 0, nil, nil, err
+ }
+ if err := b.Question(q); err != nil {
+ return 0, nil, nil, err
+ }
+ tcpReq, err = b.Finish()
+ udpReq = tcpReq[2:]
+ l := len(tcpReq) - 2
+ tcpReq[0] = byte(l >> 8)
+ tcpReq[1] = byte(l)
+ return id, udpReq, tcpReq, err
+}
+
+func equalASCIIName(x, y dnsmessage.Name) bool {
+ if x.Length != y.Length {
+ return false
+ }
+ for i := 0; i < int(x.Length); i++ {
+ a := x.Data[i]
+ b := y.Data[i]
+ if 'A' <= a && a <= 'Z' {
+ a += 0x20
+ }
+ if 'A' <= b && b <= 'Z' {
+ b += 0x20
+ }
+ if a != b {
+ return false
+ }
+ }
+ return true
+}
+
+func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
+ if !respHdr.Response {
+ return false
+ }
+ if reqID != respHdr.ID {
+ return false
+ }
+ if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
+ return false
+ }
+ return true
+}
+
+func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
+ if _, err := c.Write(b); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ b = make([]byte, 512)
+ for {
+ n, err := c.Read(b)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ continue
+ }
+ q, err := p.Question()
+ if err != nil || !checkResponse(id, query, h, q) {
+ continue
+ }
+ return p, h, nil
+ }
+}
+
+func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
+ if _, err := c.Write(b); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ b = make([]byte, 1280)
+ if _, err := io.ReadFull(c, b[:2]); err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ l := int(b[0])<<8 | int(b[1])
+ if l > len(b) {
+ b = make([]byte, l)
+ }
+ n, err := io.ReadFull(c, b[:l])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ var p dnsmessage.Parser
+ h, err := p.Start(b[:n])
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
+ }
+ q, err := p.Question()
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
+ }
+ if !checkResponse(id, query, h, q) {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
+ }
+ return p, h, nil
+}
+
+func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
+ q.Class = dnsmessage.ClassINET
+ id, udpReq, tcpReq, err := newRequest(q)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
+ }
+
+ for _, useUDP := range []bool{true, false} {
+ ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
+ defer cancel()
+
+ var c net.Conn
+ var err error
+ if useUDP {
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
+ } else {
+ c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
+ }
+
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ if d, ok := ctx.Deadline(); ok && !d.IsZero() {
+ err := c.SetDeadline(d)
+ if err != nil {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ }
+ var p dnsmessage.Parser
+ var h dnsmessage.Header
+ if useUDP {
+ p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
+ } else {
+ p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
+ }
+ c.Close()
+ if err != nil {
+ if err == context.Canceled {
+ err = errCanceled
+ } else if err == context.DeadlineExceeded {
+ err = errTimeout
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, err
+ }
+ if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
+ }
+ if h.Truncated {
+ continue
+ }
+ return p, h, nil
+ }
+ return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
+}
+
+func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
+ if h.RCode == dnsmessage.RCodeNameError {
+ return errNoSuchHost
+ }
+ _, err := p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ return errCannotUnmarshalDNSMessage
+ }
+ if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
+ return errLameReferral
+ }
+ if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
+ if h.RCode == dnsmessage.RCodeServerFailure {
+ return errServerTemporarilyMisbehaving
+ }
+ return errServerMisbehaving
+ }
+ return nil
+}
+
+func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
+ for {
+ h, err := p.AnswerHeader()
+ if err == dnsmessage.ErrSectionDone {
+ return errNoSuchHost
+ }
+ if err != nil {
+ return errCannotUnmarshalDNSMessage
+ }
+ if h.Type == qtype {
+ return nil
+ }
+ if err := p.SkipAnswer(); err != nil {
+ return errCannotUnmarshalDNSMessage
+ }
+ }
+}
+
+func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
+ var lastErr error
+
+ n, err := dnsmessage.NewName(name)
+ if err != nil {
+ return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
+ }
+ q := dnsmessage.Question{
+ Name: n,
+ Type: qtype,
+ Class: dnsmessage.ClassINET,
+ }
+
+ for i := 0; i < 2; i++ {
+ for _, server := range tnet.dnsServers {
+ p, h, err := tnet.exchange(ctx, server, q, time.Second*5)
+ if err != nil {
+ dnsErr := &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
+ dnsErr.IsTimeout = true
+ }
+ if _, ok := err.(*net.OpError); ok {
+ dnsErr.IsTemporary = true
+ }
+ lastErr = dnsErr
+ continue
+ }
+
+ if err := checkHeader(&p, h); err != nil {
+ dnsErr := &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if err == errServerTemporarilyMisbehaving {
+ dnsErr.IsTemporary = true
+ }
+ if err == errNoSuchHost {
+ dnsErr.IsNotFound = true
+ return p, server.String(), dnsErr
+ }
+ lastErr = dnsErr
+ continue
+ }
+
+ err = skipToAnswer(&p, qtype)
+ if err == nil {
+ return p, server.String(), nil
+ }
+ lastErr = &net.DNSError{
+ Err: err.Error(),
+ Name: name,
+ Server: server.String(),
+ }
+ if err == errNoSuchHost {
+ lastErr.(*net.DNSError).IsNotFound = true
+ return p, server.String(), lastErr
+ }
+ }
+ }
+ return dnsmessage.Parser{}, "", lastErr
+}
+
+func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) {
+ if host == "" || (!tnet.hasV6 && !tnet.hasV4) {
+ return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ zlen := len(host)
+ if strings.IndexByte(host, ':') != -1 {
+ if zidx := strings.LastIndexByte(host, '%'); zidx != -1 {
+ zlen = zidx
+ }
+ }
+ if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
+ return []string{ip.String()}, nil
+ }
+
+ if !isDomainName(host) {
+ return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+ }
+ type result struct {
+ p dnsmessage.Parser
+ server string
+ error
+ }
+ var addrsV4, addrsV6 []netip.Addr
+ lanes := 0
+ if tnet.hasV4 {
+ lanes++
+ }
+ if tnet.hasV6 {
+ lanes++
+ }
+ lane := make(chan result, lanes)
+ var lastErr error
+ if tnet.hasV4 {
+ go func() {
+ p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA)
+ lane <- result{p, server, err}
+ }()
+ }
+ if tnet.hasV6 {
+ go func() {
+ p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA)
+ lane <- result{p, server, err}
+ }()
+ }
+ for l := 0; l < lanes; l++ {
+ result := <-lane
+ if result.error != nil {
+ if lastErr == nil {
+ lastErr = result.error
+ }
+ continue
+ }
+
+ loop:
+ for {
+ h, err := result.p.AnswerHeader()
+ if err != nil && err != dnsmessage.ErrSectionDone {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ }
+ if err != nil {
+ break
+ }
+ switch h.Type {
+ case dnsmessage.TypeA:
+ a, err := result.p.AResource()
+ if err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
+
+ case dnsmessage.TypeAAAA:
+ aaaa, err := result.p.AAAAResource()
+ if err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
+
+ default:
+ if err := result.p.SkipAnswer(); err != nil {
+ lastErr = &net.DNSError{
+ Err: errCannotMarshalDNSMessage.Error(),
+ Name: host,
+ Server: result.server,
+ }
+ break loop
+ }
+ continue
+ }
+ }
+ }
+ // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled
+ var addrs []netip.Addr
+ if tnet.hasV6 {
+ addrs = append(addrsV6, addrsV4...)
+ } else {
+ addrs = append(addrsV4, addrsV6...)
+ }
+
+ if len(addrs) == 0 && lastErr != nil {
+ return nil, lastErr
+ }
+ saddrs := make([]string, 0, len(addrs))
+ for _, ip := range addrs {
+ saddrs = append(saddrs, ip.String())
+ }
+ return saddrs, nil
+}
+
+func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
+ if deadline.IsZero() {
+ return deadline, nil
+ }
+ timeRemaining := deadline.Sub(now)
+ if timeRemaining <= 0 {
+ return time.Time{}, errTimeout
+ }
+ timeout := timeRemaining / time.Duration(addrsRemaining)
+ const saneMinimum = 2 * time.Second
+ if timeout < saneMinimum {
+ if timeRemaining < saneMinimum {
+ timeout = timeRemaining
+ } else {
+ timeout = saneMinimum
+ }
+ }
+ return now.Add(timeout), nil
+}
+
+var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`)
+
+func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ if ctx == nil {
+ panic("nil context")
+ }
+ var acceptV4, acceptV6 bool
+ matches := protoSplitter.FindStringSubmatch(network)
+ if matches == nil {
+ return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
+ } else if len(matches[2]) == 0 {
+ acceptV4 = true
+ acceptV6 = true
+ } else {
+ acceptV4 = matches[2][0] == '4'
+ acceptV6 = !acceptV4
+ }
+ var host string
+ var port int
+ if matches[1] == "ping" {
+ host = address
+ } else {
+ var sport string
+ var err error
+ host, sport, err = net.SplitHostPort(address)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ port, err = strconv.Atoi(sport)
+ if err != nil || port < 0 || port > 65535 {
+ return nil, &net.OpError{Op: "dial", Err: errNumericPort}
+ }
+ }
+ allAddr, err := tnet.LookupContextHost(ctx, host)
+ if err != nil {
+ return nil, &net.OpError{Op: "dial", Err: err}
+ }
+ var addrs []netip.AddrPort
+ for _, addr := range allAddr {
+ ip, err := netip.ParseAddr(addr)
+ if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
+ addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
+ }
+ }
+ if len(addrs) == 0 && len(allAddr) != 0 {
+ return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
+ }
+
+ var firstErr error
+ for i, addr := range addrs {
+ select {
+ case <-ctx.Done():
+ err := ctx.Err()
+ if err == context.Canceled {
+ err = errCanceled
+ } else if err == context.DeadlineExceeded {
+ err = errTimeout
+ }
+ return nil, &net.OpError{Op: "dial", Err: err}
+ default:
+ }
+
+ dialCtx := ctx
+ if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
+ partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
+ if err != nil {
+ if firstErr == nil {
+ firstErr = &net.OpError{Op: "dial", Err: err}
+ }
+ break
+ }
+ if partialDeadline.Before(deadline) {
+ var cancel context.CancelFunc
+ dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
+ defer cancel()
+ }
+ }
+
+ var c net.Conn
+ switch matches[1] {
+ case "tcp":
+ c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
+ case "udp":
+ c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
+ case "ping":
+ c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr())
+ }
+ if err == nil {
+ return c, nil
+ }
+ if firstErr == nil {
+ firstErr = err
+ }
+ }
+ if firstErr == nil {
+ firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
+ }
+ return nil, firstErr
+}
+
+func (tnet *Net) Dial(network, address string) (net.Conn, error) {
+ return tnet.DialContext(context.Background(), network, address)
+}
diff --git a/tun/offload_linux.go b/tun/offload_linux.go
new file mode 100644
index 0000000..9ff7fea
--- /dev/null
+++ b/tun/offload_linux.go
@@ -0,0 +1,993 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "bytes"
+ "encoding/binary"
+ "errors"
+ "io"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+)
+
+const tcpFlagsOffset = 13
+
+const (
+ tcpFlagFIN uint8 = 0x01
+ tcpFlagPSH uint8 = 0x08
+ tcpFlagACK uint8 = 0x10
+)
+
+// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
+// kernel symbol is virtio_net_hdr.
+type virtioNetHdr struct {
+ flags uint8
+ gsoType uint8
+ hdrLen uint16
+ gsoSize uint16
+ csumStart uint16
+ csumOffset uint16
+}
+
+func (v *virtioNetHdr) decode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
+ return nil
+}
+
+func (v *virtioNetHdr) encode(b []byte) error {
+ if len(b) < virtioNetHdrLen {
+ return io.ErrShortBuffer
+ }
+ copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
+ return nil
+}
+
+const (
+ // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
+ // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
+ virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
+)
+
+// tcpFlowKey represents the key for a TCP flow.
+type tcpFlowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
+ isV6 bool
+}
+
+// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO.
+type tcpGROTable struct {
+ itemsByFlow map[tcpFlowKey][]tcpGROItem
+ itemsPool [][]tcpGROItem
+}
+
+func newTCPGROTable() *tcpGROTable {
+ t := &tcpGROTable{
+ itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize),
+ itemsPool: make([][]tcpGROItem, conn.IdealBatchSize),
+ }
+ for i := range t.itemsPool {
+ t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize)
+ }
+ return t
+}
+
+func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey {
+ key := tcpFlowKey{}
+ addrSize := dstAddrOffset - srcAddrOffset
+ copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+ copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
+ key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
+ key.isV6 = addrSize == 16
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
+ key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ items, ok := t.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
+ key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
+ item := tcpGROItem{
+ key: key,
+ bufsIndex: uint16(bufsIndex),
+ gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
+ iphLen: uint8(tcphOffset),
+ tcphLen: uint8(tcphLen),
+ sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
+ pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
+ }
+ items, ok := t.itemsByFlow[key]
+ if !ok {
+ items = t.newItems()
+ }
+ items = append(items, item)
+ t.itemsByFlow[key] = items
+}
+
+func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
+ items, _ := t.itemsByFlow[item.key]
+ items[i] = item
+}
+
+func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) {
+ items, _ := t.itemsByFlow[key]
+ items = append(items[:i], items[i+1:]...)
+ t.itemsByFlow[key] = items
+}
+
+// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type tcpGROItem struct {
+ key tcpFlowKey
+ sentSeq uint32 // the sequence number
+ bufsIndex uint16 // the index into the original bufs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ tcphLen uint8 // tcp header len
+ pshSet bool // psh flag is set
+}
+
+func (t *tcpGROTable) newItems() []tcpGROItem {
+ var items []tcpGROItem
+ items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
+ return items
+}
+
+func (t *tcpGROTable) reset() {
+ for k, items := range t.itemsByFlow {
+ items = items[:0]
+ t.itemsPool = append(t.itemsPool, items)
+ delete(t.itemsByFlow, k)
+ }
+}
+
+// udpFlowKey represents the key for a UDP flow.
+type udpFlowKey struct {
+ srcAddr, dstAddr [16]byte
+ srcPort, dstPort uint16
+ isV6 bool
+}
+
+// udpGROTable holds flow and coalescing information for the purposes of UDP GRO.
+type udpGROTable struct {
+ itemsByFlow map[udpFlowKey][]udpGROItem
+ itemsPool [][]udpGROItem
+}
+
+func newUDPGROTable() *udpGROTable {
+ u := &udpGROTable{
+ itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize),
+ itemsPool: make([][]udpGROItem, conn.IdealBatchSize),
+ }
+ for i := range u.itemsPool {
+ u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize)
+ }
+ return u
+}
+
+func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey {
+ key := udpFlowKey{}
+ addrSize := dstAddrOffset - srcAddrOffset
+ copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset])
+ copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize])
+ key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:])
+ key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:])
+ key.isV6 = addrSize == 16
+ return key
+}
+
+// lookupOrInsert looks up a flow for the provided packet and metadata,
+// returning the packets found for the flow, or inserting a new one if none
+// is found.
+func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) {
+ key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+ items, ok := u.itemsByFlow[key]
+ if ok {
+ return items, ok
+ }
+ // TODO: insert() performs another map lookup. This could be rearranged to avoid.
+ u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false)
+ return nil, false
+}
+
+// insert an item in the table for the provided packet and packet metadata.
+func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) {
+ key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset)
+ item := udpGROItem{
+ key: key,
+ bufsIndex: uint16(bufsIndex),
+ gsoSize: uint16(len(pkt[udphOffset+udphLen:])),
+ iphLen: uint8(udphOffset),
+ cSumKnownInvalid: cSumKnownInvalid,
+ }
+ items, ok := u.itemsByFlow[key]
+ if !ok {
+ items = u.newItems()
+ }
+ items = append(items, item)
+ u.itemsByFlow[key] = items
+}
+
+func (u *udpGROTable) updateAt(item udpGROItem, i int) {
+ items, _ := u.itemsByFlow[item.key]
+ items[i] = item
+}
+
+// udpGROItem represents bookkeeping data for a UDP packet during the lifetime
+// of a GRO evaluation across a vector of packets.
+type udpGROItem struct {
+ key udpFlowKey
+ bufsIndex uint16 // the index into the original bufs slice
+ numMerged uint16 // the number of packets merged into this item
+ gsoSize uint16 // payload size
+ iphLen uint8 // ip header len
+ cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown.
+}
+
+func (u *udpGROTable) newItems() []udpGROItem {
+ var items []udpGROItem
+ items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1]
+ return items
+}
+
+func (u *udpGROTable) reset() {
+ for k, items := range u.itemsByFlow {
+ items = items[:0]
+ u.itemsPool = append(u.itemsPool, items)
+ delete(u.itemsByFlow, k)
+ }
+}
+
+// canCoalesce represents the outcome of checking if two TCP packets are
+// candidates for coalescing.
+type canCoalesce int
+
+const (
+ coalescePrepend canCoalesce = -1
+ coalesceUnavailable canCoalesce = 0
+ coalesceAppend canCoalesce = 1
+)
+
+// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB
+// meet all requirements to be merged as part of a GRO operation, otherwise it
+// returns false.
+func ipHeadersCanCoalesce(pktA, pktB []byte) bool {
+ if len(pktA) < 9 || len(pktB) < 9 {
+ return false
+ }
+ if pktA[0]>>4 == 6 {
+ if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 {
+ // cannot coalesce with unequal Traffic class values
+ return false
+ }
+ if pktA[7] != pktB[7] {
+ // cannot coalesce with unequal Hop limit values
+ return false
+ }
+ } else {
+ if pktA[1] != pktB[1] {
+ // cannot coalesce with unequal ToS values
+ return false
+ }
+ if pktA[6]>>5 != pktB[6]>>5 {
+ // cannot coalesce with unequal DF or reserved bits. MF is checked
+ // further up the stack.
+ return false
+ }
+ if pktA[8] != pktB[8] {
+ // cannot coalesce with unequal TTL values
+ return false
+ }
+ }
+ return true
+}
+
+// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. iphLen and gsoSize describe pkt. bufs is the vector of
+// packets involved in the current GRO evaluation. bufsOffset is the offset at
+// which packet data begins within bufs.
+func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+ pktTarget := bufs[item.bufsIndex][bufsOffset:]
+ if !ipHeadersCanCoalesce(pkt, pktTarget) {
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+}
+
+// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
+// described by item. This function makes considerations that match the kernel's
+// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
+func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
+ pktTarget := bufs[item.bufsIndex][bufsOffset:]
+ if tcphLen != item.tcphLen {
+ // cannot coalesce with unequal tcp options len
+ return coalesceUnavailable
+ }
+ if tcphLen > 20 {
+ if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
+ // cannot coalesce with unequal tcp options
+ return coalesceUnavailable
+ }
+ }
+ if !ipHeadersCanCoalesce(pkt, pktTarget) {
+ return coalesceUnavailable
+ }
+ // seq adjacency
+ lhsLen := item.gsoSize
+ lhsLen += item.numMerged * item.gsoSize
+ if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
+ if item.pshSet {
+ // We cannot append to a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
+ // A smaller than gsoSize packet has been appended previously.
+ // Nothing can come after a smaller packet on the end.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ return coalesceAppend
+ } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
+ if pshSet {
+ // We cannot prepend with a segment that has the PSH flag set, PSH
+ // can only be set on the final segment in a reassembled group.
+ return coalesceUnavailable
+ }
+ if gsoSize < item.gsoSize {
+ // We cannot have a larger packet following a smaller one.
+ return coalesceUnavailable
+ }
+ if gsoSize > item.gsoSize && item.numMerged > 0 {
+ // There's at least one previous merge, and we're larger than all
+ // previous. This would put multiple smaller packets on the end.
+ return coalesceUnavailable
+ }
+ return coalescePrepend
+ }
+ return coalesceUnavailable
+}
+
+func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool {
+ srcAddrAt := ipv4SrcAddrOffset
+ addrSize := 4
+ if isV6 {
+ srcAddrAt = ipv6SrcAddrOffset
+ addrSize = 16
+ }
+ lenForPseudo := uint16(len(pkt) - int(iphLen))
+ cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo)
+ return ^checksum(pkt[iphLen:], cSum) == 0
+}
+
+// coalesceResult represents the result of attempting to coalesce two TCP
+// packets.
+type coalesceResult int
+
+const (
+ coalesceInsufficientCap coalesceResult = iota
+ coalescePSHEnding
+ coalesceItemInvalidCSum
+ coalescePktInvalidCSum
+ coalesceSuccess
+)
+
+// coalesceUDPPackets attempts to coalesce pkt with the packet described by
+// item, and returns the outcome.
+func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+ pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front
+ headersLen := item.iphLen + udphLen
+ coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+ if cap(pktHead)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ extendBy := len(pkt) - int(headersLen)
+ bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+ copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+// coalesceTCPPackets attempts to coalesce pkt with the packet described by
+// item, and returns the outcome. This function may swap bufs elements in the
+// event of a prepend as item's bufs index is already being tracked for writing
+// to a Device.
+func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
+ var pktHead []byte // the packet that will end up at the front
+ headersLen := item.iphLen + item.tcphLen
+ coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
+
+ // Copy data
+ if mode == coalescePrepend {
+ pktHead = pkt
+ if cap(pkt)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if pshSet {
+ return coalescePSHEnding
+ }
+ if item.numMerged == 0 {
+ if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ item.sentSeq = seq
+ extendBy := coalescedLen - len(pktHead)
+ bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
+ copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
+ // Flip the slice headers in bufs as part of prepend. The index of item
+ // is already being tracked for writing.
+ bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
+ } else {
+ pktHead = bufs[item.bufsIndex][bufsOffset:]
+ if cap(pktHead)-bufsOffset < coalescedLen {
+ // We don't want to allocate a new underlying array if capacity is
+ // too small.
+ return coalesceInsufficientCap
+ }
+ if item.numMerged == 0 {
+ if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalesceItemInvalidCSum
+ }
+ }
+ if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) {
+ return coalescePktInvalidCSum
+ }
+ if pshSet {
+ // We are appending a segment with PSH set.
+ item.pshSet = pshSet
+ pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
+ }
+ extendBy := len(pkt) - int(headersLen)
+ bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
+ copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
+ }
+
+ if gsoSize > item.gsoSize {
+ item.gsoSize = gsoSize
+ }
+
+ item.numMerged++
+ return coalesceSuccess
+}
+
+const (
+ ipv4FlagMoreFragments uint8 = 0x20
+)
+
+const (
+ ipv4SrcAddrOffset = 12
+ ipv6SrcAddrOffset = 8
+ maxUint16 = 1<<16 - 1
+)
+
+type groResult int
+
+const (
+ groResultNoop groResult = iota
+ groResultTableInsert
+ groResultCoalesced
+)
+
+// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult {
+ pkt := bufs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return groResultNoop
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return groResultNoop
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return groResultNoop
+ }
+ }
+ if len(pkt) < iphLen {
+ return groResultNoop
+ }
+ tcphLen := int((pkt[iphLen+12] >> 4) * 4)
+ if tcphLen < 20 || tcphLen > 60 {
+ return groResultNoop
+ }
+ if len(pkt) < iphLen+tcphLen {
+ return groResultNoop
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+ // no GRO support for fragmented segments for now
+ return groResultNoop
+ }
+ }
+ tcpFlags := pkt[iphLen+tcpFlagsOffset]
+ var pshSet bool
+ // not a candidate if any non-ACK flags (except PSH+ACK) are set
+ if tcpFlags != tcpFlagACK {
+ if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
+ return groResultNoop
+ }
+ pshSet = true
+ }
+ gsoSize := uint16(len(pkt) - tcphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return groResultNoop
+ }
+ seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ if !existing {
+ return groResultTableInsert
+ }
+ for i := len(items) - 1; i >= 0; i-- {
+ // In the best case of packets arriving in order iterating in reverse is
+ // more efficient if there are multiple items for a given flow. This
+ // also enables a natural table.deleteAt() in the
+ // coalesceItemInvalidCSum case without the need for index tracking.
+ // This algorithm makes a best effort to coalesce in the event of
+ // unordered packets, where pkt may land anywhere in items from a
+ // sequence number perspective, however once an item is inserted into
+ // the table it is never compared across other items later.
+ item := items[i]
+ can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
+ if can != coalesceUnavailable {
+ result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, i)
+ return groResultCoalesced
+ case coalesceItemInvalidCSum:
+ // delete the item with an invalid csum
+ table.deleteAt(item.key, i)
+ case coalescePktInvalidCSum:
+ // no point in inserting an item that we can't coalesce
+ return groResultNoop
+ default:
+ }
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
+ return groResultTableInsert
+}
+
+// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error {
+ for _, items := range table.itemsByFlow {
+ for _, item := range items {
+ if item.numMerged > 0 {
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(item.iphLen + item.tcphLen),
+ gsoSize: item.gsoSize,
+ csumStart: uint16(item.iphLen),
+ csumOffset: 16,
+ }
+ pkt := bufs[item.bufsIndex][offset:]
+
+ // Recalculate the total len (IPv4) or payload len (IPv6).
+ // Recalculate the (IPv4) header checksum.
+ if item.key.isV6 {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
+ binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+ } else {
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
+ pkt[10], pkt[11] = 0, 0
+ binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+ iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
+ binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
+ }
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+
+ // Calculate the pseudo header checksum and place it at the TCP
+ // checksum offset. Downstream checksum offloading will combine
+ // this with computation of the tcp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if item.key.isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := offset + addrOffset
+ srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+ } else {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
+
+// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the
+// metadata found in table.
+func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error {
+ for _, items := range table.itemsByFlow {
+ for _, item := range items {
+ if item.numMerged > 0 {
+ hdr := virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
+ hdrLen: uint16(item.iphLen + udphLen),
+ gsoSize: item.gsoSize,
+ csumStart: uint16(item.iphLen),
+ csumOffset: 6,
+ }
+ pkt := bufs[item.bufsIndex][offset:]
+
+ // Recalculate the total len (IPv4) or payload len (IPv6).
+ // Recalculate the (IPv4) header checksum.
+ hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4
+ if item.key.isV6 {
+ binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
+ } else {
+ pkt[10], pkt[11] = 0, 0
+ binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
+ iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum
+ binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
+ }
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+
+ // Recalculate the UDP len field value
+ binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:])))
+
+ // Calculate the pseudo header checksum and place it at the UDP
+ // checksum offset. Downstream checksum offloading will combine
+ // this with computation of the udp header and payload checksum.
+ addrLen := 4
+ addrOffset := ipv4SrcAddrOffset
+ if item.key.isV6 {
+ addrLen = 16
+ addrOffset = ipv6SrcAddrOffset
+ }
+ srcAddrAt := offset + addrOffset
+ srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
+ dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
+ psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
+ binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum))
+ } else {
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ }
+ }
+ }
+ return nil
+}
+
+type groCandidateType uint8
+
+const (
+ notGROCandidate groCandidateType = iota
+ tcp4GROCandidate
+ tcp6GROCandidate
+ udp4GROCandidate
+ udp6GROCandidate
+)
+
+func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType {
+ if len(b) < 28 {
+ return notGROCandidate
+ }
+ if b[0]>>4 == 4 {
+ if b[0]&0x0F != 5 {
+ // IPv4 packets w/IP options do not coalesce
+ return notGROCandidate
+ }
+ if b[9] == unix.IPPROTO_TCP && len(b) >= 40 {
+ return tcp4GROCandidate
+ }
+ if b[9] == unix.IPPROTO_UDP && canUDPGRO {
+ return udp4GROCandidate
+ }
+ } else if b[0]>>4 == 6 {
+ if b[6] == unix.IPPROTO_TCP && len(b) >= 60 {
+ return tcp6GROCandidate
+ }
+ if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO {
+ return udp6GROCandidate
+ }
+ }
+ return notGROCandidate
+}
+
+const (
+ udphLen = 8
+)
+
+// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with
+// existing packets tracked in table. It returns a groResultNoop when no
+// action was taken, groResultTableInsert when the evaluated packet was
+// inserted into table, and groResultCoalesced when the evaluated packet was
+// coalesced with another packet in table.
+func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult {
+ pkt := bufs[pktI][offset:]
+ if len(pkt) > maxUint16 {
+ // A valid IPv4 or IPv6 packet will never exceed this.
+ return groResultNoop
+ }
+ iphLen := int((pkt[0] & 0x0F) * 4)
+ if isV6 {
+ iphLen = 40
+ ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
+ if ipv6HPayloadLen != len(pkt)-iphLen {
+ return groResultNoop
+ }
+ } else {
+ totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
+ if totalLen != len(pkt) {
+ return groResultNoop
+ }
+ }
+ if len(pkt) < iphLen {
+ return groResultNoop
+ }
+ if len(pkt) < iphLen+udphLen {
+ return groResultNoop
+ }
+ if !isV6 {
+ if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
+ // no GRO support for fragmented segments for now
+ return groResultNoop
+ }
+ }
+ gsoSize := uint16(len(pkt) - udphLen - iphLen)
+ // not a candidate if payload len is 0
+ if gsoSize < 1 {
+ return groResultNoop
+ }
+ srcAddrOffset := ipv4SrcAddrOffset
+ addrLen := 4
+ if isV6 {
+ srcAddrOffset = ipv6SrcAddrOffset
+ addrLen = 16
+ }
+ items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI)
+ if !existing {
+ return groResultTableInsert
+ }
+ // With UDP we only check the last item, otherwise we could reorder packets
+ // for a given flow. We must also always insert a new item, or successfully
+ // coalesce with an existing item, for the same reason.
+ item := items[len(items)-1]
+ can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset)
+ var pktCSumKnownInvalid bool
+ if can == coalesceAppend {
+ result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6)
+ switch result {
+ case coalesceSuccess:
+ table.updateAt(item, len(items)-1)
+ return groResultCoalesced
+ case coalesceItemInvalidCSum:
+ // If the existing item has an invalid csum we take no action. A new
+ // item will be stored after it, and the existing item will never be
+ // revisited as part of future coalescing candidacy checks.
+ case coalescePktInvalidCSum:
+ // We must insert a new item, but we also mark it as invalid csum
+ // to prevent a repeat checksum validation.
+ pktCSumKnownInvalid = true
+ default:
+ }
+ }
+ // failed to coalesce with any other packets; store the item in the flow
+ table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid)
+ return groResultTableInsert
+}
+
+// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
+// packets into toWrite. toWrite, tcpTable, and udpTable should initially be
+// empty (but non-nil), and are passed in to save allocs as the caller may reset
+// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is
+// supported.
+func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error {
+ for i := range bufs {
+ if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
+ return errors.New("invalid offset")
+ }
+ var result groResult
+ switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) {
+ case tcp4GROCandidate:
+ result = tcpGRO(bufs, offset, i, tcpTable, false)
+ case tcp6GROCandidate:
+ result = tcpGRO(bufs, offset, i, tcpTable, true)
+ case udp4GROCandidate:
+ result = udpGRO(bufs, offset, i, udpTable, false)
+ case udp6GROCandidate:
+ result = udpGRO(bufs, offset, i, udpTable, true)
+ }
+ switch result {
+ case groResultNoop:
+ hdr := virtioNetHdr{}
+ err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
+ if err != nil {
+ return err
+ }
+ fallthrough
+ case groResultTableInsert:
+ *toWrite = append(*toWrite, i)
+ }
+ }
+ errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable)
+ errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable)
+ return errors.Join(errTCP, errUDP)
+}
+
+// gsoSplit splits packets from in into outBuffs, writing the size of each
+// element into sizes. It returns the number of buffers populated, and/or an
+// error.
+func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) {
+ iphLen := int(hdr.csumStart)
+ srcAddrOffset := ipv6SrcAddrOffset
+ addrLen := 16
+ if !isV6 {
+ in[10], in[11] = 0, 0 // clear ipv4 header checksum
+ srcAddrOffset = ipv4SrcAddrOffset
+ addrLen = 4
+ }
+ transportCsumAt := int(hdr.csumStart + hdr.csumOffset)
+ in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum
+ var firstTCPSeqNum uint32
+ var protocol uint8
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 {
+ protocol = unix.IPPROTO_TCP
+ firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:])
+ } else {
+ protocol = unix.IPPROTO_UDP
+ }
+ nextSegmentDataAt := int(hdr.hdrLen)
+ i := 0
+ for ; nextSegmentDataAt < len(in); i++ {
+ if i == len(outBuffs) {
+ return i - 1, ErrTooManySegments
+ }
+ nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
+ if nextSegmentEnd > len(in) {
+ nextSegmentEnd = len(in)
+ }
+ segmentDataLen := nextSegmentEnd - nextSegmentDataAt
+ totalLen := int(hdr.hdrLen) + segmentDataLen
+ sizes[i] = totalLen
+ out := outBuffs[i][outOffset:]
+
+ copy(out, in[:iphLen])
+ if !isV6 {
+ // For IPv4 we are responsible for incrementing the ID field,
+ // updating the total len field, and recalculating the header
+ // checksum.
+ if i > 0 {
+ id := binary.BigEndian.Uint16(out[4:])
+ id += uint16(i)
+ binary.BigEndian.PutUint16(out[4:], id)
+ }
+ binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
+ ipv4CSum := ^checksum(out[:iphLen], 0)
+ binary.BigEndian.PutUint16(out[10:], ipv4CSum)
+ } else {
+ // For IPv6 we are responsible for updating the payload length field.
+ binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
+ }
+
+ // copy transport header
+ copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
+
+ if protocol == unix.IPPROTO_TCP {
+ // set TCP seq and adjust TCP flags
+ tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
+ binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
+ if nextSegmentEnd != len(in) {
+ // FIN and PSH should only be set on last segment
+ clearFlags := tcpFlagFIN | tcpFlagPSH
+ out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
+ }
+ } else {
+ // set UDP header len
+ binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart))
+ }
+
+ // payload
+ copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
+
+ // transport checksum
+ transportHeaderLen := int(hdr.hdrLen - hdr.csumStart)
+ lenForPseudo := uint16(transportHeaderLen + segmentDataLen)
+ transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo)
+ transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold)
+ binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum)
+
+ nextSegmentDataAt += int(hdr.gsoSize)
+ }
+ return i, nil
+}
+
+func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
+ cSumAt := cSumStart + cSumOffset
+ // The initial value at the checksum offset should be summed with the
+ // checksum we compute. This is typically the pseudo-header checksum.
+ initial := binary.BigEndian.Uint16(in[cSumAt:])
+ in[cSumAt], in[cSumAt+1] = 0, 0
+ binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial)))
+ return nil
+}
diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go
new file mode 100644
index 0000000..ae55c8c
--- /dev/null
+++ b/tun/offload_linux_test.go
@@ -0,0 +1,752 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tun
+
+import (
+ "net/netip"
+ "testing"
+
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+)
+
+const (
+ offset = virtioNetHdrLen
+)
+
+var (
+ ip4PortA = netip.MustParseAddrPort("192.0.2.1:1")
+ ip4PortB = netip.MustParseAddrPort("192.0.2.2:1")
+ ip4PortC = netip.MustParseAddrPort("192.0.2.3:1")
+ ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1")
+ ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1")
+ ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1")
+)
+
+func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 28 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_UDP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ udpH := header.UDP(b[offset+20:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 48 + payloadLen
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_UDP,
+ HopLimit: 64,
+ PayloadLength: uint16(payloadLen + udphLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ udpH := header.UDP(b[offset+40:])
+ udpH.Encode(&header.UDPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ Length: uint16(payloadLen + udphLen),
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen))
+ udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte {
+ return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil)
+}
+
+func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte {
+ totalLen := 40 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv4H := header.IPv4(b[offset:])
+ srcAs4 := srcIPPort.Addr().As4()
+ dstAs4 := dstIPPort.Addr().As4()
+ ipFields := &header.IPv4Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs4[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs4[:]),
+ Protocol: unix.IPPROTO_TCP,
+ TTL: 64,
+ TotalLength: uint16(totalLen),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv4H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+20:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte {
+ totalLen := 60 + segmentSize
+ b := make([]byte, offset+int(totalLen), 65535)
+ ipv6H := header.IPv6(b[offset:])
+ srcAs16 := srcIPPort.Addr().As16()
+ dstAs16 := dstIPPort.Addr().As16()
+ ipFields := &header.IPv6Fields{
+ SrcAddr: tcpip.AddrFromSlice(srcAs16[:]),
+ DstAddr: tcpip.AddrFromSlice(dstAs16[:]),
+ TransportProtocol: unix.IPPROTO_TCP,
+ HopLimit: 64,
+ PayloadLength: uint16(segmentSize + 20),
+ }
+ if ipFn != nil {
+ ipFn(ipFields)
+ }
+ ipv6H.Encode(ipFields)
+ tcpH := header.TCP(b[offset+40:])
+ tcpH.Encode(&header.TCPFields{
+ SrcPort: srcIPPort.Port(),
+ DstPort: dstIPPort.Port(),
+ SeqNum: seq,
+ AckNum: 1,
+ DataOffset: 20,
+ Flags: flags,
+ WindowSize: 3000,
+ })
+ pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize))
+ tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum))
+ return b
+}
+
+func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte {
+ return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil)
+}
+
+func Test_handleVirtioRead(t *testing.T) {
+ tests := []struct {
+ name string
+ hdr virtioNetHdr
+ pktIn []byte
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "tcp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4,
+ gsoSize: 100,
+ hdrLen: 40,
+ csumStart: 20,
+ csumOffset: 16,
+ },
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{140, 140},
+ false,
+ },
+ {
+ "tcp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6,
+ gsoSize: 100,
+ hdrLen: 60,
+ csumStart: 40,
+ csumOffset: 16,
+ },
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1),
+ []int{160, 160},
+ false,
+ },
+ {
+ "udp4",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 28,
+ csumStart: 20,
+ csumOffset: 6,
+ },
+ udp4Packet(ip4PortA, ip4PortB, 200),
+ []int{128, 128},
+ false,
+ },
+ {
+ "udp6",
+ virtioNetHdr{
+ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM,
+ gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4,
+ gsoSize: 100,
+ hdrLen: 48,
+ csumStart: 40,
+ csumOffset: 6,
+ },
+ udp6Packet(ip6PortA, ip6PortB, 200),
+ []int{148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ out := make([][]byte, conn.IdealBatchSize)
+ sizes := make([]int, conn.IdealBatchSize)
+ for i := range out {
+ out[i] = make([]byte, 65535)
+ }
+ tt.hdr.encode(tt.pktIn)
+ n, err := handleVirtioRead(tt.pktIn, out, sizes, offset)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if n != len(tt.wantLens) {
+ t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens))
+ }
+ for i := range tt.wantLens {
+ if tt.wantLens[i] != sizes[i] {
+ t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i])
+ }
+ }
+ })
+ }
+}
+
+func flipTCP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func flipUDP4Checksum(b []byte) []byte {
+ at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6
+ b[at] ^= 0xFF
+ b[at+1] ^= 0xFF
+ return b
+}
+
+func Fuzz_handleGRO(f *testing.F) {
+ pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)
+ pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101)
+ pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201)
+ pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)
+ pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101)
+ pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201)
+ pkt6 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt7 := udp4Packet(ip4PortA, ip4PortB, 100)
+ pkt8 := udp4Packet(ip4PortA, ip4PortC, 100)
+ pkt9 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt10 := udp6Packet(ip6PortA, ip6PortB, 100)
+ pkt11 := udp6Packet(ip6PortA, ip6PortC, 100)
+ f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset)
+ f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) {
+ pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11}
+ toWrite := make([]int, 0, len(pkts))
+ handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite)
+ if len(toWrite) > len(pkts) {
+ t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts))
+ }
+ seenWriteI := make(map[int]bool)
+ for _, writeI := range toWrite {
+ if writeI < 0 || writeI > len(pkts)-1 {
+ t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts))
+ }
+ if seenWriteI[writeI] {
+ t.Errorf("duplicate toWrite value: %d", writeI)
+ }
+ seenWriteI[writeI] = true
+ }
+ })
+}
+
+func Test_handleGRO(t *testing.T) {
+ tests := []struct {
+ name string
+ pktsIn [][]byte
+ canUDPGRO bool
+ wantToWrite []int
+ wantLens []int
+ wantErr bool
+ }{
+ {
+ "multiple protocols and flows",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ true,
+ []int{0, 1, 2, 4, 5, 7, 9},
+ []int{240, 228, 128, 140, 260, 160, 248},
+ false,
+ },
+ {
+ "multiple protocols and flows no UDP GRO",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1
+ tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1
+ tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2
+ udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1
+ },
+ false,
+ []int{0, 1, 2, 4, 5, 7, 8, 9, 10},
+ []int{240, 128, 128, 140, 260, 160, 128, 148, 148},
+ false,
+ },
+ {
+ "PSH interleaved",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1
+ },
+ true,
+ []int{0, 2, 4, 6},
+ []int{240, 240, 260, 260},
+ false,
+ },
+ {
+ "coalesceItemInvalidCSum",
+ [][]byte{
+ flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ },
+ true,
+ []int{0, 1, 3, 4},
+ []int{140, 240, 128, 228},
+ false,
+ },
+ {
+ "out of order",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100
+ },
+ true,
+ []int{0},
+ []int{340},
+ false,
+ },
+ {
+ "unequal TTL",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TTL++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal ToS",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.TOS++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags more fragments set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 1
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "unequal flags DF set",
+ [][]byte{
+ tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1),
+ tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ udp4Packet(ip4PortA, ip4PortB, 100),
+ udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) {
+ fields.Flags = 2
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{140, 140, 128, 128},
+ false,
+ },
+ {
+ "ipv6 unequal hop limit",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.HopLimit++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ {
+ "ipv6 unequal traffic class",
+ [][]byte{
+ tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1),
+ tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ udp6Packet(ip6PortA, ip6PortB, 100),
+ udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) {
+ fields.TrafficClass++
+ }),
+ },
+ true,
+ []int{0, 1, 2, 3},
+ []int{160, 160, 148, 148},
+ false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ toWrite := make([]int, 0, len(tt.pktsIn))
+ err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite)
+ if err != nil {
+ if tt.wantErr {
+ return
+ }
+ t.Fatalf("got err: %v", err)
+ }
+ if len(toWrite) != len(tt.wantToWrite) {
+ t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite))
+ }
+ for i, pktI := range tt.wantToWrite {
+ if tt.wantToWrite[i] != toWrite[i] {
+ t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i])
+ }
+ if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) {
+ t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:]))
+ }
+ }
+ })
+ }
+}
+
+func Test_packetIsGROCandidate(t *testing.T) {
+ tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp4TooShort := tcp4[:39]
+ ip4InvalidHeaderLen := make([]byte, len(tcp4))
+ copy(ip4InvalidHeaderLen, tcp4)
+ ip4InvalidHeaderLen[0] = 0x46
+ ip4InvalidProtocol := make([]byte, len(tcp4))
+ copy(ip4InvalidProtocol, tcp4)
+ ip4InvalidProtocol[9] = unix.IPPROTO_GRE
+
+ tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:]
+ tcp6TooShort := tcp6[:59]
+ ip6InvalidProtocol := make([]byte, len(tcp6))
+ copy(ip6InvalidProtocol, tcp6)
+ ip6InvalidProtocol[6] = unix.IPPROTO_GRE
+
+ udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:]
+ udp4TooShort := udp4[:27]
+
+ udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:]
+ udp6TooShort := udp6[:47]
+
+ tests := []struct {
+ name string
+ b []byte
+ canUDPGRO bool
+ want groCandidateType
+ }{
+ {
+ "tcp4",
+ tcp4,
+ true,
+ tcp4GROCandidate,
+ },
+ {
+ "tcp6",
+ tcp6,
+ true,
+ tcp6GROCandidate,
+ },
+ {
+ "udp4",
+ udp4,
+ true,
+ udp4GROCandidate,
+ },
+ {
+ "udp4 no support",
+ udp4,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp6",
+ udp6,
+ true,
+ udp6GROCandidate,
+ },
+ {
+ "udp6 no support",
+ udp6,
+ false,
+ notGROCandidate,
+ },
+ {
+ "udp4 too short",
+ udp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "udp6 too short",
+ udp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp4 too short",
+ tcp4TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "tcp6 too short",
+ tcp6TooShort,
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP version",
+ []byte{0x00},
+ true,
+ notGROCandidate,
+ },
+ {
+ "invalid IP header len",
+ ip4InvalidHeaderLen,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip4 invalid protocol",
+ ip4InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ {
+ "ip6 invalid protocol",
+ ip6InvalidProtocol,
+ true,
+ notGROCandidate,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want {
+ t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func Test_udpPacketsCanCoalesce(t *testing.T) {
+ udp4a := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4b := udp4Packet(ip4PortA, ip4PortB, 100)
+ udp4c := udp4Packet(ip4PortA, ip4PortB, 110)
+
+ type args struct {
+ pkt []byte
+ iphLen uint8
+ gsoSize uint16
+ item udpGROItem
+ bufs [][]byte
+ bufsOffset int
+ }
+ tests := []struct {
+ name string
+ args args
+ want canCoalesce
+ }{
+ {
+ "coalesceAppend equal gso",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceAppend smaller gso",
+ args{
+ pkt: udp4a[offset : len(udp4a)-90],
+ iphLen: 20,
+ gsoSize: 10,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceAppend,
+ },
+ {
+ "coalesceUnavailable smaller gso previously appended",
+ args{
+ pkt: udp4a[offset:],
+ iphLen: 20,
+ gsoSize: 100,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4c,
+ udp4b,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ {
+ "coalesceUnavailable larger following smaller",
+ args{
+ pkt: udp4c[offset:],
+ iphLen: 20,
+ gsoSize: 110,
+ item: udpGROItem{
+ gsoSize: 100,
+ iphLen: 20,
+ },
+ bufs: [][]byte{
+ udp4a,
+ udp4c,
+ },
+ bufsOffset: offset,
+ },
+ coalesceUnavailable,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want {
+ t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/tun/operateonfd.go b/tun/operateonfd.go
index 31747a2..f1beb6d 100644
--- a/tun/operateonfd.go
+++ b/tun/operateonfd.go
@@ -1,8 +1,8 @@
-// +build !windows
+//go:build darwin || freebsd
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
diff --git a/tun/tun.go b/tun/tun.go
index 5395bdb..0ae53d0 100644
--- a/tun/tun.go
+++ b/tun/tun.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -18,12 +18,36 @@ const (
)
type Device interface {
- File() *os.File // returns the file descriptor of the device
- Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
- Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
- Flush() error // flush all previous writes to the device
- MTU() (int, error) // returns the MTU of the device
- Name() (string, error) // fetches and returns the current name
- Events() chan Event // returns a constant channel of events related to the device
- Close() error // stops the device and closes the event channel
+ // File returns the file descriptor of the device.
+ File() *os.File
+
+ // Read one or more packets from the Device (without any additional headers).
+ // On a successful read it returns the number of packets read, and sets
+ // packet lengths within the sizes slice. len(sizes) must be >= len(bufs).
+ // A nonzero offset can be used to instruct the Device on where to begin
+ // reading into each element of the bufs slice.
+ Read(bufs [][]byte, sizes []int, offset int) (n int, err error)
+
+ // Write one or more packets to the device (without any additional headers).
+ // On a successful write it returns the number of packets written. A nonzero
+ // offset can be used to instruct the Device on where to begin writing from
+ // each packet contained within the bufs slice.
+ Write(bufs [][]byte, offset int) (int, error)
+
+ // MTU returns the MTU of the Device.
+ MTU() (int, error)
+
+ // Name returns the current name of the Device.
+ Name() (string, error)
+
+ // Events returns a channel of type Event, which is fed Device events.
+ Events() <-chan Event
+
+ // Close stops the Device and closes the Event channel.
+ Close() error
+
+ // BatchSize returns the preferred/max number of packets that can be read or
+ // written in a single read/write call. BatchSize must not change over the
+ // lifetime of a Device.
+ BatchSize() int
}
diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go
index 6d2e6dd..c9a6c0b 100644
--- a/tun/tun_darwin.go
+++ b/tun/tun_darwin.go
@@ -1,46 +1,46 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
+ "errors"
"fmt"
- "io/ioutil"
+ "io"
"net"
"os"
+ "sync"
"syscall"
+ "time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
const utunControlName = "com.apple.net.utun_control"
-// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h
-const _CTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3
-
-// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h
-type sockaddrCtl struct {
- scLen uint8
- scFamily uint8
- ssSysaddr uint16
- scID uint32
- scUnit uint32
- scReserved [5]uint32
-}
-
type NativeTun struct {
name string
tunFile *os.File
events chan Event
errors chan error
routeSocket int
+ closeOnce sync.Once
}
-var sockaddrCtlSize uintptr = 32
+func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
+ for i := 0; i < 20; i++ {
+ iface, err = net.InterfaceByIndex(index)
+ if err != nil && errors.Is(err, unix.ENOMEM) {
+ time.Sleep(time.Duration(i) * time.Second / 3)
+ continue
+ }
+ return iface, err
+ }
+ return nil, err
+}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var (
@@ -55,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
+ if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR {
goto retry
}
tun.errors <- err
@@ -74,7 +74,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
continue
}
- iface, err := net.InterfaceByIndex(ifindex)
+ iface, err := retryInterfaceByIndex(ifindex)
if err != nil {
tun.errors <- err
return
@@ -107,53 +107,33 @@ func CreateTUN(name string, mtu int) (Device, error) {
}
}
- fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
-
+ fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2)
if err != nil {
return nil, err
}
- var ctlInfo = &struct {
- ctlID uint32
- ctlName [96]byte
- }{}
-
- copy(ctlInfo.ctlName[:], []byte(utunControlName))
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(_CTLIOCGINFO),
- uintptr(unsafe.Pointer(ctlInfo)),
- )
-
- if errno != 0 {
- return nil, fmt.Errorf("_CTLIOCGINFO: %v", errno)
+ ctlInfo := &unix.CtlInfo{}
+ copy(ctlInfo.Name[:], []byte(utunControlName))
+ err = unix.IoctlCtlInfo(fd, ctlInfo)
+ if err != nil {
+ unix.Close(fd)
+ return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
}
- sc := sockaddrCtl{
- scLen: uint8(sockaddrCtlSize),
- scFamily: unix.AF_SYSTEM,
- ssSysaddr: 2,
- scID: ctlInfo.ctlID,
- scUnit: uint32(ifIndex) + 1,
+ sc := &unix.SockaddrCtl{
+ ID: ctlInfo.Id,
+ Unit: uint32(ifIndex) + 1,
}
- scPointer := unsafe.Pointer(&sc)
-
- _, _, errno = unix.RawSyscall(
- unix.SYS_CONNECT,
- uintptr(fd),
- uintptr(scPointer),
- uintptr(sockaddrCtlSize),
- )
-
- if errno != 0 {
- return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
+ err = unix.Connect(fd, sc)
+ if err != nil {
+ unix.Close(fd)
+ return nil, err
}
- err = syscall.SetNonblock(fd, true)
+ err = unix.SetNonblock(fd, true)
if err != nil {
+ unix.Close(fd)
return nil, err
}
tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu)
@@ -161,7 +141,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if err == nil && name == "utun" {
fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" {
- ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400)
+ os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
}
}
@@ -193,7 +173,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
@@ -213,27 +193,19 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
}
func (tun *NativeTun) Name() (string, error) {
- var ifName struct {
- name [16]byte
- }
- ifNameSize := uintptr(16)
-
- var errno syscall.Errno
+ var err error
tun.operateOnFd(func(fd uintptr) {
- _, _, errno = unix.Syscall6(
- unix.SYS_GETSOCKOPT,
- fd,
+ tun.name, err = unix.GetsockoptString(
+ int(fd),
2, /* #define SYSPROTO_CONTROL 2 */
2, /* #define UTUN_OPT_IFNAME 2 */
- uintptr(unsafe.Pointer(&ifName)),
- uintptr(unsafe.Pointer(&ifNameSize)), 0)
+ )
})
- if errno != 0 {
- return "", fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
+ if err != nil {
+ return "", fmt.Errorf("GetSockoptString: %w", err)
}
- tun.name = string(ifName.name[:ifNameSize-1])
return tun.name, nil
}
@@ -241,61 +213,63 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ // TODO: the BSDs look very similar in Read() and Write(). They should be
+ // collapsed, with platform-specific files containing the varying parts of
+ // their implementations.
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
- var err2 error
- err1 := tun.tunFile.Close()
- if tun.routeSocket != -1 {
- unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
- err2 = unix.Close(tun.routeSocket)
- } else if tun.events != nil {
- close(tun.events)
- }
+ var err1, err2 error
+ tun.closeOnce.Do(func() {
+ err1 = tun.tunFile.Close()
+ if tun.routeSocket != -1 {
+ unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
+ err2 = unix.Close(tun.routeSocket)
+ } else if tun.events != nil {
+ close(tun.events)
+ }
+ })
if err1 != nil {
return err1
}
@@ -303,71 +277,60 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) setMTU(n int) error {
-
- // open datagram socket
-
- var fd int
-
- fd, err := unix.Socket(
+ fd, err := socketCloexec(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
-
if err != nil {
return err
}
defer unix.Close(fd)
- // do ioctl call
-
- var ifr [32]byte
- copy(ifr[:], tun.name)
- *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCSIFMTU),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
-
- if errno != 0 {
- return fmt.Errorf("failed to set MTU on %s", tun.name)
+ var ifr unix.IfreqMTU
+ copy(ifr.Name[:], tun.name)
+ ifr.MTU = int32(n)
+ err = unix.IoctlSetIfreqMTU(fd, &ifr)
+ if err != nil {
+ return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err)
}
return nil
}
func (tun *NativeTun) MTU() (int, error) {
-
- // open datagram socket
-
- fd, err := unix.Socket(
+ fd, err := socketCloexec(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
-
if err != nil {
return 0, err
}
defer unix.Close(fd)
- // do ioctl call
-
- var ifr [64]byte
- copy(ifr[:], tun.name)
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCGIFMTU),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
+ ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name)
+ if err != nil {
+ return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err)
}
- return int(*(*int32)(unsafe.Pointer(&ifr[16]))), nil
+ return int(ifr.MTU), nil
+}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
+
+func socketCloexec(family, sotype, proto int) (fd int, err error) {
+ // See go/src/net/sys_cloexec.go for background.
+ syscall.ForkLock.RLock()
+ defer syscall.ForkLock.RUnlock()
+
+ fd, err = unix.Socket(family, sotype, proto)
+ if err == nil {
+ unix.CloseOnExec(fd)
+ }
+ return
}
diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go
index 6cf9313..7c65fd9 100644
--- a/tun/tun_freebsd.go
+++ b/tun/tun_freebsd.go
@@ -1,66 +1,57 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
- "bytes"
"errors"
"fmt"
+ "io"
"net"
"os"
+ "sync"
"syscall"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
-// _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h
-// const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96))
const (
_TUNSIFHEAD = 0x80047460
_TUNSIFMODE = 0x8004745e
+ _TUNGIFNAME = 0x4020745d
_TUNSIFPID = 0x2000745f
-)
-// TODO: move into x/sys/unix
-const (
- SIOCGIFINFO_IN6 = 0xc048696c
- SIOCSIFINFO_IN6 = 0xc048696d
- ND6_IFF_AUTO_LINKLOCAL = 0x20
- ND6_IFF_NO_DAD = 0x100
+ _SIOCGIFINFO_IN6 = 0xc048696c
+ _SIOCSIFINFO_IN6 = 0xc048696d
+ _ND6_IFF_AUTO_LINKLOCAL = 0x20
+ _ND6_IFF_NO_DAD = 0x100
)
-// Iface status string max len
-const _IFSTATMAX = 800
-
-const SIZEOF_UINTPTR = 4 << (^uintptr(0) >> 32 & 1)
+// Iface requests with just the name
+type ifreqName struct {
+ Name [unix.IFNAMSIZ]byte
+ _ [16]byte
+}
-// structure for iface requests with a pointer
-type ifreq_ptr struct {
+// Iface requests with a pointer
+type ifreqPtr struct {
Name [unix.IFNAMSIZ]byte
Data uintptr
- Pad0 [16 - SIZEOF_UINTPTR]byte
+ _ [16 - unsafe.Sizeof(uintptr(0))]byte
}
-// Structure for iface mtu get/set ioctls
-type ifreq_mtu struct {
+// Iface requests with MTU
+type ifreqMtu struct {
Name [unix.IFNAMSIZ]byte
MTU uint32
- Pad0 [12]byte
-}
-
-// Structure for interface status request ioctl
-type ifstat struct {
- IfsName [unix.IFNAMSIZ]byte
- Ascii [_IFSTATMAX]byte
+ _ [12]byte
}
-// Structures for nd6 flag manipulation
-type in6_ndireq struct {
+// ND6 flag manipulation
+type nd6Req struct {
Name [unix.IFNAMSIZ]byte
Linkmtu uint32
Maxmtu uint32
@@ -82,6 +73,7 @@ type NativeTun struct {
events chan Event
errors chan error
routeSocket int
+ closeOnce sync.Once
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@@ -97,7 +89,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
+ if errors.Is(err, syscall.EINTR) {
goto retry
}
tun.errors <- err
@@ -141,91 +133,17 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
}
func tunName(fd uintptr) (string, error) {
- //Terrible hack to make up for freebsd not having a TUNGIFNAME
-
- //First, make sure the tun pid matches this proc's pid
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(_TUNSIFPID),
- uintptr(0),
- )
-
- if errno != 0 {
- return "", fmt.Errorf("failed to set tun device PID: %s", errno.Error())
- }
-
- // Open iface control socket
-
- confd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
-
- if err != nil {
+ var ifreq ifreqName
+ _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq)))
+ if err != 0 {
return "", err
}
-
- defer unix.Close(confd)
-
- procPid := os.Getpid()
-
- //Try to find interface with matching PID
- for i := 1; ; i++ {
- iface, _ := net.InterfaceByIndex(i)
- if err != nil || iface == nil {
- break
- }
-
- // Structs for getting data in and out of SIOCGIFSTATUS ioctl
- var ifstatus ifstat
- copy(ifstatus.IfsName[:], iface.Name)
-
- // Make the syscall to get the status string
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(confd),
- uintptr(unix.SIOCGIFSTATUS),
- uintptr(unsafe.Pointer(&ifstatus)),
- )
-
- if errno != 0 {
- continue
- }
-
- nullStr := ifstatus.Ascii[:]
- i := bytes.IndexByte(nullStr, 0)
- if i < 1 {
- continue
- }
- statStr := string(nullStr[:i])
- var pidNum int = 0
-
- // Finally get the owning PID
- // Format string taken from sys/net/if_tun.c
- _, err := fmt.Sscanf(statStr, "\tOpened by PID %d\n", &pidNum)
- if err != nil {
- continue
- }
-
- if pidNum == procPid {
- return iface.Name, nil
- }
- }
-
- return "", nil
+ return unix.ByteSliceToString(ifreq.Name[:]), nil
}
// Destroy a named system interface
func tunDestroy(name string) error {
- // Open control socket.
- var fd int
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return err
}
@@ -233,14 +151,9 @@ func tunDestroy(name string) error {
var ifr [32]byte
copy(ifr[:], name)
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCIFDESTROY),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0])))
if errno != 0 {
- return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error())
+ return fmt.Errorf("failed to destroy interface %s: %w", name, errno)
}
return nil
@@ -257,7 +170,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, fmt.Errorf("interface %s already exists", name)
}
- tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR, 0)
+ tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
return nil, err
}
@@ -276,103 +189,94 @@ func CreateTUN(name string, mtu int) (Device, error) {
ifheadmode := 1
var errno syscall.Errno
tun.operateOnFd(func(fd uintptr) {
- _, _, errno = unix.Syscall(
- unix.SYS_IOCTL,
- fd,
- uintptr(_TUNSIFHEAD),
- uintptr(unsafe.Pointer(&ifheadmode)),
- )
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode)))
})
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
- return nil, fmt.Errorf("Unable to put into IFHEAD mode: %v", errno)
+ return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno)
}
- // Open control sockets
- confd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
- if err != nil {
+ // Get out of PTP mode.
+ ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST
+ tun.operateOnFd(func(fd uintptr) {
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags)))
+ })
+
+ if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
- return nil, err
+ return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno)
}
- defer unix.Close(confd)
- confd6, err := unix.Socket(
- unix.AF_INET6,
- unix.SOCK_DGRAM,
- 0,
- )
+
+ // Disable link-local v6, not just because WireGuard doesn't do that anyway, but
+ // also because there are serious races with attaching and detaching LLv6 addresses
+ // in relation to interface lifetime within the FreeBSD kernel.
+ confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err
}
defer unix.Close(confd6)
-
- // Disable link-local v6, not just because WireGuard doesn't do that anyway, but
- // also because there are serious races with attaching and detaching LLv6 addresses
- // in relation to interface lifetime within the FreeBSD kernel.
- var ndireq in6_ndireq
+ var ndireq nd6Req
copy(ndireq.Name[:], assignedName)
- _, _, errno = unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(confd6),
- uintptr(SIOCGIFINFO_IN6),
- uintptr(unsafe.Pointer(&ndireq)),
- )
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
- return nil, fmt.Errorf("Unable to get nd6 flags for %s: %v", assignedName, errno)
+ return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno)
}
- ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL
- ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD
- _, _, errno = unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(confd6),
- uintptr(SIOCSIFINFO_IN6),
- uintptr(unsafe.Pointer(&ndireq)),
- )
+ ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL
+ ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
- return nil, fmt.Errorf("Unable to set nd6 flags for %s: %v", assignedName, errno)
+ return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno)
}
- // Rename the interface
- var newnp [unix.IFNAMSIZ]byte
- copy(newnp[:], name)
- var ifr ifreq_ptr
- copy(ifr.Name[:], assignedName)
- ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
- _, _, errno = unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(confd),
- uintptr(unix.SIOCSIFNAME),
- uintptr(unsafe.Pointer(&ifr)),
- )
- if errno != 0 {
- tunFile.Close()
- tunDestroy(assignedName)
- return nil, fmt.Errorf("Failed to rename %s to %s: %v", assignedName, name, errno)
+ if name != "" {
+ confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
+ if err != nil {
+ tunFile.Close()
+ tunDestroy(assignedName)
+ return nil, err
+ }
+ defer unix.Close(confd)
+ var newnp [unix.IFNAMSIZ]byte
+ copy(newnp[:], name)
+ var ifr ifreqPtr
+ copy(ifr.Name[:], assignedName)
+ ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr)))
+ if errno != 0 {
+ tunFile.Close()
+ tunDestroy(assignedName)
+ return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno)
+ }
}
return CreateTUNFromFile(tunFile, mtu)
}
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
-
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 10),
errors: make(chan error, 1),
}
+ var errno syscall.Errno
+ tun.operateOnFd(func(fd uintptr) {
+ _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0))
+ })
+ if errno != 0 {
+ tun.tunFile.Close()
+ return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno)
+ }
+
name, err := tun.Name()
if err != nil {
tun.tunFile.Close()
@@ -391,7 +295,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
@@ -425,63 +329,65 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ if len(buf) < 5 {
+ return i, io.ErrShortBuffer
+ }
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
- var err3 error
- err1 := tun.tunFile.Close()
- err2 := tunDestroy(tun.name)
- if tun.routeSocket != -1 {
- unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
- err3 = unix.Close(tun.routeSocket)
- tun.routeSocket = -1
- } else if tun.events != nil {
- close(tun.events)
- }
+ var err1, err2, err3 error
+ tun.closeOnce.Do(func() {
+ err1 = tun.tunFile.Close()
+ err2 = tunDestroy(tun.name)
+ if tun.routeSocket != -1 {
+ unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
+ err3 = unix.Close(tun.routeSocket)
+ tun.routeSocket = -1
+ } else if tun.events != nil {
+ close(tun.events)
+ }
+ })
if err1 != nil {
return err1
}
@@ -492,70 +398,38 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) setMTU(n int) error {
- // open datagram socket
-
- var fd int
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
-
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return err
}
-
defer unix.Close(fd)
- // do ioctl call
-
- var ifr ifreq_mtu
+ var ifr ifreqMtu
copy(ifr.Name[:], tun.name)
ifr.MTU = uint32(n)
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCSIFMTU),
- uintptr(unsafe.Pointer(&ifr)),
- )
-
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
- return fmt.Errorf("failed to set MTU on %s", tun.name)
+ return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno)
}
-
return nil
}
func (tun *NativeTun) MTU() (int, error) {
- // open datagram socket
-
- fd, err := unix.Socket(
- unix.AF_INET,
- unix.SOCK_DGRAM,
- 0,
- )
-
+ fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return 0, err
}
-
defer unix.Close(fd)
- // do ioctl call
- var ifr ifreq_mtu
+ var ifr ifreqMtu
copy(ifr.Name[:], tun.name)
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(fd),
- uintptr(unix.SIOCGIFMTU),
- uintptr(unsafe.Pointer(&ifr)),
- )
+ _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
- return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
+ return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno)
}
-
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
diff --git a/tun/tun_linux.go b/tun/tun_linux.go
index 61902e9..bd69cb5 100644
--- a/tun/tun_linux.go
+++ b/tun/tun_linux.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -9,18 +9,16 @@ package tun
*/
import (
- "bytes"
"errors"
"fmt"
- "net"
"os"
"sync"
"syscall"
"time"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
@@ -32,14 +30,29 @@ const (
type NativeTun struct {
tunFile *os.File
index int32 // if index
- name string // name of interface
errors chan error // async error handling
events chan Event // device related events
- nopi bool // the device was pased IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
+ batchSize int
+ vnetHdr bool
+ udpGSO bool
+
+ closeOnce sync.Once
+
+ nameOnce sync.Once // guards calling initNameCache, which sets following fields
+ nameCache string // name of interface
+ nameErr error
+
+ readOpMu sync.Mutex // readOpMu guards readBuff
+ readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
+
+ writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable
+ toWrite []int
+ tcpGROTable *tcpGROTable
+ udpGROTable *udpGROTable
}
func (tun *NativeTun) File() *os.File {
@@ -51,6 +64,11 @@ func (tun *NativeTun) routineHackListener() {
/* This is needed for the detection to work across network namespaces
* If you are reading this and know a better method, please get in touch.
*/
+ last := 0
+ const (
+ up = 1
+ down = 2
+ )
for {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
@@ -64,14 +82,25 @@ func (tun *NativeTun) routineHackListener() {
}
switch err {
case unix.EINVAL:
- tun.events <- EventUp
+ if last != up {
+ // If the tunnel is up, it reports that write() is
+ // allowed but we provided invalid data.
+ tun.events <- EventUp
+ last = up
+ }
case unix.EIO:
- tun.events <- EventDown
+ if last != down {
+ // If the tunnel is down, it reports that no I/O
+ // is possible, without checking our provided data.
+ tun.events <- EventDown
+ last = down
+ }
default:
return
}
select {
case <-time.After(time.Second):
+ // nothing
case <-tun.statusListenersShutdown:
return
}
@@ -79,13 +108,13 @@ func (tun *NativeTun) routineHackListener() {
}
func createNetlinkSocket() (int, error) {
- sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
- Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))),
+ Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
}
err = unix.Bind(sock, saddr)
if err != nil {
@@ -99,10 +128,10 @@ func (tun *NativeTun) routineNetlinkListener() {
unix.Close(tun.netlinkSock)
tun.hackListenerClosed.Lock()
close(tun.events)
+ tun.netlinkCancel.Close()
}()
for msg := make([]byte, 1<<16); ; {
-
var err error
var msgn int
for {
@@ -111,12 +140,12 @@ func (tun *NativeTun) routineNetlinkListener() {
break
}
if !tun.netlinkCancel.ReadyRead() {
- tun.errors <- fmt.Errorf("netlink socket closed: %s", err.Error())
+ tun.errors <- fmt.Errorf("netlink socket closed: %w", err)
return
}
}
if err != nil {
- tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error())
+ tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err)
return
}
@@ -126,6 +155,7 @@ func (tun *NativeTun) routineNetlinkListener() {
default:
}
+ wasEverUp := false
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
@@ -149,10 +179,16 @@ func (tun *NativeTun) routineNetlinkListener() {
if info.Flags&unix.IFF_RUNNING != 0 {
tun.events <- EventUp
+ wasEverUp = true
}
if info.Flags&unix.IFF_RUNNING == 0 {
- tun.events <- EventDown
+ // Don't emit EventDown before we've ever emitted EventUp.
+ // This avoids a startup race with HackListener, which
+ // might detect Up before we have finished reporting Down.
+ if wasEverUp {
+ tun.events <- EventDown
+ }
}
tun.events <- EventMTUUpdate
@@ -164,15 +200,10 @@ func (tun *NativeTun) routineNetlinkListener() {
}
}
-func (tun *NativeTun) isUp() (bool, error) {
- inter, err := net.InterfaceByName(tun.name)
- return inter.Flags&net.FlagUp != 0, err
-}
-
func getIFIndex(name string) (int32, error) {
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
if err != nil {
@@ -198,13 +229,17 @@ func getIFIndex(name string) (int32, error) {
}
func (tun *NativeTun) setMTU(n int) error {
+ name, err := tun.Name()
+ if err != nil {
+ return err
+ }
+
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return err
}
@@ -212,9 +247,8 @@ func (tun *NativeTun) setMTU(n int) error {
defer unix.Close(fd)
// do ioctl call
-
var ifr [ifReqSize]byte
- copy(ifr[:], tun.name)
+ copy(ifr[:], name)
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
@@ -224,20 +258,24 @@ func (tun *NativeTun) setMTU(n int) error {
)
if errno != 0 {
- return errors.New("failed to set MTU of TUN device")
+ return fmt.Errorf("failed to set MTU of TUN device: %w", errno)
}
return nil
}
func (tun *NativeTun) MTU() (int, error) {
+ name, err := tun.Name()
+ if err != nil {
+ return 0, err
+ }
+
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return 0, err
}
@@ -247,7 +285,7 @@ func (tun *NativeTun) MTU() (int, error) {
// do ioctl call
var ifr [ifReqSize]byte
- copy(ifr[:], tun.name)
+ copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
@@ -255,13 +293,22 @@ func (tun *NativeTun) MTU() (int, error) {
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
- return 0, errors.New("failed to get MTU of TUN device: " + errno.Error())
+ return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno)
}
return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
}
func (tun *NativeTun) Name() (string, error) {
+ tun.nameOnce.Do(tun.initNameCache)
+ return tun.nameCache, tun.nameErr
+}
+
+func (tun *NativeTun) initNameCache() {
+ tun.nameCache, tun.nameErr = tun.nameSlow()
+}
+
+func (tun *NativeTun) nameSlow() (string, error) {
sysconn, err := tun.tunFile.SyscallConn()
if err != nil {
return "", err
@@ -277,147 +324,287 @@ func (tun *NativeTun) Name() (string, error) {
)
})
if err != nil {
- return "", errors.New("failed to get name of TUN device: " + err.Error())
+ return "", fmt.Errorf("failed to get name of TUN device: %w", err)
}
if errno != 0 {
- return "", errors.New("failed to get name of TUN device: " + errno.Error())
- }
- nullStr := ifr[:]
- i := bytes.IndexByte(nullStr, 0)
- if i != -1 {
- nullStr = nullStr[:i]
+ return "", fmt.Errorf("failed to get name of TUN device: %w", errno)
}
- tun.name = string(nullStr)
- return tun.name, nil
+ return unix.ByteSliceToString(ifr[:]), nil
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-
- if tun.nopi {
- buff = buff[offset:]
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ tun.writeOpMu.Lock()
+ defer func() {
+ tun.tcpGROTable.reset()
+ tun.udpGROTable.reset()
+ tun.writeOpMu.Unlock()
+ }()
+ var (
+ errs error
+ total int
+ )
+ tun.toWrite = tun.toWrite[:0]
+ if tun.vnetHdr {
+ err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite)
+ if err != nil {
+ return 0, err
+ }
+ offset -= virtioNetHdrLen
} else {
- // reserve space for header
+ for i := range bufs {
+ tun.toWrite = append(tun.toWrite, i)
+ }
+ }
+ for _, bufsI := range tun.toWrite {
+ n, err := tun.tunFile.Write(bufs[bufsI][offset:])
+ if errors.Is(err, syscall.EBADFD) {
+ return total, os.ErrClosed
+ }
+ if err != nil {
+ errs = errors.Join(errs, err)
+ } else {
+ total += n
+ }
+ }
+ return total, errs
+}
- buff = buff[offset-4:]
+// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
+// each buffer. It mutates sizes to reflect the size of each element of bufs,
+// and returns the number of packets read.
+func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
+ var hdr virtioNetHdr
+ err := hdr.decode(in)
+ if err != nil {
+ return 0, err
+ }
+ in = in[virtioNetHdrLen:]
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
+ if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
+ // This means CHECKSUM_PARTIAL in skb context. We are responsible
+ // for computing the checksum starting at hdr.csumStart and placing
+ // at hdr.csumOffset.
+ err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
+ if err != nil {
+ return 0, err
+ }
+ }
+ if len(in) > len(bufs[0][offset:]) {
+ return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
+ }
+ n := copy(bufs[0][offset:], in)
+ sizes[0] = n
+ return 1, nil
+ }
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
+ }
- // add packet information header
+ ipVersion := in[0] >> 4
+ switch ipVersion {
+ case 4:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ case 6:
+ if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
+ }
+ default:
+ return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
+ }
- buff[0] = 0x00
- buff[1] = 0x00
+ // Don't trust hdr.hdrLen from the kernel as it can be equal to the length
+ // of the entire first packet when the kernel is handling it as part of a
+ // FORWARD path. Instead, parse the transport header length and add it onto
+ // csumStart, which is synonymous for IP header length.
+ if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 {
+ hdr.hdrLen = hdr.csumStart + 8
+ } else {
+ if len(in) <= int(hdr.csumStart+12) {
+ return 0, errors.New("packet is too short")
+ }
- if buff[4]>>4 == ipv6.Version {
- buff[2] = 0x86
- buff[3] = 0xdd
- } else {
- buff[2] = 0x08
- buff[3] = 0x00
+ tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
+ if tcpHLen < 20 || tcpHLen > 60 {
+ // A TCP header must be between 20 and 60 bytes in length.
+ return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
+ hdr.hdrLen = hdr.csumStart + tcpHLen
}
- // write
+ if len(in) < int(hdr.hdrLen) {
+ return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
+ }
- return tun.tunFile.Write(buff)
-}
+ if hdr.hdrLen < hdr.csumStart {
+ return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
+ }
+ cSumAt := int(hdr.csumStart + hdr.csumOffset)
+ if cSumAt+1 >= len(in) {
+ return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
+ }
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6)
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ tun.readOpMu.Lock()
+ defer tun.readOpMu.Unlock()
select {
case err := <-tun.errors:
return 0, err
default:
- if tun.nopi {
- return tun.tunFile.Read(buff[offset:])
+ readInto := bufs[0][offset:]
+ if tun.vnetHdr {
+ readInto = tun.readBuff[:]
+ }
+ n, err := tun.tunFile.Read(readInto)
+ if errors.Is(err, syscall.EBADFD) {
+ err = os.ErrClosed
+ }
+ if err != nil {
+ return 0, err
+ }
+ if tun.vnetHdr {
+ return handleVirtioRead(readInto[:n], bufs, sizes, offset)
} else {
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
- if n < 4 {
- return 0, err
- }
- return n - 4, err
+ sizes[0] = n
+ return 1, nil
}
}
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Close() error {
- var err1 error
- if tun.statusListenersShutdown != nil {
- close(tun.statusListenersShutdown)
- if tun.netlinkCancel != nil {
- err1 = tun.netlinkCancel.Cancel()
+ var err1, err2 error
+ tun.closeOnce.Do(func() {
+ if tun.statusListenersShutdown != nil {
+ close(tun.statusListenersShutdown)
+ if tun.netlinkCancel != nil {
+ err1 = tun.netlinkCancel.Cancel()
+ }
+ } else if tun.events != nil {
+ close(tun.events)
}
- } else if tun.events != nil {
- close(tun.events)
- }
- err2 := tun.tunFile.Close()
-
+ err2 = tun.tunFile.Close()
+ })
if err1 != nil {
return err1
}
return err2
}
+func (tun *NativeTun) BatchSize() int {
+ return tun.batchSize
+}
+
+const (
+ // TODO: support TSO with ECN bits
+ tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
+ tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6
+)
+
+func (tun *NativeTun) initFromFlags(name string) error {
+ sc, err := tun.tunFile.SyscallConn()
+ if err != nil {
+ return err
+ }
+ if e := sc.Control(func(fd uintptr) {
+ var (
+ ifr *unix.Ifreq
+ )
+ ifr, err = unix.NewIfreq(name)
+ if err != nil {
+ return
+ }
+ err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
+ if err != nil {
+ return
+ }
+ got := ifr.Uint16()
+ if got&unix.IFF_VNET_HDR != 0 {
+ // tunTCPOffloads were added in Linux v2.6. We require their support
+ // if IFF_VNET_HDR is set.
+ err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads)
+ if err != nil {
+ return
+ }
+ tun.vnetHdr = true
+ tun.batchSize = conn.IdealBatchSize
+ // tunUDPOffloads were added in Linux v6.2. We do not return an
+ // error if they are unsupported at runtime.
+ tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil
+ } else {
+ tun.batchSize = 1
+ }
+ }); e != nil {
+ return e
+ }
+ return err
+}
+
+// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
- nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
+ nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
+ if os.IsNotExist(err) {
+ return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
+ }
return nil, err
}
- var ifr [ifReqSize]byte
- var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
- nameBytes := []byte(name)
- if len(nameBytes) >= unix.IFNAMSIZ {
- return nil, errors.New("interface name too long")
+ ifr, err := unix.NewIfreq(name)
+ if err != nil {
+ return nil, err
}
- copy(ifr[:], nameBytes)
- *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
-
- _, _, errno := unix.Syscall(
- unix.SYS_IOCTL,
- uintptr(nfd),
- uintptr(unix.TUNSETIFF),
- uintptr(unsafe.Pointer(&ifr[0])),
- )
- if errno != 0 {
- return nil, errno
+ // IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
+ // where a null write will return EINVAL indicating the TUN is up.
+ ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
+ err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
+ if err != nil {
+ return nil, err
}
- err = unix.SetNonblock(nfd, true)
- // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
-
- fd := os.NewFile(uintptr(nfd), cloneDevicePath)
+ err = unix.SetNonblock(nfd, true)
if err != nil {
+ unix.Close(nfd)
return nil, err
}
+ // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line.
+
+ fd := os.NewFile(uintptr(nfd), cloneDevicePath)
return CreateTUNFromFile(fd, mtu)
}
+// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
- nopi: false,
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
- var err error
- _, err = tun.Name()
+ name, err := tun.Name()
if err != nil {
return nil, err
}
- // start event listener
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, err
+ }
- tun.index, err = getIFIndex(tun.name)
+ // start event listener
+ tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
}
@@ -445,6 +632,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return tun, nil
}
+// CreateUnmonitoredTUNFromFD creates a Device from the provided file
+// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true)
if err != nil {
@@ -452,14 +641,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
- tunFile: file,
- events: make(chan Event, 5),
- errors: make(chan error, 5),
- nopi: true,
+ tunFile: file,
+ events: make(chan Event, 5),
+ errors: make(chan error, 5),
+ tcpGROTable: newTCPGROTable(),
+ udpGROTable: newUDPGROTable(),
+ toWrite: make([]int, 0, conn.IdealBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, "", err
}
- return tun, name, nil
+ err = tun.initFromFlags(name)
+ if err != nil {
+ return nil, "", err
+ }
+ return tun, name, err
}
diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go
index 44cedaa..ae571b9 100644
--- a/tun/tun_openbsd.go
+++ b/tun/tun_openbsd.go
@@ -1,19 +1,20 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
+ "errors"
"fmt"
- "io/ioutil"
+ "io"
"net"
"os"
+ "sync"
"syscall"
"unsafe"
- "golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
)
@@ -32,6 +33,7 @@ type NativeTun struct {
events chan Event
errors chan error
routeSocket int
+ closeOnce sync.Once
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
@@ -99,16 +101,6 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
}
}
-func errorIsEBUSY(err error) bool {
- if pe, ok := err.(*os.PathError); ok {
- err = pe.Err
- }
- if errno, ok := err.(syscall.Errno); ok && errno == syscall.EBUSY {
- return true
- }
- return false
-}
-
func CreateTUN(name string, mtu int) (Device, error) {
ifIndex := -1
if name != "tun" {
@@ -122,11 +114,11 @@ func CreateTUN(name string, mtu int) (Device, error) {
var err error
if ifIndex != -1 {
- tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0)
+ tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
} else {
- for ifIndex = 0; ifIndex < 256; ifIndex += 1 {
- tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0)
- if err == nil || !errorIsEBUSY(err) {
+ for ifIndex = 0; ifIndex < 256; ifIndex++ {
+ tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0)
+ if err == nil || !errors.Is(err, syscall.EBUSY) {
break
}
}
@@ -141,7 +133,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
if err == nil && name == "tun" {
fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" {
- ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400)
+ os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400)
}
}
@@ -173,7 +165,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
- tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
+ tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC)
if err != nil {
tun.tunFile.Close()
return nil, err
@@ -208,62 +200,61 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
- buff := buff[offset-4:]
- n, err := tun.tunFile.Read(buff[:])
+ buf := bufs[0][offset-4:]
+ n, err := tun.tunFile.Read(buf[:])
if n < 4 {
return 0, err
}
- return n - 4, err
+ sizes[0] = n - 4
+ return 1, err
}
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
-
- // reserve space for header
-
- buff = buff[offset-4:]
-
- // add packet information header
-
- buff[0] = 0x00
- buff[1] = 0x00
- buff[2] = 0x00
-
- if buff[4]>>4 == ipv6.Version {
- buff[3] = unix.AF_INET6
- } else {
- buff[3] = unix.AF_INET
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ if offset < 4 {
+ return 0, io.ErrShortBuffer
}
-
- // write
-
- return tun.tunFile.Write(buff)
-}
-
-func (tun *NativeTun) Flush() error {
- // TODO: can flushing be implemented by buffering and using sendmmsg?
- return nil
+ for i, buf := range bufs {
+ buf = buf[offset-4:]
+ buf[0] = 0x00
+ buf[1] = 0x00
+ buf[2] = 0x00
+ switch buf[4] >> 4 {
+ case 4:
+ buf[3] = unix.AF_INET
+ case 6:
+ buf[3] = unix.AF_INET6
+ default:
+ return i, unix.EAFNOSUPPORT
+ }
+ if _, err := tun.tunFile.Write(buf); err != nil {
+ return i, err
+ }
+ }
+ return len(bufs), nil
}
func (tun *NativeTun) Close() error {
- var err2 error
- err1 := tun.tunFile.Close()
- if tun.routeSocket != -1 {
- unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
- err2 = unix.Close(tun.routeSocket)
- tun.routeSocket = -1
- } else if tun.events != nil {
- close(tun.events)
- }
+ var err1, err2 error
+ tun.closeOnce.Do(func() {
+ err1 = tun.tunFile.Close()
+ if tun.routeSocket != -1 {
+ unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
+ err2 = unix.Close(tun.routeSocket)
+ tun.routeSocket = -1
+ } else if tun.events != nil {
+ close(tun.events)
+ }
+ })
if err1 != nil {
return err1
}
@@ -277,10 +268,9 @@ func (tun *NativeTun) setMTU(n int) error {
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return err
}
@@ -312,10 +302,9 @@ func (tun *NativeTun) MTU() (int, error) {
fd, err := unix.Socket(
unix.AF_INET,
- unix.SOCK_DGRAM,
+ unix.SOCK_DGRAM|unix.SOCK_CLOEXEC,
0,
)
-
if err != nil {
return 0, err
}
@@ -338,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) {
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}
+
+func (tun *NativeTun) BatchSize() int {
+ return 1
+}
diff --git a/tun/tun_windows.go b/tun/tun_windows.go
index daad4aa..2af8e3e 100644
--- a/tun/tun_windows.go
+++ b/tun/tun_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2018-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -9,13 +9,13 @@ import (
"errors"
"fmt"
"os"
+ "sync"
"sync/atomic"
"time"
- "unsafe"
+ _ "unsafe"
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/tun/wintun"
+ "golang.zx2c4.com/wintun"
)
const (
@@ -25,24 +25,31 @@ const (
)
type rateJuggler struct {
- current uint64
- nextByteCount uint64
- nextStartTime int64
- changing int32
+ current atomic.Uint64
+ nextByteCount atomic.Uint64
+ nextStartTime atomic.Int64
+ changing atomic.Bool
}
type NativeTun struct {
- wt *wintun.Interface
+ wt *wintun.Adapter
+ name string
handle windows.Handle
- close bool
- rings wintun.RingDescriptor
+ rate rateJuggler
+ session wintun.Session
+ readWait windows.Handle
events chan Event
- errors chan error
+ running sync.WaitGroup
+ closeOnce sync.Once
+ close atomic.Bool
forcedMTU int
- rate rateJuggler
+ outSizes []int
}
-const WintunPool = wintun.Pool("WireGuard")
+var (
+ WintunTunnelType = "WireGuard"
+ WintunStaticRequestedGUID *windows.GUID
+)
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
@@ -50,34 +57,18 @@ func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
-//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
-//
func CreateTUN(ifname string, mtu int) (Device, error) {
- return CreateTUNWithRequestedGUID(ifname, nil, mtu)
+ return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
}
-//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
-//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
- var err error
- var wt *wintun.Interface
-
- // Does an interface with this name already exist?
- wt, err = WintunPool.GetInterface(ifname)
- if err == nil {
- // If so, we delete it, in case it has weird residual configuration.
- _, err = wt.DeleteInterface()
- if err != nil {
- return nil, fmt.Errorf("Error deleting already existing interface: %v", err)
- }
- }
- wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
+ wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
if err != nil {
- return nil, fmt.Errorf("Error creating interface: %v", err)
+ return nil, fmt.Errorf("Error creating interface: %w", err)
}
forcedMTU := 1420
@@ -87,52 +78,46 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun := &NativeTun{
wt: wt,
+ name: ifname,
handle: windows.InvalidHandle,
events: make(chan Event, 10),
- errors: make(chan error, 1),
forcedMTU: forcedMTU,
}
- err = tun.rings.Init()
- if err != nil {
- tun.Close()
- return nil, fmt.Errorf("Error creating events: %v", err)
- }
-
- tun.handle, err = tun.wt.Register(&tun.rings)
+ tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil {
- tun.Close()
- return nil, fmt.Errorf("Error registering rings: %v", err)
+ tun.wt.Close()
+ close(tun.events)
+ return nil, fmt.Errorf("Error starting session: %w", err)
}
+ tun.readWait = tun.session.ReadWaitEvent()
return tun, nil
}
func (tun *NativeTun) Name() (string, error) {
- return tun.wt.Name()
+ return tun.name, nil
}
func (tun *NativeTun) File() *os.File {
return nil
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Close() error {
- tun.close = true
- if tun.rings.Send.TailMoved != 0 {
- windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping
- }
- if tun.handle != windows.InvalidHandle {
- windows.CloseHandle(tun.handle)
- }
- tun.rings.Close()
var err error
- if tun.wt != nil {
- _, err = tun.wt.DeleteInterface()
- }
- close(tun.events)
+ tun.closeOnce.Do(func() {
+ tun.close.Store(true)
+ windows.SetEvent(tun.readWait)
+ tun.running.Wait()
+ tun.session.End()
+ if tun.wt != nil {
+ tun.wt.Close()
+ }
+ close(tun.events)
+ })
return err
}
@@ -142,129 +127,115 @@ func (tun *NativeTun) MTU() (int, error) {
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMTU(mtu int) {
+ if tun.close.Load() {
+ return
+ }
+ update := tun.forcedMTU != mtu
tun.forcedMTU = mtu
+ if update {
+ tun.events <- EventMTUUpdate
+ }
+}
+
+func (tun *NativeTun) BatchSize() int {
+ // TODO: implement batching with wintun
+ return 1
}
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ tun.running.Add(1)
+ defer tun.running.Done()
retry:
- select {
- case err := <-tun.errors:
- return 0, err
- default:
- }
- if tun.close {
- return 0, os.ErrClosed
- }
-
- buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head)
- if buffHead >= wintun.PacketCapacity {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
-
start := nanotime()
- shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
- var buffTail uint32
+ shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
- buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail)
- if buffHead != buffTail {
- break
- }
- if tun.close {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
- if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
- windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE)
- goto retry
+ packet, err := tun.session.ReceivePacket()
+ switch err {
+ case nil:
+ n := copy(bufs[0][offset:], packet)
+ sizes[0] = n
+ tun.session.ReleaseReceivePacket(packet)
+ tun.rate.update(uint64(n))
+ return 1, nil
+ case windows.ERROR_NO_MORE_ITEMS:
+ if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
+ windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
+ goto retry
+ }
+ procyield(1)
+ continue
+ case windows.ERROR_HANDLE_EOF:
+ return 0, os.ErrClosed
+ case windows.ERROR_INVALID_DATA:
+ return 0, errors.New("Send ring corrupt")
}
- procyield(1)
- }
- if buffTail >= wintun.PacketCapacity {
- return 0, os.ErrClosed
- }
-
- buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead)
- if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) {
- return 0, errors.New("incomplete packet header in send ring")
- }
-
- packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead]))
- if packet.Size > wintun.PacketSizeMax {
- return 0, errors.New("packet too big in send ring")
- }
-
- alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size)
- if alignedPacketSize > buffContent {
- return 0, errors.New("incomplete packet in send ring")
+ return 0, fmt.Errorf("Read failed: %w", err)
}
-
- copy(buff[offset:], packet.Data[:packet.Size])
- buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize)
- atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead)
- tun.rate.update(uint64(packet.Size))
- return int(packet.Size), nil
-}
-
-func (tun *NativeTun) Flush() error {
- return nil
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
- if tun.close {
- return 0, os.ErrClosed
- }
-
- packetSize := uint32(len(buff) - offset)
- tun.rate.update(uint64(packetSize))
- alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
-
- buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
- if buffHead >= wintun.PacketCapacity {
- return 0, os.ErrClosed
- }
-
- buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail)
- if buffTail >= wintun.PacketCapacity {
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ tun.running.Add(1)
+ defer tun.running.Done()
+ if tun.close.Load() {
return 0, os.ErrClosed
}
- buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment)
- if alignedPacketSize > buffSpace {
- return 0, nil // Dropping when ring is full.
- }
-
- packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail]))
- packet.Size = packetSize
- copy(packet.Data[:packetSize], buff[offset:])
- atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize))
- if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 {
- windows.SetEvent(tun.rings.Receive.TailMoved)
+ for i, buf := range bufs {
+ packetSize := len(buf) - offset
+ tun.rate.update(uint64(packetSize))
+
+ packet, err := tun.session.AllocateSendPacket(packetSize)
+ switch err {
+ case nil:
+ // TODO: Explore options to eliminate this copy.
+ copy(packet, buf[offset:])
+ tun.session.SendPacket(packet)
+ continue
+ case windows.ERROR_HANDLE_EOF:
+ return i, os.ErrClosed
+ case windows.ERROR_BUFFER_OVERFLOW:
+ continue // Dropping when ring is full.
+ default:
+ return i, fmt.Errorf("Write failed: %w", err)
+ }
}
- return int(packetSize), nil
+ return len(bufs), nil
}
// LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 {
+ tun.running.Add(1)
+ defer tun.running.Done()
+ if tun.close.Load() {
+ return 0
+ }
return tun.wt.LUID()
}
-// Version returns the version of the Wintun driver and NDIS system currently loaded.
-func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) {
- return tun.wt.Version()
+// RunningVersion returns the running version of the Wintun driver.
+func (tun *NativeTun) RunningVersion() (version uint32, err error) {
+ return wintun.RunningVersion()
}
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
- total := atomic.AddUint64(&rate.nextByteCount, packetLen)
- period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
+ total := rate.nextByteCount.Add(packetLen)
+ period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity {
- if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
+ if !rate.changing.CompareAndSwap(false, true) {
return
}
- atomic.StoreInt64(&rate.nextStartTime, now)
- atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
- atomic.StoreUint64(&rate.nextByteCount, 0)
- atomic.StoreInt32(&rate.changing, 0)
+ rate.nextStartTime.Store(now)
+ rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
+ rate.nextByteCount.Store(0)
+ rate.changing.Store(false)
}
}
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
new file mode 100644
index 0000000..d07e860
--- /dev/null
+++ b/tun/tuntest/tuntest.go
@@ -0,0 +1,155 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package tuntest
+
+import (
+ "encoding/binary"
+ "io"
+ "net/netip"
+ "os"
+
+ "golang.zx2c4.com/wireguard/tun"
+)
+
+func Ping(dst, src netip.Addr) []byte {
+ localPort := uint16(1337)
+ seq := uint16(0)
+
+ payload := make([]byte, 4)
+ binary.BigEndian.PutUint16(payload[0:], localPort)
+ binary.BigEndian.PutUint16(payload[2:], seq)
+
+ return genICMPv4(payload, dst, src)
+}
+
+// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
+func checksum(buf []byte, initial uint16) uint16 {
+ v := uint32(initial)
+ for i := 0; i < len(buf)-1; i += 2 {
+ v += uint32(binary.BigEndian.Uint16(buf[i:]))
+ }
+ if len(buf)%2 == 1 {
+ v += uint32(buf[len(buf)-1]) << 8
+ }
+ for v > 0xffff {
+ v = (v >> 16) + (v & 0xffff)
+ }
+ return ^uint16(v)
+}
+
+func genICMPv4(payload []byte, dst, src netip.Addr) []byte {
+ const (
+ icmpv4ProtocolNumber = 1
+ icmpv4Echo = 8
+ icmpv4ChecksumOffset = 2
+ icmpv4Size = 8
+ ipv4Size = 20
+ ipv4TotalLenOffset = 2
+ ipv4ChecksumOffset = 10
+ ttl = 65
+ headerSize = ipv4Size + icmpv4Size
+ )
+
+ pkt := make([]byte, headerSize+len(payload))
+
+ ip := pkt[0:ipv4Size]
+ icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size]
+
+ // https://tools.ietf.org/html/rfc792
+ icmpv4[0] = icmpv4Echo // type
+ icmpv4[1] = 0 // code
+ chksum := ^checksum(icmpv4, checksum(payload, 0))
+ binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
+
+ // https://tools.ietf.org/html/rfc760 section 3.1
+ length := uint16(len(pkt))
+ ip[0] = (4 << 4) | (ipv4Size / 4)
+ binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
+ ip[8] = ttl
+ ip[9] = icmpv4ProtocolNumber
+ copy(ip[12:], src.AsSlice())
+ copy(ip[16:], dst.AsSlice())
+ chksum = ^checksum(ip[:], 0)
+ binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
+
+ copy(pkt[headerSize:], payload)
+ return pkt
+}
+
+type ChannelTUN struct {
+ Inbound chan []byte // incoming packets, closed on TUN close
+ Outbound chan []byte // outbound packets, blocks forever on TUN close
+
+ closed chan struct{}
+ events chan tun.Event
+ tun chTun
+}
+
+func NewChannelTUN() *ChannelTUN {
+ c := &ChannelTUN{
+ Inbound: make(chan []byte),
+ Outbound: make(chan []byte),
+ closed: make(chan struct{}),
+ events: make(chan tun.Event, 1),
+ }
+ c.tun.c = c
+ c.events <- tun.EventUp
+ return c
+}
+
+func (c *ChannelTUN) TUN() tun.Device {
+ return &c.tun
+}
+
+type chTun struct {
+ c *ChannelTUN
+}
+
+func (t *chTun) File() *os.File { return nil }
+
+func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) {
+ select {
+ case <-t.c.closed:
+ return 0, os.ErrClosed
+ case msg := <-t.c.Outbound:
+ n := copy(packets[0][offset:], msg)
+ sizes[0] = n
+ return 1, nil
+ }
+}
+
+// Write is called by the wireguard device to deliver a packet for routing.
+func (t *chTun) Write(packets [][]byte, offset int) (int, error) {
+ if offset == -1 {
+ close(t.c.closed)
+ close(t.c.events)
+ return 0, io.EOF
+ }
+ for i, data := range packets {
+ msg := make([]byte, len(data)-offset)
+ copy(msg, data[offset:])
+ select {
+ case <-t.c.closed:
+ return i, os.ErrClosed
+ case t.c.Inbound <- msg:
+ }
+ }
+ return len(packets), nil
+}
+
+func (t *chTun) BatchSize() int {
+ return 1
+}
+
+const DefaultMTU = 1420
+
+func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
+func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
+func (t *chTun) Events() <-chan tun.Event { return t.c.events }
+func (t *chTun) Close() error {
+ t.Write(nil, -1)
+ return nil
+}
diff --git a/tun/wintun/iphlpapi/conversion_windows.go b/tun/wintun/iphlpapi/conversion_windows.go
deleted file mode 100644
index a19e961..0000000
--- a/tun/wintun/iphlpapi/conversion_windows.go
+++ /dev/null
@@ -1,25 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package iphlpapi
-
-import "golang.org/x/sys/windows"
-
-//sys convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid
-//sys convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) = iphlpapi.ConvertInterfaceAliasToLuid
-
-func InterfaceGUIDFromAlias(alias string) (*windows.GUID, error) {
- var luid uint64
- var guid windows.GUID
- err := convertInterfaceAliasToLUID(windows.StringToUTF16Ptr(alias), &luid)
- if err != nil {
- return nil, err
- }
- err = convertInterfaceLUIDToGUID(&luid, &guid)
- if err != nil {
- return nil, err
- }
- return &guid, nil
-}
diff --git a/tun/wintun/iphlpapi/mksyscall.go b/tun/wintun/iphlpapi/mksyscall.go
deleted file mode 100644
index fc7dba4..0000000
--- a/tun/wintun/iphlpapi/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package iphlpapi
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go
diff --git a/tun/wintun/iphlpapi/zsyscall_windows.go b/tun/wintun/iphlpapi/zsyscall_windows.go
deleted file mode 100644
index dc14294..0000000
--- a/tun/wintun/iphlpapi/zsyscall_windows.go
+++ /dev/null
@@ -1,60 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package iphlpapi
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
-
- procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
- procConvertInterfaceAliasToLuid = modiphlpapi.NewProc("ConvertInterfaceAliasToLuid")
-)
-
-func convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) {
- r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
- if r0 != 0 {
- ret = syscall.Errno(r0)
- }
- return
-}
-
-func convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) {
- r0, _, _ := syscall.Syscall(procConvertInterfaceAliasToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceAlias)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
- if r0 != 0 {
- ret = syscall.Errno(r0)
- }
- return
-}
diff --git a/tun/wintun/namespace_windows.go b/tun/wintun/namespace_windows.go
deleted file mode 100644
index f4316fe..0000000
--- a/tun/wintun/namespace_windows.go
+++ /dev/null
@@ -1,98 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "encoding/hex"
- "errors"
- "fmt"
- "sync"
- "unsafe"
-
- "golang.org/x/crypto/blake2s"
- "golang.org/x/sys/windows"
- "golang.org/x/text/unicode/norm"
-
- "golang.zx2c4.com/wireguard/tun/wintun/namespaceapi"
-)
-
-var (
- wintunObjectSecurityAttributes *windows.SecurityAttributes
- hasInitializedNamespace bool
- initializingNamespace sync.Mutex
-)
-
-func initializeNamespace() error {
- initializingNamespace.Lock()
- defer initializingNamespace.Unlock()
- if hasInitializedNamespace {
- return nil
- }
- sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
- if err != nil {
- return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err)
- }
- wintunObjectSecurityAttributes = &windows.SecurityAttributes{
- Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
- SecurityDescriptor: sd,
- }
- sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
- if err != nil {
- return fmt.Errorf("CreateWellKnownSid(LOCAL_SYSTEM) failed: %v", err)
- }
-
- boundary, err := namespaceapi.CreateBoundaryDescriptor("Wintun")
- if err != nil {
- return fmt.Errorf("CreateBoundaryDescriptor failed: %v", err)
- }
- err = boundary.AddSid(sid)
- if err != nil {
- return fmt.Errorf("AddSIDToBoundaryDescriptor failed: %v", err)
- }
- for {
- _, err = namespaceapi.CreatePrivateNamespace(wintunObjectSecurityAttributes, boundary, "Wintun")
- if err == windows.ERROR_ALREADY_EXISTS {
- _, err = namespaceapi.OpenPrivateNamespace(boundary, "Wintun")
- if err == windows.ERROR_PATH_NOT_FOUND {
- continue
- }
- }
- if err != nil {
- return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err)
- }
- break
- }
- hasInitializedNamespace = true
- return nil
-}
-
-func (pool Pool) takeNameMutex() (windows.Handle, error) {
- err := initializeNamespace()
- if err != nil {
- return 0, err
- }
-
- const mutexLabel = "WireGuard Adapter Name Mutex Stable Suffix v1 jason@zx2c4.com"
- b2, _ := blake2s.New256(nil)
- b2.Write([]byte(mutexLabel))
- b2.Write(norm.NFC.Bytes([]byte(string(pool))))
- mutexName := `Wintun\Wintun-Name-Mutex-` + hex.EncodeToString(b2.Sum(nil))
- mutex, err := windows.CreateMutex(wintunObjectSecurityAttributes, false, windows.StringToUTF16Ptr(mutexName))
- if err != nil {
- err = fmt.Errorf("Error creating name mutex: %v", err)
- return 0, err
- }
- event, err := windows.WaitForSingleObject(mutex, windows.INFINITE)
- if err != nil {
- windows.CloseHandle(mutex)
- return 0, fmt.Errorf("Error waiting on name mutex: %v", err)
- }
- if event != windows.WAIT_OBJECT_0 && event != windows.WAIT_ABANDONED {
- windows.CloseHandle(mutex)
- return 0, errors.New("Error with event trigger of name mutex")
- }
- return mutex, nil
-}
diff --git a/tun/wintun/namespaceapi/mksyscall.go b/tun/wintun/namespaceapi/mksyscall.go
deleted file mode 100644
index 93d43b0..0000000
--- a/tun/wintun/namespaceapi/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package namespaceapi
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go
diff --git a/tun/wintun/namespaceapi/namespaceapi_windows.go b/tun/wintun/namespaceapi/namespaceapi_windows.go
deleted file mode 100644
index a3a6274..0000000
--- a/tun/wintun/namespaceapi/namespaceapi_windows.go
+++ /dev/null
@@ -1,83 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package namespaceapi
-
-import "golang.org/x/sys/windows"
-
-//sys createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) = kernel32.CreateBoundaryDescriptorW
-//sys deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) = kernel32.DeleteBoundaryDescriptor
-//sys addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) = kernel32.AddSIDToBoundaryDescriptor
-//sys createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.CreatePrivateNamespaceW
-//sys openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.OpenPrivateNamespaceW
-//sys closePrivateNamespace(handle windows.Handle, flags uint32) (err error) = kernel32.ClosePrivateNamespace
-
-// BoundaryDescriptor represents a boundary that defines how the objects in the namespace are to be isolated.
-type BoundaryDescriptor windows.Handle
-
-// CreateBoundaryDescriptor creates a boundary descriptor.
-func CreateBoundaryDescriptor(name string) (BoundaryDescriptor, error) {
- name16, err := windows.UTF16PtrFromString(name)
- if err != nil {
- return 0, err
- }
- handle, err := createBoundaryDescriptor(name16, 0)
- if err != nil {
- return 0, err
- }
- return BoundaryDescriptor(handle), nil
-}
-
-// Delete deletes the specified boundary descriptor.
-func (bd BoundaryDescriptor) Delete() {
- deleteBoundaryDescriptor(windows.Handle(bd))
-}
-
-// AddSid adds a security identifier (SID) to the specified boundary descriptor.
-func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
- return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
-}
-
-// PrivateNamespace represents a private namespace.
-type PrivateNamespace windows.Handle
-
-// CreatePrivateNamespace creates a private namespace.
-func CreatePrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
- aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
- if err != nil {
- return 0, err
- }
- handle, err := createPrivateNamespace(privateNamespaceAttributes, windows.Handle(boundaryDescriptor), aliasPrefix16)
- if err != nil {
- return 0, err
- }
- return PrivateNamespace(handle), nil
-}
-
-// OpenPrivateNamespace opens a private namespace.
-func OpenPrivateNamespace(boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
- aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
- if err != nil {
- return 0, err
- }
- handle, err := openPrivateNamespace(windows.Handle(boundaryDescriptor), aliasPrefix16)
- if err != nil {
- return 0, err
- }
- return PrivateNamespace(handle), nil
-}
-
-// ClosePrivateNamespaceFlags describes flags that are used by PrivateNamespace's Close() method.
-type ClosePrivateNamespaceFlags uint32
-
-const (
- // PrivateNamespaceFlagDestroy makes the close to destroy the namespace.
- PrivateNamespaceFlagDestroy = ClosePrivateNamespaceFlags(0x1)
-)
-
-// Close closes an open namespace handle.
-func (pns PrivateNamespace) Close(flags ClosePrivateNamespaceFlags) error {
- return closePrivateNamespace(windows.Handle(pns), uint32(flags))
-}
diff --git a/tun/wintun/namespaceapi/zsyscall_windows.go b/tun/wintun/namespaceapi/zsyscall_windows.go
deleted file mode 100644
index 508c223..0000000
--- a/tun/wintun/namespaceapi/zsyscall_windows.go
+++ /dev/null
@@ -1,116 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package namespaceapi
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
-
- procCreateBoundaryDescriptorW = modkernel32.NewProc("CreateBoundaryDescriptorW")
- procDeleteBoundaryDescriptor = modkernel32.NewProc("DeleteBoundaryDescriptor")
- procAddSIDToBoundaryDescriptor = modkernel32.NewProc("AddSIDToBoundaryDescriptor")
- procCreatePrivateNamespaceW = modkernel32.NewProc("CreatePrivateNamespaceW")
- procOpenPrivateNamespaceW = modkernel32.NewProc("OpenPrivateNamespaceW")
- procClosePrivateNamespace = modkernel32.NewProc("ClosePrivateNamespace")
-)
-
-func createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall(procCreateBoundaryDescriptorW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(flags), 0)
- handle = windows.Handle(r0)
- if handle == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) {
- syscall.Syscall(procDeleteBoundaryDescriptor.Addr(), 1, uintptr(boundaryDescriptor), 0, 0)
- return
-}
-
-func addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) {
- r1, _, e1 := syscall.Syscall(procAddSIDToBoundaryDescriptor.Addr(), 2, uintptr(unsafe.Pointer(boundaryDescriptor)), uintptr(unsafe.Pointer(requiredSid)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall(procCreatePrivateNamespaceW.Addr(), 3, uintptr(unsafe.Pointer(privateNamespaceAttributes)), uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)))
- handle = windows.Handle(r0)
- if handle == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall(procOpenPrivateNamespaceW.Addr(), 2, uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)), 0)
- handle = windows.Handle(r0)
- if handle == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func closePrivateNamespace(handle windows.Handle, flags uint32) (err error) {
- r1, _, e1 := syscall.Syscall(procClosePrivateNamespace.Addr(), 2, uintptr(handle), uintptr(flags), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
diff --git a/tun/wintun/nci/mksyscall.go b/tun/wintun/nci/mksyscall.go
deleted file mode 100644
index 019da93..0000000
--- a/tun/wintun/nci/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package nci
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go
diff --git a/tun/wintun/nci/nci_windows.go b/tun/wintun/nci/nci_windows.go
deleted file mode 100644
index 9dc6699..0000000
--- a/tun/wintun/nci/nci_windows.go
+++ /dev/null
@@ -1,28 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package nci
-
-import "golang.org/x/sys/windows"
-
-//sys nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) = nci.NciSetConnectionName
-//sys nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) = nci.NciGetConnectionName
-
-func SetConnectionName(guid *windows.GUID, newName string) error {
- newName16, err := windows.UTF16PtrFromString(newName)
- if err != nil {
- return err
- }
- return nciSetConnectionName(guid, newName16)
-}
-
-func ConnectionName(guid *windows.GUID) (string, error) {
- var name [0x400]uint16
- err := nciGetConnectionName(guid, &name[0], uint32(len(name)*2), nil)
- if err != nil {
- return "", err
- }
- return windows.UTF16ToString(name[:]), nil
-}
diff --git a/tun/wintun/nci/zsyscall_windows.go b/tun/wintun/nci/zsyscall_windows.go
deleted file mode 100644
index 2a7b79e..0000000
--- a/tun/wintun/nci/zsyscall_windows.go
+++ /dev/null
@@ -1,60 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package nci
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modnci = windows.NewLazySystemDLL("nci.dll")
-
- procNciSetConnectionName = modnci.NewProc("NciSetConnectionName")
- procNciGetConnectionName = modnci.NewProc("NciGetConnectionName")
-)
-
-func nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) {
- r0, _, _ := syscall.Syscall(procNciSetConnectionName.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0)
- if r0 != 0 {
- ret = syscall.Errno(r0)
- }
- return
-}
-
-func nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) {
- r0, _, _ := syscall.Syscall6(procNciGetConnectionName.Addr(), 4, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(destName)), uintptr(inDestNameBytes), uintptr(unsafe.Pointer(outDestNameBytes)), 0, 0)
- if r0 != 0 {
- ret = syscall.Errno(r0)
- }
- return
-}
diff --git a/tun/wintun/registry/mksyscall.go b/tun/wintun/registry/mksyscall.go
deleted file mode 100644
index 6ad82d2..0000000
--- a/tun/wintun/registry/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package registry
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zregistry_windows.go registry_windows.go
diff --git a/tun/wintun/registry/registry_windows.go b/tun/wintun/registry/registry_windows.go
deleted file mode 100644
index 12a0336..0000000
--- a/tun/wintun/registry/registry_windows.go
+++ /dev/null
@@ -1,272 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package registry
-
-import (
- "errors"
- "fmt"
- "runtime"
- "strings"
- "time"
- "unsafe"
-
- "golang.org/x/sys/windows"
- "golang.org/x/sys/windows/registry"
-)
-
-const (
- // REG_NOTIFY_CHANGE_NAME notifies the caller if a subkey is added or deleted.
- REG_NOTIFY_CHANGE_NAME uint32 = 0x00000001
-
- // REG_NOTIFY_CHANGE_ATTRIBUTES notifies the caller of changes to the attributes of the key, such as the security descriptor information.
- REG_NOTIFY_CHANGE_ATTRIBUTES uint32 = 0x00000002
-
- // REG_NOTIFY_CHANGE_LAST_SET notifies the caller of changes to a value of the key. This can include adding or deleting a value, or changing an existing value.
- REG_NOTIFY_CHANGE_LAST_SET uint32 = 0x00000004
-
- // REG_NOTIFY_CHANGE_SECURITY notifies the caller of changes to the security descriptor of the key.
- REG_NOTIFY_CHANGE_SECURITY uint32 = 0x00000008
-
- // REG_NOTIFY_THREAD_AGNOSTIC indicates that the lifetime of the registration must not be tied to the lifetime of the thread issuing the RegNotifyChangeKeyValue call. Note: This flag value is only supported in Windows 8 and later.
- REG_NOTIFY_THREAD_AGNOSTIC uint32 = 0x10000000
-)
-
-//sys regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) = advapi32.RegNotifyChangeKeyValue
-
-func OpenKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) {
- runtime.LockOSThread()
- defer runtime.UnlockOSThread()
-
- deadline := time.Now().Add(timeout)
- pathSpl := strings.Split(path, "\\")
- for i := 0; ; i++ {
- keyName := pathSpl[i]
- isLast := i+1 == len(pathSpl)
-
- event, err := windows.CreateEvent(nil, 0, 0, nil)
- if err != nil {
- return 0, fmt.Errorf("Error creating event: %v", err)
- }
- defer windows.CloseHandle(event)
-
- var key registry.Key
- for {
- err = regNotifyChangeKeyValue(windows.Handle(k), false, REG_NOTIFY_CHANGE_NAME, windows.Handle(event), true)
- if err != nil {
- return 0, fmt.Errorf("Setting up change notification on registry key failed: %v", err)
- }
-
- var accessFlags uint32
- if isLast {
- accessFlags = access
- } else {
- accessFlags = registry.NOTIFY
- }
- key, err = registry.OpenKey(k, keyName, accessFlags)
- if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
- timeout := time.Until(deadline) / time.Millisecond
- if timeout < 0 {
- timeout = 0
- }
- s, err := windows.WaitForSingleObject(event, uint32(timeout))
- if err != nil {
- return 0, fmt.Errorf("Unable to wait on registry key: %v", err)
- }
- if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
- return 0, errors.New("Timeout waiting for registry key")
- }
- } else if err != nil {
- return 0, fmt.Errorf("Error opening registry key %v: %v", path, err)
- } else {
- if isLast {
- return key, nil
- }
- defer key.Close()
- break
- }
- }
-
- k = key
- }
-}
-
-func WaitForKey(k registry.Key, path string, timeout time.Duration) error {
- key, err := OpenKeyWait(k, path, registry.NOTIFY, timeout)
- if err != nil {
- return err
- }
- key.Close()
- return nil
-}
-
-//
-// getValue is more or less the same as windows/registry's getValue.
-//
-func getValue(k registry.Key, name string, buf []byte) (value []byte, valueType uint32, err error) {
- var name16 *uint16
- name16, err = windows.UTF16PtrFromString(name)
- if err != nil {
- return
- }
- n := uint32(len(buf))
- for {
- err = windows.RegQueryValueEx(windows.Handle(k), name16, nil, &valueType, (*byte)(unsafe.Pointer(&buf[0])), &n)
- if err == nil {
- value = buf[:n]
- return
- }
- if err != windows.ERROR_MORE_DATA {
- return
- }
- if n <= uint32(len(buf)) {
- return
- }
- buf = make([]byte, n)
- }
-}
-
-//
-// getValueRetry function reads any value from registry. It waits for
-// the registry value to become available or returns error on timeout.
-//
-// Key must be opened with at least QUERY_VALUE|NOTIFY access.
-//
-func getValueRetry(key registry.Key, name string, buf []byte, timeout time.Duration) ([]byte, uint32, error) {
- runtime.LockOSThread()
- defer runtime.UnlockOSThread()
-
- event, err := windows.CreateEvent(nil, 0, 0, nil)
- if err != nil {
- return nil, 0, fmt.Errorf("Error creating event: %v", err)
- }
- defer windows.CloseHandle(event)
-
- deadline := time.Now().Add(timeout)
- for {
- err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true)
- if err != nil {
- return nil, 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err)
- }
-
- buf, valueType, err := getValue(key, name, buf)
- if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
- timeout := time.Until(deadline) / time.Millisecond
- if timeout < 0 {
- timeout = 0
- }
- s, err := windows.WaitForSingleObject(event, uint32(timeout))
- if err != nil {
- return nil, 0, fmt.Errorf("Unable to wait on registry value: %v", err)
- }
- if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
- return nil, 0, errors.New("Timeout waiting for registry value")
- }
- } else if err != nil {
- return nil, 0, fmt.Errorf("Error reading registry value %v: %v", name, err)
- } else {
- return buf, valueType, nil
- }
- }
-}
-
-func toString(buf []byte, valueType uint32, err error) (string, error) {
- if err != nil {
- return "", err
- }
-
- var value string
- switch valueType {
- case registry.SZ, registry.EXPAND_SZ, registry.MULTI_SZ:
- if len(buf) == 0 {
- return "", nil
- }
- value = windows.UTF16ToString((*[(1 << 30) - 1]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2])
-
- default:
- return "", registry.ErrUnexpectedType
- }
-
- if valueType != registry.EXPAND_SZ {
- // Value does not require expansion.
- return value, nil
- }
-
- valueExp, err := registry.ExpandString(value)
- if err != nil {
- // Expanding failed: return original sting value.
- return value, nil
- }
-
- // Return expanded value.
- return valueExp, nil
-}
-
-func toInteger(buf []byte, valueType uint32, err error) (uint64, error) {
- if err != nil {
- return 0, err
- }
-
- switch valueType {
- case registry.DWORD:
- if len(buf) != 4 {
- return 0, errors.New("DWORD value is not 4 bytes long")
- }
- var val uint32
- copy((*[4]byte)(unsafe.Pointer(&val))[:], buf)
- return uint64(val), nil
-
- case registry.QWORD:
- if len(buf) != 8 {
- return 0, errors.New("QWORD value is not 8 bytes long")
- }
- var val uint64
- copy((*[8]byte)(unsafe.Pointer(&val))[:], buf)
- return val, nil
-
- default:
- return 0, registry.ErrUnexpectedType
- }
-}
-
-//
-// GetStringValueWait function reads a string value from registry. It waits
-// for the registry value to become available or returns error on timeout.
-//
-// Key must be opened with at least QUERY_VALUE|NOTIFY access.
-//
-// If the value type is REG_EXPAND_SZ the environment variables are expanded.
-// Should expanding fail, original string value and nil error are returned.
-//
-// If the value type is REG_MULTI_SZ only the first string is returned.
-//
-func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
- return toString(getValueRetry(key, name, make([]byte, 256), timeout))
-}
-
-//
-// GetStringValue function reads a string value from registry.
-//
-// Key must be opened with at least QUERY_VALUE access.
-//
-// If the value type is REG_EXPAND_SZ the environment variables are expanded.
-// Should expanding fail, original string value and nil error are returned.
-//
-// If the value type is REG_MULTI_SZ only the first string is returned.
-//
-func GetStringValue(key registry.Key, name string) (string, error) {
- return toString(getValue(key, name, make([]byte, 256)))
-}
-
-//
-// GetIntegerValueWait function reads a DWORD32 or QWORD value from registry.
-// It waits for the registry value to become available or returns error on
-// timeout.
-//
-// Key must be opened with at least QUERY_VALUE|NOTIFY access.
-//
-func GetIntegerValueWait(key registry.Key, name string, timeout time.Duration) (uint64, error) {
- return toInteger(getValueRetry(key, name, make([]byte, 8), timeout))
-}
diff --git a/tun/wintun/registry/registry_windows_test.go b/tun/wintun/registry/registry_windows_test.go
deleted file mode 100644
index c56b51b..0000000
--- a/tun/wintun/registry/registry_windows_test.go
+++ /dev/null
@@ -1,103 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package registry
-
-import (
- "testing"
- "time"
-
- "golang.org/x/sys/windows/registry"
-)
-
-const keyRoot = registry.CURRENT_USER
-const pathRoot = "Software\\WireGuardRegistryTest"
-const path = pathRoot + "\\foobar"
-const pathFake = pathRoot + "\\raboof"
-
-func Test_WaitForKey(t *testing.T) {
- registry.DeleteKey(keyRoot, path)
- registry.DeleteKey(keyRoot, pathRoot)
- go func() {
- time.Sleep(time.Second * 1)
- key, _, err := registry.CreateKey(keyRoot, pathFake, registry.QUERY_VALUE)
- if err != nil {
- t.Errorf("Error creating registry key: %v", err)
- }
- key.Close()
- registry.DeleteKey(keyRoot, pathFake)
-
- key, _, err = registry.CreateKey(keyRoot, path, registry.QUERY_VALUE)
- if err != nil {
- t.Errorf("Error creating registry key: %v", err)
- }
- key.Close()
- }()
- err := WaitForKey(keyRoot, path, time.Second*2)
- if err != nil {
- t.Errorf("Error waiting for registry key: %v", err)
- }
- registry.DeleteKey(keyRoot, path)
- registry.DeleteKey(keyRoot, pathRoot)
-
- err = WaitForKey(keyRoot, path, time.Second*1)
- if err == nil {
- t.Error("Registry key notification expected to timeout but it succeeded.")
- }
-}
-
-func Test_GetValueWait(t *testing.T) {
- registry.DeleteKey(keyRoot, path)
- registry.DeleteKey(keyRoot, pathRoot)
- go func() {
- time.Sleep(time.Second * 1)
- key, _, err := registry.CreateKey(keyRoot, path, registry.SET_VALUE)
- if err != nil {
- t.Errorf("Error creating registry key: %v", err)
- }
- time.Sleep(time.Second * 1)
- key.SetStringValue("name1", "eulav")
- key.SetExpandStringValue("name2", "value")
- time.Sleep(time.Second * 1)
- key.SetDWordValue("name3", ^uint32(123))
- key.SetDWordValue("name4", 123)
- key.Close()
- }()
-
- key, err := OpenKeyWait(keyRoot, path, registry.QUERY_VALUE|registry.NOTIFY, time.Second*2)
- if err != nil {
- t.Errorf("Error waiting for registry key: %v", err)
- }
-
- valueStr, err := GetStringValueWait(key, "name2", time.Second*2)
- if err != nil {
- t.Errorf("Error waiting for registry value: %v", err)
- }
- if valueStr != "value" {
- t.Errorf("Wrong value read: %v", valueStr)
- }
-
- _, err = GetStringValueWait(key, "nonexisting", time.Second*1)
- if err == nil {
- t.Error("Registry value notification expected to timeout but it succeeded.")
- }
-
- valueInt, err := GetIntegerValueWait(key, "name4", time.Second*2)
- if err != nil {
- t.Errorf("Error waiting for registry value: %v", err)
- }
- if valueInt != 123 {
- t.Errorf("Wrong value read: %v", valueInt)
- }
-
- _, err = GetIntegerValueWait(key, "nonexisting", time.Second*1)
- if err == nil {
- t.Error("Registry value notification expected to timeout but it succeeded.")
- }
-
- key.Close()
- registry.DeleteKey(keyRoot, path)
- registry.DeleteKey(keyRoot, pathRoot)
-}
diff --git a/tun/wintun/registry/zregistry_windows.go b/tun/wintun/registry/zregistry_windows.go
deleted file mode 100644
index f7ac33b..0000000
--- a/tun/wintun/registry/zregistry_windows.go
+++ /dev/null
@@ -1,63 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package registry
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
-
- procRegNotifyChangeKeyValue = modadvapi32.NewProc("RegNotifyChangeKeyValue")
-)
-
-func regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) {
- var _p0 uint32
- if watchSubtree {
- _p0 = 1
- } else {
- _p0 = 0
- }
- var _p1 uint32
- if asynchronous {
- _p1 = 1
- } else {
- _p1 = 0
- }
- r0, _, _ := syscall.Syscall6(procRegNotifyChangeKeyValue.Addr(), 5, uintptr(key), uintptr(_p0), uintptr(notifyFilter), uintptr(event), uintptr(_p1), 0)
- if r0 != 0 {
- regerrno = syscall.Errno(r0)
- }
- return
-}
diff --git a/tun/wintun/ring_windows.go b/tun/wintun/ring_windows.go
deleted file mode 100644
index 8f46bc9..0000000
--- a/tun/wintun/ring_windows.go
+++ /dev/null
@@ -1,97 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-const (
- PacketAlignment = 4 // Number of bytes packets are aligned to in rings
- PacketSizeMax = 0xffff // Maximum packet size
- PacketCapacity = 0x800000 // Ring capacity, 8MiB
- PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment
- ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
-)
-
-type PacketHeader struct {
- Size uint32
-}
-
-type Packet struct {
- PacketHeader
- Data [PacketSizeMax]byte
-}
-
-type Ring struct {
- Head uint32
- Tail uint32
- Alertable int32
- Data [PacketCapacity + PacketTrailingSize]byte
-}
-
-type RingDescriptor struct {
- Send, Receive struct {
- Size uint32
- Ring *Ring
- TailMoved windows.Handle
- }
-}
-
-// Wrap returns value modulo ring capacity
-func (rb *Ring) Wrap(value uint32) uint32 {
- return value & (PacketCapacity - 1)
-}
-
-// Aligns a packet size to PacketAlignment
-func PacketAlign(size uint32) uint32 {
- return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
-}
-
-func (descriptor *RingDescriptor) Init() (err error) {
- descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
- descriptor.Send.Ring = &Ring{}
- descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
- if err != nil {
- return
- }
-
- descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
- descriptor.Receive.Ring = &Ring{}
- descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
- if err != nil {
- windows.CloseHandle(descriptor.Send.TailMoved)
- return
- }
-
- return
-}
-
-func (descriptor *RingDescriptor) Close() {
- if descriptor.Send.TailMoved != 0 {
- windows.CloseHandle(descriptor.Send.TailMoved)
- descriptor.Send.TailMoved = 0
- }
- if descriptor.Send.TailMoved != 0 {
- windows.CloseHandle(descriptor.Receive.TailMoved)
- descriptor.Receive.TailMoved = 0
- }
-}
-
-func (wintun *Interface) Register(descriptor *RingDescriptor) (windows.Handle, error) {
- handle, err := wintun.handle()
- if err != nil {
- return 0, err
- }
- var bytesReturned uint32
- err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil)
- if err != nil {
- return 0, err
- }
- return handle, nil
-}
diff --git a/tun/wintun/setupapi/mksyscall.go b/tun/wintun/setupapi/mksyscall.go
deleted file mode 100644
index ac103a1..0000000
--- a/tun/wintun/setupapi/mksyscall.go
+++ /dev/null
@@ -1,8 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsetupapi_windows.go setupapi_windows.go
diff --git a/tun/wintun/setupapi/setupapi_windows.go b/tun/wintun/setupapi/setupapi_windows.go
deleted file mode 100644
index 60a8eb7..0000000
--- a/tun/wintun/setupapi/setupapi_windows.go
+++ /dev/null
@@ -1,506 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-import (
- "encoding/binary"
- "fmt"
- "runtime"
- "unsafe"
-
- "golang.org/x/sys/windows"
- "golang.org/x/sys/windows/registry"
-)
-
-//sys setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiCreateDeviceInfoListExW
-
-// SetupDiCreateDeviceInfoListEx function creates an empty device information set on a remote or a local computer and optionally associates the set with a device setup class.
-func SetupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName string) (deviceInfoSet DevInfo, err error) {
- var machineNameUTF16 *uint16
- if machineName != "" {
- machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
- if err != nil {
- return
- }
- }
- return setupDiCreateDeviceInfoListEx(classGUID, hwndParent, machineNameUTF16, 0)
-}
-
-//sys setupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo, deviceInfoSetDetailData *DevInfoListDetailData) (err error) = setupapi.SetupDiGetDeviceInfoListDetailW
-
-// SetupDiGetDeviceInfoListDetail function retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name.
-func SetupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo) (deviceInfoSetDetailData *DevInfoListDetailData, err error) {
- data := &DevInfoListDetailData{}
- data.size = sizeofDevInfoListDetailData
-
- return data, setupDiGetDeviceInfoListDetail(deviceInfoSet, data)
-}
-
-// DeviceInfoListDetail method retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name.
-func (deviceInfoSet DevInfo) DeviceInfoListDetail() (*DevInfoListDetailData, error) {
- return SetupDiGetDeviceInfoListDetail(deviceInfoSet)
-}
-
-//sys setupDiCreateDeviceInfo(deviceInfoSet DevInfo, DeviceName *uint16, classGUID *windows.GUID, DeviceDescription *uint16, hwndParent uintptr, CreationFlags DICD, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiCreateDeviceInfoW
-
-// SetupDiCreateDeviceInfo function creates a new device information element and adds it as a new member to the specified device information set.
-func SetupDiCreateDeviceInfo(deviceInfoSet DevInfo, deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (deviceInfoData *DevInfoData, err error) {
- deviceNameUTF16, err := windows.UTF16PtrFromString(deviceName)
- if err != nil {
- return
- }
-
- var deviceDescriptionUTF16 *uint16
- if deviceDescription != "" {
- deviceDescriptionUTF16, err = windows.UTF16PtrFromString(deviceDescription)
- if err != nil {
- return
- }
- }
-
- data := &DevInfoData{}
- data.size = uint32(unsafe.Sizeof(*data))
-
- return data, setupDiCreateDeviceInfo(deviceInfoSet, deviceNameUTF16, classGUID, deviceDescriptionUTF16, hwndParent, creationFlags, data)
-}
-
-// CreateDeviceInfo method creates a new device information element and adds it as a new member to the specified device information set.
-func (deviceInfoSet DevInfo) CreateDeviceInfo(deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (*DevInfoData, error) {
- return SetupDiCreateDeviceInfo(deviceInfoSet, deviceName, classGUID, deviceDescription, hwndParent, creationFlags)
-}
-
-//sys setupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex uint32, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiEnumDeviceInfo
-
-// SetupDiEnumDeviceInfo function returns a DevInfoData structure that specifies a device information element in a device information set.
-func SetupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex int) (*DevInfoData, error) {
- data := &DevInfoData{}
- data.size = uint32(unsafe.Sizeof(*data))
-
- return data, setupDiEnumDeviceInfo(deviceInfoSet, uint32(memberIndex), data)
-}
-
-// EnumDeviceInfo method returns a DevInfoData structure that specifies a device information element in a device information set.
-func (deviceInfoSet DevInfo) EnumDeviceInfo(memberIndex int) (*DevInfoData, error) {
- return SetupDiEnumDeviceInfo(deviceInfoSet, memberIndex)
-}
-
-// SetupDiDestroyDeviceInfoList function deletes a device information set and frees all associated memory.
-//sys SetupDiDestroyDeviceInfoList(deviceInfoSet DevInfo) (err error) = setupapi.SetupDiDestroyDeviceInfoList
-
-// Close method deletes a device information set and frees all associated memory.
-func (deviceInfoSet DevInfo) Close() error {
- return SetupDiDestroyDeviceInfoList(deviceInfoSet)
-}
-
-//sys SetupDiBuildDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) = setupapi.SetupDiBuildDriverInfoList
-
-// BuildDriverInfoList method builds a list of drivers that is associated with a specific device or with the global class driver list for a device information set.
-func (deviceInfoSet DevInfo) BuildDriverInfoList(deviceInfoData *DevInfoData, driverType SPDIT) error {
- return SetupDiBuildDriverInfoList(deviceInfoSet, deviceInfoData, driverType)
-}
-
-//sys SetupDiCancelDriverInfoSearch(deviceInfoSet DevInfo) (err error) = setupapi.SetupDiCancelDriverInfoSearch
-
-// CancelDriverInfoSearch method cancels a driver list search that is currently in progress in a different thread.
-func (deviceInfoSet DevInfo) CancelDriverInfoSearch() error {
- return SetupDiCancelDriverInfoSearch(deviceInfoSet)
-}
-
-//sys setupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex uint32, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiEnumDriverInfoW
-
-// SetupDiEnumDriverInfo function enumerates the members of a driver list.
-func SetupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex int) (*DrvInfoData, error) {
- data := &DrvInfoData{}
- data.size = uint32(unsafe.Sizeof(*data))
-
- return data, setupDiEnumDriverInfo(deviceInfoSet, deviceInfoData, driverType, uint32(memberIndex), data)
-}
-
-// EnumDriverInfo method enumerates the members of a driver list.
-func (deviceInfoSet DevInfo) EnumDriverInfo(deviceInfoData *DevInfoData, driverType SPDIT, memberIndex int) (*DrvInfoData, error) {
- return SetupDiEnumDriverInfo(deviceInfoSet, deviceInfoData, driverType, memberIndex)
-}
-
-//sys setupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiGetSelectedDriverW
-
-// SetupDiGetSelectedDriver function retrieves the selected driver for a device information set or a particular device information element.
-func SetupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (*DrvInfoData, error) {
- data := &DrvInfoData{}
- data.size = uint32(unsafe.Sizeof(*data))
-
- return data, setupDiGetSelectedDriver(deviceInfoSet, deviceInfoData, data)
-}
-
-// SelectedDriver method retrieves the selected driver for a device information set or a particular device information element.
-func (deviceInfoSet DevInfo) SelectedDriver(deviceInfoData *DevInfoData) (*DrvInfoData, error) {
- return SetupDiGetSelectedDriver(deviceInfoSet, deviceInfoData)
-}
-
-//sys SetupDiSetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiSetSelectedDriverW
-
-// SetSelectedDriver method sets, or resets, the selected driver for a device information element or the selected class driver for a device information set.
-func (deviceInfoSet DevInfo) SetSelectedDriver(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) error {
- return SetupDiSetSelectedDriver(deviceInfoSet, deviceInfoData, driverInfoData)
-}
-
-//sys setupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData, driverInfoDetailData *DrvInfoDetailData, driverInfoDetailDataSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDriverInfoDetailW
-
-// SetupDiGetDriverInfoDetail function retrieves driver information detail for a device information set or a particular device information element in the device information set.
-func SetupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) {
- reqSize := uint32(2048)
- for {
- buf := make([]byte, reqSize)
- data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0]))
- data.size = sizeofDrvInfoDetailData
- err := setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, uint32(len(buf)), &reqSize)
- if err == windows.ERROR_INSUFFICIENT_BUFFER {
- continue
- }
- if err != nil {
- return nil, err
- }
- data.size = reqSize
- return data, nil
- }
-}
-
-// DriverInfoDetail method retrieves driver information detail for a device information set or a particular device information element in the device information set.
-func (deviceInfoSet DevInfo) DriverInfoDetail(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) {
- return SetupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData)
-}
-
-//sys SetupDiDestroyDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) = setupapi.SetupDiDestroyDriverInfoList
-
-// DestroyDriverInfoList method deletes a driver list.
-func (deviceInfoSet DevInfo) DestroyDriverInfoList(deviceInfoData *DevInfoData, driverType SPDIT) error {
- return SetupDiDestroyDriverInfoList(deviceInfoSet, deviceInfoData, driverType)
-}
-
-//sys setupDiGetClassDevsEx(classGUID *windows.GUID, Enumerator *uint16, hwndParent uintptr, Flags DIGCF, deviceInfoSet DevInfo, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiGetClassDevsExW
-
-// SetupDiGetClassDevsEx function returns a handle to a device information set that contains requested device information elements for a local or a remote computer.
-func SetupDiGetClassDevsEx(classGUID *windows.GUID, enumerator string, hwndParent uintptr, flags DIGCF, deviceInfoSet DevInfo, machineName string) (handle DevInfo, err error) {
- var enumeratorUTF16 *uint16
- if enumerator != "" {
- enumeratorUTF16, err = windows.UTF16PtrFromString(enumerator)
- if err != nil {
- return
- }
- }
- var machineNameUTF16 *uint16
- if machineName != "" {
- machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
- if err != nil {
- return
- }
- }
- return setupDiGetClassDevsEx(classGUID, enumeratorUTF16, hwndParent, flags, deviceInfoSet, machineNameUTF16, 0)
-}
-
-// SetupDiCallClassInstaller function calls the appropriate class installer, and any registered co-installers, with the specified installation request (DIF code).
-//sys SetupDiCallClassInstaller(installFunction DI_FUNCTION, deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiCallClassInstaller
-
-// CallClassInstaller member calls the appropriate class installer, and any registered co-installers, with the specified installation request (DIF code).
-func (deviceInfoSet DevInfo) CallClassInstaller(installFunction DI_FUNCTION, deviceInfoData *DevInfoData) error {
- return SetupDiCallClassInstaller(installFunction, deviceInfoSet, deviceInfoData)
-}
-
-//sys setupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (key windows.Handle, err error) [failretval==windows.InvalidHandle] = setupapi.SetupDiOpenDevRegKey
-
-// SetupDiOpenDevRegKey function opens a registry key for device-specific configuration information.
-func SetupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, scope DICS_FLAG, hwProfile uint32, keyType DIREG, samDesired uint32) (registry.Key, error) {
- handle, err := setupDiOpenDevRegKey(deviceInfoSet, deviceInfoData, scope, hwProfile, keyType, samDesired)
- return registry.Key(handle), err
-}
-
-// OpenDevRegKey method opens a registry key for device-specific configuration information.
-func (deviceInfoSet DevInfo) OpenDevRegKey(DeviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (registry.Key, error) {
- return SetupDiOpenDevRegKey(deviceInfoSet, DeviceInfoData, Scope, HwProfile, KeyType, samDesired)
-}
-
-//sys setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceRegistryPropertyW
-
-// SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property.
-func SetupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP) (value interface{}, err error) {
- reqSize := uint32(256)
- for {
- var dataType uint32
- buf := make([]byte, reqSize)
- err = setupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &dataType, &buf[0], uint32(len(buf)), &reqSize)
- if err == windows.ERROR_INSUFFICIENT_BUFFER {
- continue
- }
- if err != nil {
- return
- }
- return getRegistryValue(buf[:reqSize], dataType)
- }
-}
-
-func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
- switch dataType {
- case windows.REG_SZ:
- ret := windows.UTF16ToString(bufToUTF16(buf))
- runtime.KeepAlive(buf)
- return ret, nil
- case windows.REG_EXPAND_SZ:
- ret, err := registry.ExpandString(windows.UTF16ToString(bufToUTF16(buf)))
- runtime.KeepAlive(buf)
- return ret, err
- case windows.REG_BINARY:
- return buf, nil
- case windows.REG_DWORD_LITTLE_ENDIAN:
- return binary.LittleEndian.Uint32(buf), nil
- case windows.REG_DWORD_BIG_ENDIAN:
- return binary.BigEndian.Uint32(buf), nil
- case windows.REG_MULTI_SZ:
- bufW := bufToUTF16(buf)
- a := []string{}
- for i := 0; i < len(bufW); {
- j := i + wcslen(bufW[i:])
- if i < j {
- a = append(a, windows.UTF16ToString(bufW[i:j]))
- }
- i = j + 1
- }
- runtime.KeepAlive(buf)
- return a, nil
- case windows.REG_QWORD_LITTLE_ENDIAN:
- return binary.LittleEndian.Uint64(buf), nil
- default:
- return nil, fmt.Errorf("Unsupported registry value type: %v", dataType)
- }
-}
-
-// bufToUTF16 function reinterprets []byte buffer as []uint16
-func bufToUTF16(buf []byte) []uint16 {
- sl := struct {
- addr *uint16
- len int
- cap int
- }{(*uint16)(unsafe.Pointer(&buf[0])), len(buf) / 2, cap(buf) / 2}
- return *(*[]uint16)(unsafe.Pointer(&sl))
-}
-
-// utf16ToBuf function reinterprets []uint16 as []byte
-func utf16ToBuf(buf []uint16) []byte {
- sl := struct {
- addr *byte
- len int
- cap int
- }{(*byte)(unsafe.Pointer(&buf[0])), len(buf) * 2, cap(buf) * 2}
- return *(*[]byte)(unsafe.Pointer(&sl))
-}
-
-func wcslen(str []uint16) int {
- for i := 0; i < len(str); i++ {
- if str[i] == 0 {
- return i
- }
- }
- return len(str)
-}
-
-// DeviceRegistryProperty method retrieves a specified Plug and Play device property.
-func (deviceInfoSet DevInfo) DeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP) (interface{}, error) {
- return SetupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property)
-}
-
-//sys setupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffer *byte, propertyBufferSize uint32) (err error) = setupapi.SetupDiSetDeviceRegistryPropertyW
-
-// SetupDiSetDeviceRegistryProperty function sets a Plug and Play device property for a device.
-func SetupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffers []byte) error {
- return setupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &propertyBuffers[0], uint32(len(propertyBuffers)))
-}
-
-// SetDeviceRegistryProperty function sets a Plug and Play device property for a device.
-func (deviceInfoSet DevInfo) SetDeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP, propertyBuffers []byte) error {
- return SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, propertyBuffers)
-}
-
-// SetDeviceRegistryPropertyString method sets a Plug and Play device property string for a device.
-func (deviceInfoSet DevInfo) SetDeviceRegistryPropertyString(deviceInfoData *DevInfoData, property SPDRP, str string) error {
- str16, err := windows.UTF16FromString(str)
- if err != nil {
- return err
- }
- err = SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, utf16ToBuf(append(str16, 0)))
- runtime.KeepAlive(str16)
- return err
-}
-
-//sys setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiGetDeviceInstallParamsW
-
-// SetupDiGetDeviceInstallParams function retrieves device installation parameters for a device information set or a particular device information element.
-func SetupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (*DevInstallParams, error) {
- params := &DevInstallParams{}
- params.size = uint32(unsafe.Sizeof(*params))
-
- return params, setupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData, params)
-}
-
-// DeviceInstallParams method retrieves device installation parameters for a device information set or a particular device information element.
-func (deviceInfoSet DevInfo) DeviceInstallParams(deviceInfoData *DevInfoData) (*DevInstallParams, error) {
- return SetupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData)
-}
-
-//sys setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceInstanceIdW
-
-// SetupDiGetDeviceInstanceId function retrieves the instance ID of the device.
-func SetupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (string, error) {
- reqSize := uint32(1024)
- for {
- buf := make([]uint16, reqSize)
- err := setupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData, &buf[0], uint32(len(buf)), &reqSize)
- if err == windows.ERROR_INSUFFICIENT_BUFFER {
- continue
- }
- if err != nil {
- return "", err
- }
- return windows.UTF16ToString(buf), nil
- }
-}
-
-// DeviceInstanceID method retrieves the instance ID of the device.
-func (deviceInfoSet DevInfo) DeviceInstanceID(deviceInfoData *DevInfoData) (string, error) {
- return SetupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData)
-}
-
-// SetupDiGetClassInstallParams function retrieves class installation parameters for a device information set or a particular device information element.
-//sys SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetClassInstallParamsW
-
-// ClassInstallParams method retrieves class installation parameters for a device information set or a particular device information element.
-func (deviceInfoSet DevInfo) ClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) error {
- return SetupDiGetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize, requiredSize)
-}
-
-//sys SetupDiSetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiSetDeviceInstallParamsW
-
-// SetDeviceInstallParams member sets device installation parameters for a device information set or a particular device information element.
-func (deviceInfoSet DevInfo) SetDeviceInstallParams(deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) error {
- return SetupDiSetDeviceInstallParams(deviceInfoSet, deviceInfoData, deviceInstallParams)
-}
-
-// SetupDiSetClassInstallParams function sets or clears class install parameters for a device information set or a particular device information element.
-//sys SetupDiSetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) (err error) = setupapi.SetupDiSetClassInstallParamsW
-
-// SetClassInstallParams method sets or clears class install parameters for a device information set or a particular device information element.
-func (deviceInfoSet DevInfo) SetClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) error {
- return SetupDiSetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize)
-}
-
-//sys setupDiClassNameFromGuidEx(classGUID *windows.GUID, className *uint16, classNameSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) = setupapi.SetupDiClassNameFromGuidExW
-
-// SetupDiClassNameFromGuidEx function retrieves the class name associated with a class GUID. The class can be installed on a local or remote computer.
-func SetupDiClassNameFromGuidEx(classGUID *windows.GUID, machineName string) (className string, err error) {
- var classNameUTF16 [MAX_CLASS_NAME_LEN]uint16
-
- var machineNameUTF16 *uint16
- if machineName != "" {
- machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
- if err != nil {
- return
- }
- }
-
- err = setupDiClassNameFromGuidEx(classGUID, &classNameUTF16[0], MAX_CLASS_NAME_LEN, nil, machineNameUTF16, 0)
- if err != nil {
- return
- }
-
- className = windows.UTF16ToString(classNameUTF16[:])
- return
-}
-
-//sys setupDiClassGuidsFromNameEx(className *uint16, classGuidList *windows.GUID, classGuidListSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) = setupapi.SetupDiClassGuidsFromNameExW
-
-// SetupDiClassGuidsFromNameEx function retrieves the GUIDs associated with the specified class name. This resulting list contains the classes currently installed on a local or remote computer.
-func SetupDiClassGuidsFromNameEx(className string, machineName string) ([]windows.GUID, error) {
- classNameUTF16, err := windows.UTF16PtrFromString(className)
- if err != nil {
- return nil, err
- }
-
- var machineNameUTF16 *uint16
- if machineName != "" {
- machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
- if err != nil {
- return nil, err
- }
- }
-
- reqSize := uint32(4)
- for {
- buf := make([]windows.GUID, reqSize)
- err = setupDiClassGuidsFromNameEx(classNameUTF16, &buf[0], uint32(len(buf)), &reqSize, machineNameUTF16, 0)
- if err == windows.ERROR_INSUFFICIENT_BUFFER {
- continue
- }
- if err != nil {
- return nil, err
- }
- return buf[:reqSize], nil
- }
-}
-
-//sys setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiGetSelectedDevice
-
-// SetupDiGetSelectedDevice function retrieves the selected device information element in a device information set.
-func SetupDiGetSelectedDevice(deviceInfoSet DevInfo) (*DevInfoData, error) {
- data := &DevInfoData{}
- data.size = uint32(unsafe.Sizeof(*data))
-
- return data, setupDiGetSelectedDevice(deviceInfoSet, data)
-}
-
-// SelectedDevice method retrieves the selected device information element in a device information set.
-func (deviceInfoSet DevInfo) SelectedDevice() (*DevInfoData, error) {
- return SetupDiGetSelectedDevice(deviceInfoSet)
-}
-
-// SetupDiSetSelectedDevice function sets a device information element as the selected member of a device information set. This function is typically used by an installation wizard.
-//sys SetupDiSetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiSetSelectedDevice
-
-// SetSelectedDevice method sets a device information element as the selected member of a device information set. This function is typically used by an installation wizard.
-func (deviceInfoSet DevInfo) SetSelectedDevice(deviceInfoData *DevInfoData) error {
- return SetupDiSetSelectedDevice(deviceInfoSet, deviceInfoData)
-}
-
-//sys cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_List_SizeW
-//sys cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_ListW
-
-func CM_Get_Device_Interface_List(deviceID string, interfaceClass *windows.GUID, flags uint32) ([]string, error) {
- deviceID16, err := windows.UTF16PtrFromString(deviceID)
- if err != nil {
- return nil, err
- }
- var buf []uint16
- var buflen uint32
- for {
- if ret := cm_Get_Device_Interface_List_Size(&buflen, interfaceClass, deviceID16, flags); ret != CR_SUCCESS {
- return nil, fmt.Errorf("CfgMgr error: 0x%x", ret)
- }
- buf = make([]uint16, buflen)
- if ret := cm_Get_Device_Interface_List(interfaceClass, deviceID16, &buf[0], buflen, flags); ret == CR_SUCCESS {
- break
- } else if ret != CR_BUFFER_SMALL {
- return nil, fmt.Errorf("CfgMgr error: 0x%x", ret)
- }
- }
- var interfaces []string
- for i := 0; i < len(buf); {
- j := i + wcslen(buf[i:])
- if i < j {
- interfaces = append(interfaces, windows.UTF16ToString(buf[i:j]))
- }
- i = j + 1
- }
- if interfaces == nil {
- return nil, fmt.Errorf("no interfaces found")
- }
- return interfaces, nil
-}
diff --git a/tun/wintun/setupapi/setupapi_windows_test.go b/tun/wintun/setupapi/setupapi_windows_test.go
deleted file mode 100644
index a9e6b89..0000000
--- a/tun/wintun/setupapi/setupapi_windows_test.go
+++ /dev/null
@@ -1,488 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-import (
- "runtime"
- "strings"
- "testing"
-
- "golang.org/x/sys/windows"
-)
-
-var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
-var computerName string
-
-func init() {
- computerName, _ = windows.ComputerName()
-}
-
-func TestSetupDiCreateDeviceInfoListEx(t *testing.T) {
- devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
- if err != nil {
- t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
- } else {
- devInfoList.Close()
- }
-
- devInfoList, err = SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
- } else {
- devInfoList.Close()
- }
-
- devInfoList, err = SetupDiCreateDeviceInfoListEx(nil, 0, "")
- if err != nil {
- t.Errorf("Error calling SetupDiCreateDeviceInfoListEx(nil): %s", err.Error())
- } else {
- devInfoList.Close()
- }
-}
-
-func TestSetupDiGetDeviceInfoListDetail(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- data, err := devInfoList.DeviceInfoListDetail()
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error())
- } else {
- if data.ClassGUID != deviceClassNetGUID {
- t.Error("SetupDiGetDeviceInfoListDetail returned different class GUID")
- }
-
- if data.RemoteMachineHandle != windows.Handle(0) {
- t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine handle")
- }
-
- if data.RemoteMachineName() != "" {
- t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine name")
- }
- }
-
- devInfoList, err = SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- data, err = devInfoList.DeviceInfoListDetail()
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error())
- } else {
- if data.ClassGUID != deviceClassNetGUID {
- t.Error("SetupDiGetDeviceInfoListDetail returned different class GUID")
- }
-
- if data.RemoteMachineHandle == windows.Handle(0) {
- t.Error("SetupDiGetDeviceInfoListDetail returned NULL remote machine handle")
- }
-
- if data.RemoteMachineName() != computerName {
- t.Error("SetupDiGetDeviceInfoListDetail returned different remote machine name")
- }
- }
-
- data = &DevInfoListDetailData{}
- data.SetRemoteMachineName("foobar")
- if data.RemoteMachineName() != "foobar" {
- t.Error("DevInfoListDetailData.(Get|Set)RemoteMachineName() differ")
- }
-}
-
-func TestSetupDiCreateDeviceInfo(t *testing.T) {
- devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- deviceClassNetName, err := SetupDiClassNameFromGuidEx(&deviceClassNetGUID, computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error())
- }
-
- devInfoData, err := devInfoList.CreateDeviceInfo(deviceClassNetName, &deviceClassNetGUID, "This is a test device", 0, DICD_GENERATE_ID)
- if err != nil {
- // Access denied is expected, as the SetupDiCreateDeviceInfo() require elevation to succeed.
- if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_ACCESS_DENIED {
- t.Errorf("Error calling SetupDiCreateDeviceInfo: %s", err.Error())
- }
- } else if devInfoData.ClassGUID != deviceClassNetGUID {
- t.Error("SetupDiCreateDeviceInfo returned different class GUID")
- }
-}
-
-func TestSetupDiEnumDeviceInfo(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- data, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- if data.ClassGUID != deviceClassNetGUID {
- t.Error("SetupDiEnumDeviceInfo returned different class GUID")
- }
-
- _, err = devInfoList.DeviceInstanceID(data)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceInstanceId: %s", err.Error())
- }
- }
-}
-
-func TestDevInfo_BuildDriverInfoList(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- deviceData, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- const driverType SPDIT = SPDIT_COMPATDRIVER
- err = devInfoList.BuildDriverInfoList(deviceData, driverType)
- if err != nil {
- t.Errorf("Error calling SetupDiBuildDriverInfoList: %s", err.Error())
- }
- defer devInfoList.DestroyDriverInfoList(deviceData, driverType)
-
- var selectedDriverData *DrvInfoData
- for j := 0; true; j++ {
- driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, j)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- if driverData.DriverType == 0 {
- continue
- }
-
- if !driverData.IsNewer(windows.Filetime{}, 0) {
- t.Error("Driver should have non-zero date and version")
- }
- if !driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime}, 0) {
- t.Error("Driver should have non-zero date and version")
- }
- if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime + 1}, 0) {
- t.Error("Driver should report newer version on high-date-time")
- }
- if !driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, 0) {
- t.Error("Driver should have non-zero version")
- }
- if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime + 1}, 0) {
- t.Error("Driver should report newer version on low-date-time")
- }
- if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, driverData.DriverVersion) {
- t.Error("Driver should not be newer than itself")
- }
- if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, driverData.DriverVersion+1) {
- t.Error("Driver should report newer version on version")
- }
-
- err = devInfoList.SetSelectedDriver(deviceData, driverData)
- if err != nil {
- t.Errorf("Error calling SetupDiSetSelectedDriver: %s", err.Error())
- } else {
- selectedDriverData = driverData
- }
-
- driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDriverInfoDetail: %s", err.Error())
- }
-
- if driverDetailData.IsCompatible("foobar-aab6e3a4-144e-4786-88d3-6cec361e1edd") {
- t.Error("Invalid HWID compatibitlity reported")
- }
- if !driverDetailData.IsCompatible(strings.ToUpper(driverDetailData.HardwareID())) {
- t.Error("HWID compatibitlity missed")
- }
- a := driverDetailData.CompatIDs()
- for k := range a {
- if !driverDetailData.IsCompatible(strings.ToUpper(a[k])) {
- t.Error("HWID compatibitlity missed")
- }
- }
- }
-
- selectedDriverData2, err := devInfoList.SelectedDriver(deviceData)
- if err != nil {
- t.Errorf("Error calling SetupDiGetSelectedDriver: %s", err.Error())
- } else if *selectedDriverData != *selectedDriverData2 {
- t.Error("SetupDiGetSelectedDriver should return driver selected with SetupDiSetSelectedDriver")
- }
- }
-
- data := &DrvInfoData{}
- data.SetDescription("foobar")
- if data.Description() != "foobar" {
- t.Error("DrvInfoData.(Get|Set)Description() differ")
- }
- data.SetMfgName("foobar")
- if data.MfgName() != "foobar" {
- t.Error("DrvInfoData.(Get|Set)MfgName() differ")
- }
- data.SetProviderName("foobar")
- if data.ProviderName() != "foobar" {
- t.Error("DrvInfoData.(Get|Set)ProviderName() differ")
- }
-}
-
-func TestSetupDiGetClassDevsEx(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "PCI", 0, DIGCF_PRESENT, DevInfo(0), computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- } else {
- devInfoList.Close()
- }
-
- devInfoList, err = SetupDiGetClassDevsEx(nil, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
- t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail with ERROR_INVALID_PARAMETER")
- }
- } else {
- devInfoList.Close()
- t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail")
- }
-}
-
-func TestSetupDiOpenDevRegKey(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- data, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- key, err := devInfoList.OpenDevRegKey(data, DICS_FLAG_GLOBAL, 0, DIREG_DRV, windows.KEY_READ)
- if err != nil {
- t.Errorf("Error calling SetupDiOpenDevRegKey: %s", err.Error())
- }
- defer key.Close()
- }
-}
-
-func TestSetupDiGetDeviceRegistryProperty(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- data, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- val, err := devInfoList.DeviceRegistryProperty(data, SPDRP_CLASS)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASS): %s", err.Error())
- } else if class, ok := val.(string); !ok || strings.ToLower(class) != "net" {
- t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASS) should return \"Net\"")
- }
-
- val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CLASSGUID)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error())
- } else if valStr, ok := val.(string); !ok {
- t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return string")
- } else {
- classGUID, err := windows.GUIDFromString(valStr)
- if err != nil {
- t.Errorf("Error parsing GUID returned by SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error())
- } else if classGUID != deviceClassNetGUID {
- t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return %x", deviceClassNetGUID)
- }
- }
-
- val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_COMPATIBLEIDS)
- if err != nil {
- // Some devices have no SPDRP_COMPATIBLEIDS.
- if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_DATA {
- t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_COMPATIBLEIDS): %s", err.Error())
- }
- }
-
- val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CONFIGFLAGS)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CONFIGFLAGS): %s", err.Error())
- }
-
- val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_DEVICE_POWER_DATA)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_DEVICE_POWER_DATA): %s", err.Error())
- }
- }
-}
-
-func TestSetupDiGetDeviceInstallParams(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- data, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- _, err = devInfoList.DeviceInstallParams(data)
- if err != nil {
- t.Errorf("Error calling SetupDiGetDeviceInstallParams: %s", err.Error())
- }
- }
-
- params := &DevInstallParams{}
- params.SetDriverPath("foobar")
- if params.DriverPath() != "foobar" {
- t.Error("DevInstallParams.(Get|Set)DriverPath() differ")
- }
-}
-
-func TestSetupDiClassNameFromGuidEx(t *testing.T) {
- deviceClassNetName, err := SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
- if err != nil {
- t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error())
- } else if strings.ToLower(deviceClassNetName) != "net" {
- t.Errorf("SetupDiClassNameFromGuidEx(%x) should return \"Net\"", deviceClassNetGUID)
- }
-
- deviceClassNetName, err = SetupDiClassNameFromGuidEx(&deviceClassNetGUID, computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error())
- } else if strings.ToLower(deviceClassNetName) != "net" {
- t.Errorf("SetupDiClassNameFromGuidEx(%x) should return \"Net\"", deviceClassNetGUID)
- }
-
- _, err = SetupDiClassNameFromGuidEx(nil, "")
- if err != nil {
- if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_USER_BUFFER {
- t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail with ERROR_INVALID_USER_BUFFER")
- }
- } else {
- t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail")
- }
-}
-
-func TestSetupDiClassGuidsFromNameEx(t *testing.T) {
- ClassGUIDs, err := SetupDiClassGuidsFromNameEx("Net", "")
- if err != nil {
- t.Errorf("Error calling SetupDiClassGuidsFromNameEx: %s", err.Error())
- } else {
- found := false
- for i := range ClassGUIDs {
- if ClassGUIDs[i] == deviceClassNetGUID {
- found = true
- break
- }
- }
- if !found {
- t.Errorf("SetupDiClassGuidsFromNameEx(\"Net\") should return %x", deviceClassNetGUID)
- }
- }
-
- ClassGUIDs, err = SetupDiClassGuidsFromNameEx("foobar-34274a51-a6e6-45f0-80d6-c62be96dd5fe", computerName)
- if err != nil {
- t.Errorf("Error calling SetupDiClassGuidsFromNameEx: %s", err.Error())
- } else if len(ClassGUIDs) != 0 {
- t.Errorf("SetupDiClassGuidsFromNameEx(\"foobar-34274a51-a6e6-45f0-80d6-c62be96dd5fe\") should return an empty GUID set")
- }
-}
-
-func TestSetupDiGetSelectedDevice(t *testing.T) {
- devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "")
- if err != nil {
- t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
- }
- defer devInfoList.Close()
-
- for i := 0; true; i++ {
- data, err := devInfoList.EnumDeviceInfo(i)
- if err != nil {
- if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- err = devInfoList.SetSelectedDevice(data)
- if err != nil {
- t.Errorf("Error calling SetupDiSetSelectedDevice: %s", err.Error())
- }
-
- data2, err := devInfoList.SelectedDevice()
- if err != nil {
- t.Errorf("Error calling SetupDiGetSelectedDevice: %s", err.Error())
- } else if *data != *data2 {
- t.Error("SetupDiGetSelectedDevice returned different data than was set by SetupDiSetSelectedDevice")
- }
- }
-
- err = devInfoList.SetSelectedDevice(nil)
- if err != nil {
- if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
- t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER")
- }
- } else {
- t.Errorf("SetupDiSetSelectedDevice(nil) should fail")
- }
-}
-
-func TestUTF16ToBuf(t *testing.T) {
- buf := []uint16{0x0123, 0x4567, 0x89ab, 0xcdef}
- buf2 := utf16ToBuf(buf)
- if len(buf)*2 != len(buf2) ||
- cap(buf)*2 != cap(buf2) ||
- buf2[0] != 0x23 || buf2[1] != 0x01 ||
- buf2[2] != 0x67 || buf2[3] != 0x45 ||
- buf2[4] != 0xab || buf2[5] != 0x89 ||
- buf2[6] != 0xef || buf2[7] != 0xcd {
- t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER")
- }
- runtime.KeepAlive(buf)
-}
diff --git a/tun/wintun/setupapi/types_windows.go b/tun/wintun/setupapi/types_windows.go
deleted file mode 100644
index 136b4be..0000000
--- a/tun/wintun/setupapi/types_windows.go
+++ /dev/null
@@ -1,568 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-import (
- "strings"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-const (
- MAX_DEVICE_ID_LEN = 200
- MAX_DEVNODE_ID_LEN = MAX_DEVICE_ID_LEN
- MAX_GUID_STRING_LEN = 39 // 38 chars + terminator null
- MAX_CLASS_NAME_LEN = 32
- MAX_PROFILE_LEN = 80
- MAX_CONFIG_VALUE = 9999
- MAX_INSTANCE_VALUE = 9999
- CONFIGMG_VERSION = 0x0400
-)
-
-//
-// Define maximum string length constants
-//
-const (
- ANYSIZE_ARRAY = 1
- LINE_LEN = 256 // Windows 9x-compatible maximum for displayable strings coming from a device INF.
- MAX_INF_STRING_LENGTH = 4096 // Actual maximum size of an INF string (including string substitutions).
- MAX_INF_SECTION_NAME_LENGTH = 255 // For Windows 9x compatibility, INF section names should be constrained to 32 characters.
- MAX_TITLE_LEN = 60
- MAX_INSTRUCTION_LEN = 256
- MAX_LABEL_LEN = 30
- MAX_SERVICE_NAME_LEN = 256
- MAX_SUBTITLE_LEN = 256
-)
-
-const (
- // SP_MAX_MACHINENAME_LENGTH defines maximum length of a machine name in the format expected by ConfigMgr32 CM_Connect_Machine (i.e., "\\\\MachineName\0").
- SP_MAX_MACHINENAME_LENGTH = windows.MAX_PATH + 3
-)
-
-// HSPFILEQ is type for setup file queue
-type HSPFILEQ uintptr
-
-// DevInfo holds reference to device information set
-type DevInfo windows.Handle
-
-// DevInfoData is a device information structure (references a device instance that is a member of a device information set)
-type DevInfoData struct {
- size uint32
- ClassGUID windows.GUID
- DevInst uint32 // DEVINST handle
- _ uintptr
-}
-
-// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass).
-type DevInfoListDetailData struct {
- size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
- ClassGUID windows.GUID
- RemoteMachineHandle windows.Handle
- remoteMachineName [SP_MAX_MACHINENAME_LENGTH]uint16
-}
-
-func (data *DevInfoListDetailData) RemoteMachineName() string {
- return windows.UTF16ToString(data.remoteMachineName[:])
-}
-
-func (data *DevInfoListDetailData) SetRemoteMachineName(remoteMachineName string) error {
- str, err := windows.UTF16FromString(remoteMachineName)
- if err != nil {
- return err
- }
- copy(data.remoteMachineName[:], str)
- return nil
-}
-
-// DI_FUNCTION is function type for device installer
-type DI_FUNCTION uint32
-
-const (
- DIF_SELECTDEVICE DI_FUNCTION = 0x00000001
- DIF_INSTALLDEVICE DI_FUNCTION = 0x00000002
- DIF_ASSIGNRESOURCES DI_FUNCTION = 0x00000003
- DIF_PROPERTIES DI_FUNCTION = 0x00000004
- DIF_REMOVE DI_FUNCTION = 0x00000005
- DIF_FIRSTTIMESETUP DI_FUNCTION = 0x00000006
- DIF_FOUNDDEVICE DI_FUNCTION = 0x00000007
- DIF_SELECTCLASSDRIVERS DI_FUNCTION = 0x00000008
- DIF_VALIDATECLASSDRIVERS DI_FUNCTION = 0x00000009
- DIF_INSTALLCLASSDRIVERS DI_FUNCTION = 0x0000000A
- DIF_CALCDISKSPACE DI_FUNCTION = 0x0000000B
- DIF_DESTROYPRIVATEDATA DI_FUNCTION = 0x0000000C
- DIF_VALIDATEDRIVER DI_FUNCTION = 0x0000000D
- DIF_DETECT DI_FUNCTION = 0x0000000F
- DIF_INSTALLWIZARD DI_FUNCTION = 0x00000010
- DIF_DESTROYWIZARDDATA DI_FUNCTION = 0x00000011
- DIF_PROPERTYCHANGE DI_FUNCTION = 0x00000012
- DIF_ENABLECLASS DI_FUNCTION = 0x00000013
- DIF_DETECTVERIFY DI_FUNCTION = 0x00000014
- DIF_INSTALLDEVICEFILES DI_FUNCTION = 0x00000015
- DIF_UNREMOVE DI_FUNCTION = 0x00000016
- DIF_SELECTBESTCOMPATDRV DI_FUNCTION = 0x00000017
- DIF_ALLOW_INSTALL DI_FUNCTION = 0x00000018
- DIF_REGISTERDEVICE DI_FUNCTION = 0x00000019
- DIF_NEWDEVICEWIZARD_PRESELECT DI_FUNCTION = 0x0000001A
- DIF_NEWDEVICEWIZARD_SELECT DI_FUNCTION = 0x0000001B
- DIF_NEWDEVICEWIZARD_PREANALYZE DI_FUNCTION = 0x0000001C
- DIF_NEWDEVICEWIZARD_POSTANALYZE DI_FUNCTION = 0x0000001D
- DIF_NEWDEVICEWIZARD_FINISHINSTALL DI_FUNCTION = 0x0000001E
- DIF_INSTALLINTERFACES DI_FUNCTION = 0x00000020
- DIF_DETECTCANCEL DI_FUNCTION = 0x00000021
- DIF_REGISTER_COINSTALLERS DI_FUNCTION = 0x00000022
- DIF_ADDPROPERTYPAGE_ADVANCED DI_FUNCTION = 0x00000023
- DIF_ADDPROPERTYPAGE_BASIC DI_FUNCTION = 0x00000024
- DIF_TROUBLESHOOTER DI_FUNCTION = 0x00000026
- DIF_POWERMESSAGEWAKE DI_FUNCTION = 0x00000027
- DIF_ADDREMOTEPROPERTYPAGE_ADVANCED DI_FUNCTION = 0x00000028
- DIF_UPDATEDRIVER_UI DI_FUNCTION = 0x00000029
- DIF_FINISHINSTALL_ACTION DI_FUNCTION = 0x0000002A
-)
-
-// DevInstallParams is device installation parameters structure (associated with a particular device information element, or globally with a device information set)
-type DevInstallParams struct {
- size uint32
- Flags DI_FLAGS
- FlagsEx DI_FLAGSEX
- hwndParent uintptr
- InstallMsgHandler uintptr
- InstallMsgHandlerContext uintptr
- FileQueue HSPFILEQ
- _ uintptr
- _ uint32
- driverPath [windows.MAX_PATH]uint16
-}
-
-func (params *DevInstallParams) DriverPath() string {
- return windows.UTF16ToString(params.driverPath[:])
-}
-
-func (params *DevInstallParams) SetDriverPath(driverPath string) error {
- str, err := windows.UTF16FromString(driverPath)
- if err != nil {
- return err
- }
- copy(params.driverPath[:], str)
- return nil
-}
-
-// DI_FLAGS is SP_DEVINSTALL_PARAMS.Flags values
-type DI_FLAGS uint32
-
-const (
- // Flags for choosing a device
- DI_SHOWOEM DI_FLAGS = 0x00000001 // support Other... button
- DI_SHOWCOMPAT DI_FLAGS = 0x00000002 // show compatibility list
- DI_SHOWCLASS DI_FLAGS = 0x00000004 // show class list
- DI_SHOWALL DI_FLAGS = 0x00000007 // both class & compat list shown
- DI_NOVCP DI_FLAGS = 0x00000008 // don't create a new copy queue--use caller-supplied FileQueue
- DI_DIDCOMPAT DI_FLAGS = 0x00000010 // Searched for compatible devices
- DI_DIDCLASS DI_FLAGS = 0x00000020 // Searched for class devices
- DI_AUTOASSIGNRES DI_FLAGS = 0x00000040 // No UI for resources if possible
-
- // Flags returned by DiInstallDevice to indicate need to reboot/restart
- DI_NEEDRESTART DI_FLAGS = 0x00000080 // Reboot required to take effect
- DI_NEEDREBOOT DI_FLAGS = 0x00000100 // ""
-
- // Flags for device installation
- DI_NOBROWSE DI_FLAGS = 0x00000200 // no Browse... in InsertDisk
-
- // Flags set by DiBuildDriverInfoList
- DI_MULTMFGS DI_FLAGS = 0x00000400 // Set if multiple manufacturers in class driver list
-
- // Flag indicates that device is disabled
- DI_DISABLED DI_FLAGS = 0x00000800 // Set if device disabled
-
- // Flags for Device/Class Properties
- DI_GENERALPAGE_ADDED DI_FLAGS = 0x00001000
- DI_RESOURCEPAGE_ADDED DI_FLAGS = 0x00002000
-
- // Flag to indicate the setting properties for this Device (or class) caused a change so the Dev Mgr UI probably needs to be updated.
- DI_PROPERTIES_CHANGE DI_FLAGS = 0x00004000
-
- // Flag to indicate that the sorting from the INF file should be used.
- DI_INF_IS_SORTED DI_FLAGS = 0x00008000
-
- // Flag to indicate that only the the INF specified by SP_DEVINSTALL_PARAMS.DriverPath should be searched.
- DI_ENUMSINGLEINF DI_FLAGS = 0x00010000
-
- // Flag that prevents ConfigMgr from removing/re-enumerating devices during device
- // registration, installation, and deletion.
- DI_DONOTCALLCONFIGMG DI_FLAGS = 0x00020000
-
- // The following flag can be used to install a device disabled
- DI_INSTALLDISABLED DI_FLAGS = 0x00040000
-
- // Flag that causes SetupDiBuildDriverInfoList to build a device's compatible driver
- // list from its existing class driver list, instead of the normal INF search.
- DI_COMPAT_FROM_CLASS DI_FLAGS = 0x00080000
-
- // This flag is set if the Class Install params should be used.
- DI_CLASSINSTALLPARAMS DI_FLAGS = 0x00100000
-
- // This flag is set if the caller of DiCallClassInstaller does NOT want the internal default action performed if the Class installer returns ERROR_DI_DO_DEFAULT.
- DI_NODI_DEFAULTACTION DI_FLAGS = 0x00200000
-
- // Flags for device installation
- DI_QUIETINSTALL DI_FLAGS = 0x00800000 // don't confuse the user with questions or excess info
- DI_NOFILECOPY DI_FLAGS = 0x01000000 // No file Copy necessary
- DI_FORCECOPY DI_FLAGS = 0x02000000 // Force files to be copied from install path
- DI_DRIVERPAGE_ADDED DI_FLAGS = 0x04000000 // Prop provider added Driver page.
- DI_USECI_SELECTSTRINGS DI_FLAGS = 0x08000000 // Use Class Installer Provided strings in the Select Device Dlg
- DI_OVERRIDE_INFFLAGS DI_FLAGS = 0x10000000 // Override INF flags
- DI_PROPS_NOCHANGEUSAGE DI_FLAGS = 0x20000000 // No Enable/Disable in General Props
-
- DI_NOSELECTICONS DI_FLAGS = 0x40000000 // No small icons in select device dialogs
-
- DI_NOWRITE_IDS DI_FLAGS = 0x80000000 // Don't write HW & Compat IDs on install
-)
-
-// DI_FLAGSEX is SP_DEVINSTALL_PARAMS.FlagsEx values
-type DI_FLAGSEX uint32
-
-const (
- DI_FLAGSEX_CI_FAILED DI_FLAGSEX = 0x00000004 // Failed to Load/Call class installer
- DI_FLAGSEX_FINISHINSTALL_ACTION DI_FLAGSEX = 0x00000008 // Class/co-installer wants to get a DIF_FINISH_INSTALL action in client context.
- DI_FLAGSEX_DIDINFOLIST DI_FLAGSEX = 0x00000010 // Did the Class Info List
- DI_FLAGSEX_DIDCOMPATINFO DI_FLAGSEX = 0x00000020 // Did the Compat Info List
- DI_FLAGSEX_FILTERCLASSES DI_FLAGSEX = 0x00000040
- DI_FLAGSEX_SETFAILEDINSTALL DI_FLAGSEX = 0x00000080
- DI_FLAGSEX_DEVICECHANGE DI_FLAGSEX = 0x00000100
- DI_FLAGSEX_ALWAYSWRITEIDS DI_FLAGSEX = 0x00000200
- DI_FLAGSEX_PROPCHANGE_PENDING DI_FLAGSEX = 0x00000400 // One or more device property sheets have had changes made to them, and need to have a DIF_PROPERTYCHANGE occur.
- DI_FLAGSEX_ALLOWEXCLUDEDDRVS DI_FLAGSEX = 0x00000800
- DI_FLAGSEX_NOUIONQUERYREMOVE DI_FLAGSEX = 0x00001000
- DI_FLAGSEX_USECLASSFORCOMPAT DI_FLAGSEX = 0x00002000 // Use the device's class when building compat drv list. (Ignored if DI_COMPAT_FROM_CLASS flag is specified.)
- DI_FLAGSEX_NO_DRVREG_MODIFY DI_FLAGSEX = 0x00008000 // Don't run AddReg and DelReg for device's software (driver) key.
- DI_FLAGSEX_IN_SYSTEM_SETUP DI_FLAGSEX = 0x00010000 // Installation is occurring during initial system setup.
- DI_FLAGSEX_INET_DRIVER DI_FLAGSEX = 0x00020000 // Driver came from Windows Update
- DI_FLAGSEX_APPENDDRIVERLIST DI_FLAGSEX = 0x00040000 // Cause SetupDiBuildDriverInfoList to append a new driver list to an existing list.
- DI_FLAGSEX_PREINSTALLBACKUP DI_FLAGSEX = 0x00080000 // not used
- DI_FLAGSEX_BACKUPONREPLACE DI_FLAGSEX = 0x00100000 // not used
- DI_FLAGSEX_DRIVERLIST_FROM_URL DI_FLAGSEX = 0x00200000 // build driver list from INF(s) retrieved from URL specified in SP_DEVINSTALL_PARAMS.DriverPath (empty string means Windows Update website)
- DI_FLAGSEX_EXCLUDE_OLD_INET_DRIVERS DI_FLAGSEX = 0x00800000 // Don't include old Internet drivers when building a driver list. Ignored on Windows Vista and later.
- DI_FLAGSEX_POWERPAGE_ADDED DI_FLAGSEX = 0x01000000 // class installer added their own power page
- DI_FLAGSEX_FILTERSIMILARDRIVERS DI_FLAGSEX = 0x02000000 // only include similar drivers in class list
- DI_FLAGSEX_INSTALLEDDRIVER DI_FLAGSEX = 0x04000000 // only add the installed driver to the class or compat driver list. Used in calls to SetupDiBuildDriverInfoList
- DI_FLAGSEX_NO_CLASSLIST_NODE_MERGE DI_FLAGSEX = 0x08000000 // Don't remove identical driver nodes from the class list
- DI_FLAGSEX_ALTPLATFORM_DRVSEARCH DI_FLAGSEX = 0x10000000 // Build driver list based on alternate platform information specified in associated file queue
- DI_FLAGSEX_RESTART_DEVICE_ONLY DI_FLAGSEX = 0x20000000 // only restart the device drivers are being installed on as opposed to restarting all devices using those drivers.
- DI_FLAGSEX_RECURSIVESEARCH DI_FLAGSEX = 0x40000000 // Tell SetupDiBuildDriverInfoList to do a recursive search
- DI_FLAGSEX_SEARCH_PUBLISHED_INFS DI_FLAGSEX = 0x80000000 // Tell SetupDiBuildDriverInfoList to do a "published INF" search
-)
-
-// ClassInstallHeader is the first member of any class install parameters structure. It contains the device installation request code that defines the format of the rest of the install parameters structure.
-type ClassInstallHeader struct {
- size uint32
- InstallFunction DI_FUNCTION
-}
-
-func MakeClassInstallHeader(installFunction DI_FUNCTION) *ClassInstallHeader {
- hdr := &ClassInstallHeader{InstallFunction: installFunction}
- hdr.size = uint32(unsafe.Sizeof(*hdr))
- return hdr
-}
-
-// DICS_STATE specifies values indicating a change in a device's state
-type DICS_STATE uint32
-
-const (
- DICS_ENABLE DICS_STATE = 0x00000001 // The device is being enabled.
- DICS_DISABLE DICS_STATE = 0x00000002 // The device is being disabled.
- DICS_PROPCHANGE DICS_STATE = 0x00000003 // The properties of the device have changed.
- DICS_START DICS_STATE = 0x00000004 // The device is being started (if the request is for the currently active hardware profile).
- DICS_STOP DICS_STATE = 0x00000005 // The device is being stopped. The driver stack will be unloaded and the CSCONFIGFLAG_DO_NOT_START flag will be set for the device.
-)
-
-// DICS_FLAG specifies the scope of a device property change
-type DICS_FLAG uint32
-
-const (
- DICS_FLAG_GLOBAL DICS_FLAG = 0x00000001 // make change in all hardware profiles
- DICS_FLAG_CONFIGSPECIFIC DICS_FLAG = 0x00000002 // make change in specified profile only
- DICS_FLAG_CONFIGGENERAL DICS_FLAG = 0x00000004 // 1 or more hardware profile-specific changes to follow (obsolete)
-)
-
-// PropChangeParams is a structure corresponding to a DIF_PROPERTYCHANGE install function.
-type PropChangeParams struct {
- ClassInstallHeader ClassInstallHeader
- StateChange DICS_STATE
- Scope DICS_FLAG
- HwProfile uint32
-}
-
-// DI_REMOVEDEVICE specifies the scope of the device removal
-type DI_REMOVEDEVICE uint32
-
-const (
- DI_REMOVEDEVICE_GLOBAL DI_REMOVEDEVICE = 0x00000001 // Make this change in all hardware profiles. Remove information about the device from the registry.
- DI_REMOVEDEVICE_CONFIGSPECIFIC DI_REMOVEDEVICE = 0x00000002 // Make this change to only the hardware profile specified by HwProfile. this flag only applies to root-enumerated devices. When Windows removes the device from the last hardware profile in which it was configured, Windows performs a global removal.
-)
-
-// RemoveDeviceParams is a structure corresponding to a DIF_REMOVE install function.
-type RemoveDeviceParams struct {
- ClassInstallHeader ClassInstallHeader
- Scope DI_REMOVEDEVICE
- HwProfile uint32
-}
-
-// DrvInfoData is driver information structure (member of a driver info list that may be associated with a particular device instance, or (globally) with a device information set)
-type DrvInfoData struct {
- size uint32
- DriverType uint32
- _ uintptr
- description [LINE_LEN]uint16
- mfgName [LINE_LEN]uint16
- providerName [LINE_LEN]uint16
- DriverDate windows.Filetime
- DriverVersion uint64
-}
-
-func (data *DrvInfoData) Description() string {
- return windows.UTF16ToString(data.description[:])
-}
-
-func (data *DrvInfoData) SetDescription(description string) error {
- str, err := windows.UTF16FromString(description)
- if err != nil {
- return err
- }
- copy(data.description[:], str)
- return nil
-}
-
-func (data *DrvInfoData) MfgName() string {
- return windows.UTF16ToString(data.mfgName[:])
-}
-
-func (data *DrvInfoData) SetMfgName(mfgName string) error {
- str, err := windows.UTF16FromString(mfgName)
- if err != nil {
- return err
- }
- copy(data.mfgName[:], str)
- return nil
-}
-
-func (data *DrvInfoData) ProviderName() string {
- return windows.UTF16ToString(data.providerName[:])
-}
-
-func (data *DrvInfoData) SetProviderName(providerName string) error {
- str, err := windows.UTF16FromString(providerName)
- if err != nil {
- return err
- }
- copy(data.providerName[:], str)
- return nil
-}
-
-// IsNewer method returns true if DrvInfoData date and version is newer than supplied parameters.
-func (data *DrvInfoData) IsNewer(driverDate windows.Filetime, driverVersion uint64) bool {
- if data.DriverDate.HighDateTime > driverDate.HighDateTime {
- return true
- }
- if data.DriverDate.HighDateTime < driverDate.HighDateTime {
- return false
- }
-
- if data.DriverDate.LowDateTime > driverDate.LowDateTime {
- return true
- }
- if data.DriverDate.LowDateTime < driverDate.LowDateTime {
- return false
- }
-
- if data.DriverVersion > driverVersion {
- return true
- }
- if data.DriverVersion < driverVersion {
- return false
- }
-
- return false
-}
-
-// DrvInfoDetailData is driver information details structure (provides detailed information about a particular driver information structure)
-type DrvInfoDetailData struct {
- size uint32 // Warning: unsafe.Sizeof(DrvInfoDetailData) > sizeof(SP_DRVINFO_DETAIL_DATA) when GOARCH == 386 => use sizeofDrvInfoDetailData const.
- InfDate windows.Filetime
- compatIDsOffset uint32
- compatIDsLength uint32
- _ uintptr
- sectionName [LINE_LEN]uint16
- infFileName [windows.MAX_PATH]uint16
- drvDescription [LINE_LEN]uint16
- hardwareID [ANYSIZE_ARRAY]uint16
-}
-
-func (data *DrvInfoDetailData) SectionName() string {
- return windows.UTF16ToString(data.sectionName[:])
-}
-
-func (data *DrvInfoDetailData) InfFileName() string {
- return windows.UTF16ToString(data.infFileName[:])
-}
-
-func (data *DrvInfoDetailData) DrvDescription() string {
- return windows.UTF16ToString(data.drvDescription[:])
-}
-
-func (data *DrvInfoDetailData) HardwareID() string {
- if data.compatIDsOffset > 1 {
- bufW := data.getBuf()
- return windows.UTF16ToString(bufW[:wcslen(bufW)])
- }
-
- return ""
-}
-
-func (data *DrvInfoDetailData) CompatIDs() []string {
- a := make([]string, 0)
-
- if data.compatIDsLength > 0 {
- bufW := data.getBuf()
- bufW = bufW[data.compatIDsOffset : data.compatIDsOffset+data.compatIDsLength]
- for i := 0; i < len(bufW); {
- j := i + wcslen(bufW[i:])
- if i < j {
- a = append(a, windows.UTF16ToString(bufW[i:j]))
- }
- i = j + 1
- }
- }
-
- return a
-}
-
-func (data *DrvInfoDetailData) getBuf() []uint16 {
- len := (data.size - uint32(unsafe.Offsetof(data.hardwareID))) / 2
- sl := struct {
- addr *uint16
- len int
- cap int
- }{&data.hardwareID[0], int(len), int(len)}
- return *(*[]uint16)(unsafe.Pointer(&sl))
-}
-
-// IsCompatible method tests if given hardware ID matches the driver or is listed on the compatible ID list.
-func (data *DrvInfoDetailData) IsCompatible(hwid string) bool {
- hwidLC := strings.ToLower(hwid)
- if strings.ToLower(data.HardwareID()) == hwidLC {
- return true
- }
- a := data.CompatIDs()
- for i := range a {
- if strings.ToLower(a[i]) == hwidLC {
- return true
- }
- }
-
- return false
-}
-
-// DICD flags control SetupDiCreateDeviceInfo
-type DICD uint32
-
-const (
- DICD_GENERATE_ID DICD = 0x00000001
- DICD_INHERIT_CLASSDRVS DICD = 0x00000002
-)
-
-//
-// SPDIT flags to distinguish between class drivers and
-// device drivers.
-// (Passed in 'DriverType' parameter of driver information list APIs)
-//
-type SPDIT uint32
-
-const (
- SPDIT_NODRIVER SPDIT = 0x00000000
- SPDIT_CLASSDRIVER SPDIT = 0x00000001
- SPDIT_COMPATDRIVER SPDIT = 0x00000002
-)
-
-// DIGCF flags control what is included in the device information set built by SetupDiGetClassDevs
-type DIGCF uint32
-
-const (
- DIGCF_DEFAULT DIGCF = 0x00000001 // only valid with DIGCF_DEVICEINTERFACE
- DIGCF_PRESENT DIGCF = 0x00000002
- DIGCF_ALLCLASSES DIGCF = 0x00000004
- DIGCF_PROFILE DIGCF = 0x00000008
- DIGCF_DEVICEINTERFACE DIGCF = 0x00000010
-)
-
-// DIREG specifies values for SetupDiCreateDevRegKey, SetupDiOpenDevRegKey, and SetupDiDeleteDevRegKey.
-type DIREG uint32
-
-const (
- DIREG_DEV DIREG = 0x00000001 // Open/Create/Delete device key
- DIREG_DRV DIREG = 0x00000002 // Open/Create/Delete driver key
- DIREG_BOTH DIREG = 0x00000004 // Delete both driver and Device key
-)
-
-//
-// SPDRP specifies device registry property codes
-// (Codes marked as read-only (R) may only be used for
-// SetupDiGetDeviceRegistryProperty)
-//
-// These values should cover the same set of registry properties
-// as defined by the CM_DRP codes in cfgmgr32.h.
-//
-// Note that SPDRP codes are zero based while CM_DRP codes are one based!
-//
-type SPDRP uint32
-
-const (
- SPDRP_DEVICEDESC SPDRP = 0x00000000 // DeviceDesc (R/W)
- SPDRP_HARDWAREID SPDRP = 0x00000001 // HardwareID (R/W)
- SPDRP_COMPATIBLEIDS SPDRP = 0x00000002 // CompatibleIDs (R/W)
- SPDRP_SERVICE SPDRP = 0x00000004 // Service (R/W)
- SPDRP_CLASS SPDRP = 0x00000007 // Class (R--tied to ClassGUID)
- SPDRP_CLASSGUID SPDRP = 0x00000008 // ClassGUID (R/W)
- SPDRP_DRIVER SPDRP = 0x00000009 // Driver (R/W)
- SPDRP_CONFIGFLAGS SPDRP = 0x0000000A // ConfigFlags (R/W)
- SPDRP_MFG SPDRP = 0x0000000B // Mfg (R/W)
- SPDRP_FRIENDLYNAME SPDRP = 0x0000000C // FriendlyName (R/W)
- SPDRP_LOCATION_INFORMATION SPDRP = 0x0000000D // LocationInformation (R/W)
- SPDRP_PHYSICAL_DEVICE_OBJECT_NAME SPDRP = 0x0000000E // PhysicalDeviceObjectName (R)
- SPDRP_CAPABILITIES SPDRP = 0x0000000F // Capabilities (R)
- SPDRP_UI_NUMBER SPDRP = 0x00000010 // UiNumber (R)
- SPDRP_UPPERFILTERS SPDRP = 0x00000011 // UpperFilters (R/W)
- SPDRP_LOWERFILTERS SPDRP = 0x00000012 // LowerFilters (R/W)
- SPDRP_BUSTYPEGUID SPDRP = 0x00000013 // BusTypeGUID (R)
- SPDRP_LEGACYBUSTYPE SPDRP = 0x00000014 // LegacyBusType (R)
- SPDRP_BUSNUMBER SPDRP = 0x00000015 // BusNumber (R)
- SPDRP_ENUMERATOR_NAME SPDRP = 0x00000016 // Enumerator Name (R)
- SPDRP_SECURITY SPDRP = 0x00000017 // Security (R/W, binary form)
- SPDRP_SECURITY_SDS SPDRP = 0x00000018 // Security (W, SDS form)
- SPDRP_DEVTYPE SPDRP = 0x00000019 // Device Type (R/W)
- SPDRP_EXCLUSIVE SPDRP = 0x0000001A // Device is exclusive-access (R/W)
- SPDRP_CHARACTERISTICS SPDRP = 0x0000001B // Device Characteristics (R/W)
- SPDRP_ADDRESS SPDRP = 0x0000001C // Device Address (R)
- SPDRP_UI_NUMBER_DESC_FORMAT SPDRP = 0x0000001D // UiNumberDescFormat (R/W)
- SPDRP_DEVICE_POWER_DATA SPDRP = 0x0000001E // Device Power Data (R)
- SPDRP_REMOVAL_POLICY SPDRP = 0x0000001F // Removal Policy (R)
- SPDRP_REMOVAL_POLICY_HW_DEFAULT SPDRP = 0x00000020 // Hardware Removal Policy (R)
- SPDRP_REMOVAL_POLICY_OVERRIDE SPDRP = 0x00000021 // Removal Policy Override (RW)
- SPDRP_INSTALL_STATE SPDRP = 0x00000022 // Device Install State (R)
- SPDRP_LOCATION_PATHS SPDRP = 0x00000023 // Device Location Paths (R)
- SPDRP_BASE_CONTAINERID SPDRP = 0x00000024 // Base ContainerID (R)
-
- SPDRP_MAXIMUM_PROPERTY SPDRP = 0x00000025 // Upper bound on ordinals
-)
-
-const (
- CR_SUCCESS = 0x0
- CR_BUFFER_SMALL = 0x1a
-)
-
-const (
- CM_GET_DEVICE_INTERFACE_LIST_PRESENT = 0 // only currently 'live' device interfaces
- CM_GET_DEVICE_INTERFACE_LIST_ALL_DEVICES = 1 // all registered device interfaces, live or not
-)
diff --git a/tun/wintun/setupapi/types_windows_386.go b/tun/wintun/setupapi/types_windows_386.go
deleted file mode 100644
index 132f921..0000000
--- a/tun/wintun/setupapi/types_windows_386.go
+++ /dev/null
@@ -1,11 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-const (
- sizeofDevInfoListDetailData uint32 = 550
- sizeofDrvInfoDetailData uint32 = 1570
-)
diff --git a/tun/wintun/setupapi/types_windows_amd64.go b/tun/wintun/setupapi/types_windows_amd64.go
deleted file mode 100644
index d4dd65c..0000000
--- a/tun/wintun/setupapi/types_windows_amd64.go
+++ /dev/null
@@ -1,11 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-const (
- sizeofDevInfoListDetailData uint32 = 560
- sizeofDrvInfoDetailData uint32 = 1584
-)
diff --git a/tun/wintun/setupapi/zsetupapi_windows.go b/tun/wintun/setupapi/zsetupapi_windows.go
deleted file mode 100644
index 375862d..0000000
--- a/tun/wintun/setupapi/zsetupapi_windows.go
+++ /dev/null
@@ -1,398 +0,0 @@
-// Code generated by 'go generate'; DO NOT EDIT.
-
-package setupapi
-
-import (
- "syscall"
- "unsafe"
-
- "golang.org/x/sys/windows"
-)
-
-var _ unsafe.Pointer
-
-// Do the interface allocations only once for common
-// Errno values.
-const (
- errnoERROR_IO_PENDING = 997
-)
-
-var (
- errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
-)
-
-// errnoErr returns common boxed Errno values, to prevent
-// allocations at runtime.
-func errnoErr(e syscall.Errno) error {
- switch e {
- case 0:
- return nil
- case errnoERROR_IO_PENDING:
- return errERROR_IO_PENDING
- }
- // TODO: add more here, after collecting data on the common
- // error values see on Windows. (perhaps when running
- // all.bat?)
- return e
-}
-
-var (
- modsetupapi = windows.NewLazySystemDLL("setupapi.dll")
- modCfgMgr32 = windows.NewLazySystemDLL("CfgMgr32.dll")
-
- procSetupDiCreateDeviceInfoListExW = modsetupapi.NewProc("SetupDiCreateDeviceInfoListExW")
- procSetupDiGetDeviceInfoListDetailW = modsetupapi.NewProc("SetupDiGetDeviceInfoListDetailW")
- procSetupDiCreateDeviceInfoW = modsetupapi.NewProc("SetupDiCreateDeviceInfoW")
- procSetupDiEnumDeviceInfo = modsetupapi.NewProc("SetupDiEnumDeviceInfo")
- procSetupDiDestroyDeviceInfoList = modsetupapi.NewProc("SetupDiDestroyDeviceInfoList")
- procSetupDiBuildDriverInfoList = modsetupapi.NewProc("SetupDiBuildDriverInfoList")
- procSetupDiCancelDriverInfoSearch = modsetupapi.NewProc("SetupDiCancelDriverInfoSearch")
- procSetupDiEnumDriverInfoW = modsetupapi.NewProc("SetupDiEnumDriverInfoW")
- procSetupDiGetSelectedDriverW = modsetupapi.NewProc("SetupDiGetSelectedDriverW")
- procSetupDiSetSelectedDriverW = modsetupapi.NewProc("SetupDiSetSelectedDriverW")
- procSetupDiGetDriverInfoDetailW = modsetupapi.NewProc("SetupDiGetDriverInfoDetailW")
- procSetupDiDestroyDriverInfoList = modsetupapi.NewProc("SetupDiDestroyDriverInfoList")
- procSetupDiGetClassDevsExW = modsetupapi.NewProc("SetupDiGetClassDevsExW")
- procSetupDiCallClassInstaller = modsetupapi.NewProc("SetupDiCallClassInstaller")
- procSetupDiOpenDevRegKey = modsetupapi.NewProc("SetupDiOpenDevRegKey")
- procSetupDiGetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiGetDeviceRegistryPropertyW")
- procSetupDiSetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiSetDeviceRegistryPropertyW")
- procSetupDiGetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiGetDeviceInstallParamsW")
- procSetupDiGetDeviceInstanceIdW = modsetupapi.NewProc("SetupDiGetDeviceInstanceIdW")
- procSetupDiGetClassInstallParamsW = modsetupapi.NewProc("SetupDiGetClassInstallParamsW")
- procSetupDiSetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiSetDeviceInstallParamsW")
- procSetupDiSetClassInstallParamsW = modsetupapi.NewProc("SetupDiSetClassInstallParamsW")
- procSetupDiClassNameFromGuidExW = modsetupapi.NewProc("SetupDiClassNameFromGuidExW")
- procSetupDiClassGuidsFromNameExW = modsetupapi.NewProc("SetupDiClassGuidsFromNameExW")
- procSetupDiGetSelectedDevice = modsetupapi.NewProc("SetupDiGetSelectedDevice")
- procSetupDiSetSelectedDevice = modsetupapi.NewProc("SetupDiSetSelectedDevice")
- procCM_Get_Device_Interface_List_SizeW = modCfgMgr32.NewProc("CM_Get_Device_Interface_List_SizeW")
- procCM_Get_Device_Interface_ListW = modCfgMgr32.NewProc("CM_Get_Device_Interface_ListW")
-)
-
-func setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) {
- r0, _, e1 := syscall.Syscall6(procSetupDiCreateDeviceInfoListExW.Addr(), 4, uintptr(unsafe.Pointer(classGUID)), uintptr(hwndParent), uintptr(unsafe.Pointer(machineName)), uintptr(reserved), 0, 0)
- handle = DevInfo(r0)
- if handle == DevInfo(windows.InvalidHandle) {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo, deviceInfoSetDetailData *DevInfoListDetailData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiGetDeviceInfoListDetailW.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoSetDetailData)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiCreateDeviceInfo(deviceInfoSet DevInfo, DeviceName *uint16, classGUID *windows.GUID, DeviceDescription *uint16, hwndParent uintptr, CreationFlags DICD, deviceInfoData *DevInfoData) (err error) {
- r1, _, e1 := syscall.Syscall9(procSetupDiCreateDeviceInfoW.Addr(), 7, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(DeviceName)), uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(DeviceDescription)), uintptr(hwndParent), uintptr(CreationFlags), uintptr(unsafe.Pointer(deviceInfoData)), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex uint32, deviceInfoData *DevInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiEnumDeviceInfo.Addr(), 3, uintptr(deviceInfoSet), uintptr(memberIndex), uintptr(unsafe.Pointer(deviceInfoData)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiDestroyDeviceInfoList(deviceInfoSet DevInfo) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiDestroyDeviceInfoList.Addr(), 1, uintptr(deviceInfoSet), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiBuildDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiBuildDriverInfoList.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiCancelDriverInfoSearch(deviceInfoSet DevInfo) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiCancelDriverInfoSearch.Addr(), 1, uintptr(deviceInfoSet), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex uint32, driverInfoData *DrvInfoData) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiEnumDriverInfoW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType), uintptr(memberIndex), uintptr(unsafe.Pointer(driverInfoData)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiGetSelectedDriverW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiSetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiSetSelectedDriverW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData, driverInfoDetailData *DrvInfoDetailData, driverInfoDetailDataSize uint32, requiredSize *uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiGetDriverInfoDetailW.Addr(), 6, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData)), uintptr(unsafe.Pointer(driverInfoDetailData)), uintptr(driverInfoDetailDataSize), uintptr(unsafe.Pointer(requiredSize)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiDestroyDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiDestroyDriverInfoList.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetClassDevsEx(classGUID *windows.GUID, Enumerator *uint16, hwndParent uintptr, Flags DIGCF, deviceInfoSet DevInfo, machineName *uint16, reserved uintptr) (handle DevInfo, err error) {
- r0, _, e1 := syscall.Syscall9(procSetupDiGetClassDevsExW.Addr(), 7, uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(Enumerator)), uintptr(hwndParent), uintptr(Flags), uintptr(deviceInfoSet), uintptr(unsafe.Pointer(machineName)), uintptr(reserved), 0, 0)
- handle = DevInfo(r0)
- if handle == DevInfo(windows.InvalidHandle) {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiCallClassInstaller(installFunction DI_FUNCTION, deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiCallClassInstaller.Addr(), 3, uintptr(installFunction), uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (key windows.Handle, err error) {
- r0, _, e1 := syscall.Syscall6(procSetupDiOpenDevRegKey.Addr(), 6, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(Scope), uintptr(HwProfile), uintptr(KeyType), uintptr(samDesired))
- key = windows.Handle(r0)
- if key == windows.InvalidHandle {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) {
- r1, _, e1 := syscall.Syscall9(procSetupDiGetDeviceRegistryPropertyW.Addr(), 7, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(property), uintptr(unsafe.Pointer(propertyRegDataType)), uintptr(unsafe.Pointer(propertyBuffer)), uintptr(propertyBufferSize), uintptr(unsafe.Pointer(requiredSize)), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffer *byte, propertyBufferSize uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiSetDeviceRegistryPropertyW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(property), uintptr(unsafe.Pointer(propertyBuffer)), uintptr(propertyBufferSize), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiGetDeviceInstallParamsW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(deviceInstallParams)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiGetDeviceInstanceIdW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(instanceId)), uintptr(instanceIdSize), uintptr(unsafe.Pointer(instanceIdRequiredSize)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiGetClassInstallParamsW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), uintptr(unsafe.Pointer(requiredSize)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiSetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiSetDeviceInstallParamsW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(deviceInstallParams)))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiSetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiSetClassInstallParamsW.Addr(), 4, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), 0, 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiClassNameFromGuidEx(classGUID *windows.GUID, className *uint16, classNameSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiClassNameFromGuidExW.Addr(), 6, uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(className)), uintptr(classNameSize), uintptr(unsafe.Pointer(requiredSize)), uintptr(unsafe.Pointer(machineName)), uintptr(reserved))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiClassGuidsFromNameEx(className *uint16, classGuidList *windows.GUID, classGuidListSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) {
- r1, _, e1 := syscall.Syscall6(procSetupDiClassGuidsFromNameExW.Addr(), 6, uintptr(unsafe.Pointer(className)), uintptr(unsafe.Pointer(classGuidList)), uintptr(classGuidListSize), uintptr(unsafe.Pointer(requiredSize)), uintptr(unsafe.Pointer(machineName)), uintptr(reserved))
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiGetSelectedDevice.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func SetupDiSetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) {
- r1, _, e1 := syscall.Syscall(procSetupDiSetSelectedDevice.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), 0)
- if r1 == 0 {
- if e1 != 0 {
- err = errnoErr(e1)
- } else {
- err = syscall.EINVAL
- }
- }
- return
-}
-
-func cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) {
- r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_List_SizeW.Addr(), 4, uintptr(unsafe.Pointer(len)), uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(flags), 0, 0)
- ret = uint32(r0)
- return
-}
-
-func cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) {
- r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_ListW.Addr(), 5, uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(unsafe.Pointer(buffer)), uintptr(bufferLen), uintptr(flags), 0)
- ret = uint32(r0)
- return
-}
diff --git a/tun/wintun/setupapi/zsetupapi_windows_test.go b/tun/wintun/setupapi/zsetupapi_windows_test.go
deleted file mode 100644
index 915b427..0000000
--- a/tun/wintun/setupapi/zsetupapi_windows_test.go
+++ /dev/null
@@ -1,20 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package setupapi
-
-import (
- "syscall"
- "testing"
-
- "golang.org/x/sys/windows"
-)
-
-func TestSetupDiDestroyDeviceInfoList(t *testing.T) {
- err := SetupDiDestroyDeviceInfoList(DevInfo(windows.InvalidHandle))
- if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_HANDLE {
- t.Errorf("SetupDiDestroyDeviceInfoList(nil, ...) should fail with ERROR_INVALID_HANDLE")
- }
-}
diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go
deleted file mode 100644
index 4c12d97..0000000
--- a/tun/wintun/wintun_windows.go
+++ /dev/null
@@ -1,803 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
- */
-
-package wintun
-
-import (
- "errors"
- "fmt"
- "strings"
- "time"
- "unsafe"
-
- "golang.org/x/sys/windows"
- "golang.org/x/sys/windows/registry"
-
- "golang.zx2c4.com/wireguard/tun/wintun/iphlpapi"
- "golang.zx2c4.com/wireguard/tun/wintun/nci"
- registryEx "golang.zx2c4.com/wireguard/tun/wintun/registry"
- "golang.zx2c4.com/wireguard/tun/wintun/setupapi"
-)
-
-type Pool string
-
-type Interface struct {
- cfgInstanceID windows.GUID
- devInstanceID string
- luidIndex uint32
- ifType uint32
- pool Pool
-}
-
-var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
-var deviceInterfaceNetGUID = windows.GUID{Data1: 0xcac88484, Data2: 0x7515, Data3: 0x4c03, Data4: [8]byte{0x82, 0xe6, 0x71, 0xa8, 0x7a, 0xba, 0xc3, 0x61}}
-
-const (
- hardwareID = "Wintun"
- waitForRegistryTimeout = time.Second * 10
-)
-
-// makeWintun creates a Wintun interface handle and populates it from the device's registry key.
-func makeWintun(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) {
- // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
- key, err := devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
- if err != nil {
- return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
- }
- defer key.Close()
-
- // Read the NetCfgInstanceId value.
- valueStr, err := registryEx.GetStringValue(key, "NetCfgInstanceId")
- if err != nil {
- return nil, fmt.Errorf("RegQueryStringValue(\"NetCfgInstanceId\") failed: %v", err)
- }
-
- // Convert to GUID.
- ifid, err := windows.GUIDFromString(valueStr)
- if err != nil {
- return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: %q)", valueStr)
- }
-
- // Read the NetLuidIndex value.
- luidIdx, _, err := key.GetIntegerValue("NetLuidIndex")
- if err != nil {
- return nil, fmt.Errorf("RegQueryValue(\"NetLuidIndex\") failed: %v", err)
- }
-
- // Read the NetLuidIndex value.
- ifType, _, err := key.GetIntegerValue("*IfType")
- if err != nil {
- return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
- }
-
- instanceID, err := devInfo.DeviceInstanceID(devInfoData)
- if err != nil {
- return nil, fmt.Errorf("DeviceInstanceID failed: %v", err)
- }
-
- return &Interface{
- cfgInstanceID: ifid,
- devInstanceID: instanceID,
- luidIndex: uint32(luidIdx),
- ifType: uint32(ifType),
- pool: pool,
- }, nil
-}
-
-func removeNumberedSuffix(ifname string) string {
- removed := strings.TrimRight(ifname, "0123456789")
- if removed != ifname && len(removed) > 1 && removed[len(removed)-1] == ' ' {
- return removed[:len(removed)-1]
- }
- return ifname
-}
-
-// GetInterface finds a Wintun interface by its name. This function returns
-// the interface if found, or windows.ERROR_OBJECT_NOT_FOUND otherwise. If
-// the interface is found but not a Wintun-class or a member of the pool,
-// this function returns windows.ERROR_ALREADY_EXISTS.
-func (pool Pool) GetInterface(ifname string) (*Interface, error) {
- mutex, err := pool.takeNameMutex()
- if err != nil {
- return nil, err
- }
- defer func() {
- windows.ReleaseMutex(mutex)
- windows.CloseHandle(mutex)
- }()
-
- // Create a list of network devices.
- devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
- if err != nil {
- return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err)
- }
- defer devInfo.Close()
-
- // Windows requires each interface to have a different name. When
- // enforcing this, Windows treats interface names case-insensitive. If an
- // interface "FooBar" exists and this function reports there is no
- // interface "foobar", an attempt to create a new interface and name it
- // "foobar" would cause conflict with "FooBar".
- ifname = strings.ToLower(ifname)
-
- for index := 0; ; index++ {
- devInfoData, err := devInfo.EnumDeviceInfo(index)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
- property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
- if err != nil {
- continue
- }
- if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
- continue
- }
-
- wintun, err := makeWintun(devInfo, devInfoData, pool)
- if err != nil {
- continue
- }
-
- // TODO: is there a better way than comparing ifnames?
- ifname2, err := wintun.Name()
- if err != nil {
- continue
- }
- ifname2 = strings.ToLower(ifname2)
- ifname3 := removeNumberedSuffix(ifname2)
-
- if ifname == ifname2 || ifname == ifname3 {
- err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
- if err != nil {
- return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
- }
- defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
-
- for index := 0; ; index++ {
- driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- // Get driver info details.
- driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
- if err != nil {
- continue
- }
-
- if driverDetailData.IsCompatible(hardwareID) {
- isMember, err := pool.isMember(devInfo, devInfoData)
- if err != nil {
- return nil, err
- }
- if !isMember {
- return nil, windows.ERROR_ALREADY_EXISTS
- }
-
- return wintun, nil
- }
- }
-
- // This interface is not using Wintun driver.
- return nil, windows.ERROR_ALREADY_EXISTS
- }
- }
-
- return nil, windows.ERROR_OBJECT_NOT_FOUND
-}
-
-// CreateInterface creates a Wintun interface. ifname is the requested name of
-// the interface, while requestedGUID is the GUID of the created network
-// interface, which then influences NLA generation deterministically. If it is
-// set to nil, the GUID is chosen by the system at random, and hence a new NLA
-// entry is created for each new interface. It is called "requested" GUID
-// because the API it uses is completely undocumented, and so there could be minor
-// interesting complications with its usage. This function returns the network
-// interface ID and a flag if reboot is required.
-func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wintun *Interface, rebootRequired bool, err error) {
- mutex, err := pool.takeNameMutex()
- if err != nil {
- return
- }
- defer func() {
- windows.ReleaseMutex(mutex)
- windows.CloseHandle(mutex)
- }()
-
- // Create an empty device info set for network adapter device class.
- devInfo, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
- if err != nil {
- err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err)
- return
- }
- defer devInfo.Close()
-
- // Get the device class name from GUID.
- className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
- if err != nil {
- err = fmt.Errorf("SetupDiClassNameFromGuidEx(%v) failed: %v", deviceClassNetGUID, err)
- return
- }
-
- // Create a new device info element and add it to the device info set.
- deviceTypeName := pool.deviceTypeName()
- devInfoData, err := devInfo.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID)
- if err != nil {
- err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
- return
- }
-
- err = setQuietInstall(devInfo, devInfoData)
- if err != nil {
- err = fmt.Errorf("Setting quiet installation failed: %v", err)
- return
- }
-
- // Set a device information element as the selected member of a device information set.
- err = devInfo.SetSelectedDevice(devInfoData)
- if err != nil {
- err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
- return
- }
-
- // Set Plug&Play device hardware ID property.
- err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_HARDWAREID, hardwareID)
- if err != nil {
- err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
- return
- }
-
- err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
- if err != nil {
- err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
- return
- }
- defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
-
- driverDate := windows.Filetime{}
- driverVersion := uint64(0)
- for index := 0; ; index++ { // TODO: This loop takes ~600ms
- driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- // Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match.
- if driverData.IsNewer(driverDate, driverVersion) {
- driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
- if err != nil {
- continue
- }
-
- if driverDetailData.IsCompatible(hardwareID) {
- err := devInfo.SetSelectedDriver(devInfoData, driverData)
- if err != nil {
- continue
- }
-
- driverDate = driverData.DriverDate
- driverVersion = driverData.DriverVersion
- }
- }
- }
-
- if driverVersion == 0 {
- err = fmt.Errorf("No driver for device %q installed", hardwareID)
- return
- }
-
- defer func() {
- if err != nil {
- // The interface failed to install, or the interface ID was unobtainable. Clean-up.
- removeDeviceParams := setupapi.RemoveDeviceParams{
- ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
- Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
- }
-
- // Set class installer parameters for DIF_REMOVE.
- if devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
- // Call appropriate class installer.
- if devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) == nil {
- rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
- }
- }
-
- wintun = nil
- }
- }()
-
- // Call appropriate class installer.
- err = devInfo.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, devInfoData)
- if err != nil {
- err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
- return
- }
-
- // Register device co-installers if any. (Ignore errors)
- devInfo.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, devInfoData)
-
- var netDevRegKey registry.Key
- const pollTimeout = time.Millisecond * 50
- for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ {
- if i != 0 {
- time.Sleep(pollTimeout)
- }
- netDevRegKey, err = devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
- if err == nil {
- break
- }
- }
- if err != nil {
- err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err)
- return
- }
- defer netDevRegKey.Close()
- if requestedGUID != nil {
- err = netDevRegKey.SetStringValue("NetSetupAnticipatedInstanceId", requestedGUID.String())
- if err != nil {
- err = fmt.Errorf("SetStringValue(NetSetupAnticipatedInstanceId) failed: %v", err)
- return
- }
- }
-
- // Install interfaces if any. (Ignore errors)
- devInfo.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, devInfoData)
-
- // Install the device.
- err = devInfo.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, devInfoData)
- if err != nil {
- err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
- return
- }
- rebootRequired = checkReboot(devInfo, devInfoData)
-
- err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_DEVICEDESC, deviceTypeName)
- if err != nil {
- err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
- return
- }
-
- // DIF_INSTALLDEVICE returns almost immediately, while the device installation
- // continues in the background. It might take a while, before all registry
- // keys and values are populated.
- _, err = registryEx.GetStringValueWait(netDevRegKey, "NetCfgInstanceId", waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("GetStringValueWait(NetCfgInstanceId) failed: %v", err)
- return
- }
- _, err = registryEx.GetIntegerValueWait(netDevRegKey, "NetLuidIndex", waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("GetIntegerValueWait(NetLuidIndex) failed: %v", err)
- return
- }
- _, err = registryEx.GetIntegerValueWait(netDevRegKey, "*IfType", waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("GetIntegerValueWait(*IfType) failed: %v", err)
- return
- }
-
- // Get network interface.
- wintun, err = makeWintun(devInfo, devInfoData, pool)
- if err != nil {
- err = fmt.Errorf("makeWintun failed: %v", err)
- return
- }
-
- // Wait for TCP/IP adapter registry key to emerge and populate.
- tcpipAdapterRegKey, err := registryEx.OpenKeyWait(
- registry.LOCAL_MACHINE,
- wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE|registry.NOTIFY,
- waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", wintun.tcpipAdapterRegKeyName(), err)
- return
- }
- defer tcpipAdapterRegKey.Close()
- _, err = registryEx.GetStringValueWait(tcpipAdapterRegKey, "IpConfig", waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("GetStringValueWait(IpConfig) failed: %v", err)
- return
- }
-
- tcpipInterfaceRegKeyName, err := wintun.tcpipInterfaceRegKeyName()
- if err != nil {
- err = fmt.Errorf("tcpipInterfaceRegKeyName failed: %v", err)
- return
- }
-
- // Wait for TCP/IP interface registry key to emerge.
- tcpipInterfaceRegKey, err := registryEx.OpenKeyWait(
- registry.LOCAL_MACHINE,
- tcpipInterfaceRegKeyName, registry.QUERY_VALUE|registry.SET_VALUE,
- waitForRegistryTimeout)
- if err != nil {
- err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", tcpipInterfaceRegKeyName, err)
- return
- }
- defer tcpipInterfaceRegKey.Close()
- // Disable dead gateway detection on our interface.
- tcpipInterfaceRegKey.SetDWordValue("EnableDeadGWDetect", 0)
-
- err = wintun.SetName(ifname)
- if err != nil {
- err = fmt.Errorf("Unable to set name of Wintun interface: %v", err)
- return
- }
-
- return
-}
-
-// DeleteInterface deletes a Wintun interface. This function succeeds
-// if the interface was not found. It returns a bool indicating whether
-// a reboot is required.
-func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
- devInfo, devInfoData, err := wintun.devInfoData()
- if err == windows.ERROR_OBJECT_NOT_FOUND {
- return false, nil
- }
- if err != nil {
- return false, err
- }
- defer devInfo.Close()
-
- // Remove the device.
- removeDeviceParams := setupapi.RemoveDeviceParams{
- ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
- Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
- }
-
- // Set class installer parameters for DIF_REMOVE.
- err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
- if err != nil {
- return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
- }
-
- // Call appropriate class installer.
- err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
- if err != nil {
- return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
- }
-
- return checkReboot(devInfo, devInfoData), nil
-}
-
-// DeleteMatchingInterfaces deletes all Wintun interfaces, which match
-// given criteria, and returns which ones it deleted, whether a reboot
-// is required after, and which errors occurred during the process.
-func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool) (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) {
- mutex, err := pool.takeNameMutex()
- if err != nil {
- errors = append(errors, err)
- return
- }
- defer func() {
- windows.ReleaseMutex(mutex)
- windows.CloseHandle(mutex)
- }()
-
- devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
- if err != nil {
- return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())}
- }
- defer devInfo.Close()
-
- for i := 0; ; i++ {
- devInfoData, err := devInfo.EnumDeviceInfo(i)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
- property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
- if err != nil {
- continue
- }
- if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
- continue
- }
-
- err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
- if err != nil {
- continue
- }
- defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
-
- isWintun := false
- for j := 0; ; j++ {
- driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, j)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
- driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
- if err != nil {
- continue
- }
- if driverDetailData.IsCompatible(hardwareID) {
- isWintun = true
- break
- }
- }
- if !isWintun {
- continue
- }
-
- isMember, err := pool.isMember(devInfo, devInfoData)
- if err != nil {
- errors = append(errors, err)
- continue
- }
- if !isMember {
- continue
- }
-
- wintun, err := makeWintun(devInfo, devInfoData, pool)
- if err != nil {
- errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err))
- continue
- }
- if !matches(wintun) {
- continue
- }
-
- err = setQuietInstall(devInfo, devInfoData)
- if err != nil {
- errors = append(errors, err)
- continue
- }
-
- inst := devInfoData.DevInst
- removeDeviceParams := setupapi.RemoveDeviceParams{
- ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
- Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
- }
- err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
- if err != nil {
- errors = append(errors, err)
- continue
- }
- err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
- if err != nil {
- errors = append(errors, err)
- continue
- }
- rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
- deviceInstancesDeleted = append(deviceInstancesDeleted, inst)
- }
- return
-}
-
-// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name.
-func (pool Pool) isMember(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) (bool, error) {
- deviceDescVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_DEVICEDESC)
- if err != nil {
- return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
- }
- deviceDesc, _ := deviceDescVal.(string)
- friendlyNameVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_FRIENDLYNAME)
- if err != nil {
- return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err)
- }
- friendlyName, _ := friendlyNameVal.(string)
- deviceTypeName := pool.deviceTypeName()
- return friendlyName == deviceTypeName || deviceDesc == deviceTypeName ||
- removeNumberedSuffix(friendlyName) == deviceTypeName || removeNumberedSuffix(deviceDesc) == deviceTypeName, nil
-}
-
-// checkReboot checks device install parameters if a system reboot is required.
-func checkReboot(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) bool {
- devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
- if err != nil {
- return false
- }
-
- return (devInstallParams.Flags & (setupapi.DI_NEEDREBOOT | setupapi.DI_NEEDRESTART)) != 0
-}
-
-// setQuietInstall sets device install parameters for a quiet installation
-func setQuietInstall(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) error {
- devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
- if err != nil {
- return err
- }
-
- devInstallParams.Flags |= setupapi.DI_QUIETINSTALL
- return devInfo.SetDeviceInstallParams(devInfoData, devInstallParams)
-}
-
-// deviceTypeName returns pool-specific device type name.
-func (pool Pool) deviceTypeName() string {
- return fmt.Sprintf("%s Tunnel", pool)
-}
-
-// Name returns the name of the Wintun interface.
-func (wintun *Interface) Name() (string, error) {
- return nci.ConnectionName(&wintun.cfgInstanceID)
-}
-
-// SetName sets name of the Wintun interface.
-func (wintun *Interface) SetName(ifname string) error {
- const maxSuffix = 1000
- availableIfname := ifname
- for i := 0; ; i++ {
- err := nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname)
- if err == windows.ERROR_DUP_NAME {
- duplicateGuid, err2 := iphlpapi.InterfaceGUIDFromAlias(availableIfname)
- if err2 == nil {
- for j := 0; j < maxSuffix; j++ {
- proposal := fmt.Sprintf("%s %d", ifname, j+1)
- if proposal == availableIfname {
- continue
- }
- err2 = nci.SetConnectionName(duplicateGuid, proposal)
- if err2 == windows.ERROR_DUP_NAME {
- continue
- }
- if err2 == nil {
- err = nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname)
- if err == nil {
- break
- }
- }
- break
- }
- }
- }
- if err == nil {
- break
- }
-
- if i > maxSuffix || err != windows.ERROR_DUP_NAME {
- return fmt.Errorf("NciSetConnectionName failed: %v", err)
- }
- availableIfname = fmt.Sprintf("%s %d", ifname, i+1)
- }
-
- // TODO: This should use NetSetup2 so that it doesn't get unset.
- deviceRegKey, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.deviceRegKeyName(), registry.SET_VALUE)
- if err != nil {
- return fmt.Errorf("Device-level registry key open failed: %v", err)
- }
- defer deviceRegKey.Close()
- err = deviceRegKey.SetStringValue("FriendlyName", wintun.pool.deviceTypeName())
- if err != nil {
- return fmt.Errorf("SetStringValue(FriendlyName) failed: %v", err)
- }
- return nil
-}
-
-// tcpipAdapterRegKeyName returns the adapter-specific TCP/IP network registry key name.
-func (wintun *Interface) tcpipAdapterRegKeyName() string {
- return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID)
-}
-
-// deviceRegKeyName returns the device-level registry key name.
-func (wintun *Interface) deviceRegKeyName() string {
- return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Enum\\%v", wintun.devInstanceID)
-}
-
-// Version returns the version of the Wintun driver and NDIS system currently loaded.
-func (wintun *Interface) Version() (driverVersion string, ndisVersion string, err error) {
- key, err := registry.OpenKey(registry.LOCAL_MACHINE, "SYSTEM\\CurrentControlSet\\Services\\Wintun", registry.QUERY_VALUE)
- if err != nil {
- return
- }
- defer key.Close()
- driverMajor, _, err := key.GetIntegerValue("DriverMajorVersion")
- if err != nil {
- return
- }
- driverMinor, _, err := key.GetIntegerValue("DriverMinorVersion")
- if err != nil {
- return
- }
- ndisMajor, _, err := key.GetIntegerValue("NdisMajorVersion")
- if err != nil {
- return
- }
- ndisMinor, _, err := key.GetIntegerValue("NdisMinorVersion")
- if err != nil {
- return
- }
- driverVersion = fmt.Sprintf("%d.%d", driverMajor, driverMinor)
- ndisVersion = fmt.Sprintf("%d.%d", ndisMajor, ndisMinor)
- return
-}
-
-// tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name.
-func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
- key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE)
- if err != nil {
- return "", fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err)
- }
- paths, _, err := key.GetStringsValue("IpConfig")
- key.Close()
- if err != nil {
- return "", fmt.Errorf("Error reading IpConfig registry key: %v", err)
- }
- if len(paths) == 0 {
- return "", errors.New("No TCP/IP interfaces found on adapter")
- }
- return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
-}
-
-// devInfoData returns TUN device info list handle and interface device info
-// data. The device info list handle must be closed after use. In case the
-// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned.
-func (wintun *Interface) devInfoData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
- // Create a list of network devices.
- devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
- if err != nil {
- return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())
- }
-
- for index := 0; ; index++ {
- devInfoData, err := devInfo.EnumDeviceInfo(index)
- if err != nil {
- if err == windows.ERROR_NO_MORE_ITEMS {
- break
- }
- continue
- }
-
- // Get interface ID.
- // TODO: Store some ID in the Wintun object such that this call isn't required.
- wintun2, err := makeWintun(devInfo, devInfoData, wintun.pool)
- if err != nil {
- continue
- }
-
- if wintun.cfgInstanceID == wintun2.cfgInstanceID {
- err = setQuietInstall(devInfo, devInfoData)
- if err != nil {
- devInfo.Close()
- return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err)
- }
- return devInfo, devInfoData, nil
- }
- }
-
- devInfo.Close()
- return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
-}
-
-// handle returns a handle to the interface device object.
-func (wintun *Interface) handle() (windows.Handle, error) {
- interfaces, err := setupapi.CM_Get_Device_Interface_List(wintun.devInstanceID, &deviceInterfaceNetGUID, setupapi.CM_GET_DEVICE_INTERFACE_LIST_PRESENT)
- if err != nil {
- return windows.InvalidHandle, fmt.Errorf("Error listing NDIS interfaces: %v", err)
- }
- handle, err := windows.CreateFile(windows.StringToUTF16Ptr(interfaces[0]), windows.GENERIC_READ|windows.GENERIC_WRITE, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, 0, 0)
- if err != nil {
- return windows.InvalidHandle, fmt.Errorf("Error opening NDIS device: %v", err)
- }
- return handle, nil
-}
-
-// GUID returns the GUID of the interface.
-func (wintun *Interface) GUID() windows.GUID {
- return wintun.cfgInstanceID
-}
-
-// LUID returns the LUID of the interface.
-func (wintun *Interface) LUID() uint64 {
- return ((uint64(wintun.luidIndex) & ((1 << 24) - 1)) << 24) | ((uint64(wintun.ifType) & ((1 << 16) - 1)) << 48)
-}