diff options
Diffstat (limited to 'tun')
42 files changed, 4127 insertions, 4590 deletions
diff --git a/tun/alignment_windows_test.go b/tun/alignment_windows_test.go new file mode 100644 index 0000000..67a785e --- /dev/null +++ b/tun/alignment_windows_test.go @@ -0,0 +1,67 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "reflect" + "testing" + "unsafe" +) + +func checkAlignment(t *testing.T, name string, offset uintptr) { + t.Helper() + if offset%8 != 0 { + t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) + } +} + +// TestRateJugglerAlignment checks that atomically-accessed fields are +// aligned to 64-bit boundaries, as required by the atomic package. +// +// Unfortunately, violating this rule on 32-bit platforms results in a +// hard segfault at runtime. +func TestRateJugglerAlignment(t *testing.T) { + var r rateJuggler + + typ := reflect.TypeOf(&r).Elem() + t.Logf("Peer type size: %d, with fields:", typ.Size()) + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", + field.Name, + field.Offset, + field.Type.Size(), + field.Type.Align(), + ) + } + + checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current)) + checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount)) + checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime)) +} + +// TestNativeTunAlignment checks that atomically-accessed fields are +// aligned to 64-bit boundaries, as required by the atomic package. +// +// Unfortunately, violating this rule on 32-bit platforms results in a +// hard segfault at runtime. +func TestNativeTunAlignment(t *testing.T) { + var tun NativeTun + + typ := reflect.TypeOf(&tun).Elem() + t.Logf("Peer type size: %d, with fields:", typ.Size()) + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", + field.Name, + field.Offset, + field.Type.Size(), + field.Type.Align(), + ) + } + + checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate)) +} diff --git a/tun/checksum.go b/tun/checksum.go new file mode 100644 index 0000000..29a8fc8 --- /dev/null +++ b/tun/checksum.go @@ -0,0 +1,118 @@ +package tun + +import "encoding/binary" + +// TODO: Explore SIMD and/or other assembly optimizations. +// TODO: Test native endian loads. See RFC 1071 section 2 part B. +func checksumNoFold(b []byte, initial uint64) uint64 { + ac := initial + + for len(b) >= 128 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac += uint64(binary.BigEndian.Uint32(b[64:68])) + ac += uint64(binary.BigEndian.Uint32(b[68:72])) + ac += uint64(binary.BigEndian.Uint32(b[72:76])) + ac += uint64(binary.BigEndian.Uint32(b[76:80])) + ac += uint64(binary.BigEndian.Uint32(b[80:84])) + ac += uint64(binary.BigEndian.Uint32(b[84:88])) + ac += uint64(binary.BigEndian.Uint32(b[88:92])) + ac += uint64(binary.BigEndian.Uint32(b[92:96])) + ac += uint64(binary.BigEndian.Uint32(b[96:100])) + ac += uint64(binary.BigEndian.Uint32(b[100:104])) + ac += uint64(binary.BigEndian.Uint32(b[104:108])) + ac += uint64(binary.BigEndian.Uint32(b[108:112])) + ac += uint64(binary.BigEndian.Uint32(b[112:116])) + ac += uint64(binary.BigEndian.Uint32(b[116:120])) + ac += uint64(binary.BigEndian.Uint32(b[120:124])) + ac += uint64(binary.BigEndian.Uint32(b[124:128])) + b = b[128:] + } + if len(b) >= 64 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + b = b[64:] + } + if len(b) >= 32 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] + } + if len(b) >= 16 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + b = b[16:] + } + if len(b) >= 8 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + b = b[8:] + } + if len(b) >= 4 { + ac += uint64(binary.BigEndian.Uint32(b)) + b = b[4:] + } + if len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + + return ac +} + +func checksum(b []byte, initial uint64) uint16 { + ac := checksumNoFold(b, initial) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + return uint16(ac) +} + +func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { + sum := checksumNoFold(srcAddr, 0) + sum = checksumNoFold(dstAddr, sum) + sum = checksumNoFold([]byte{0, protocol}, sum) + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + return checksumNoFold(tmp, sum) +} diff --git a/tun/checksum_test.go b/tun/checksum_test.go new file mode 100644 index 0000000..c1ccff5 --- /dev/null +++ b/tun/checksum_test.go @@ -0,0 +1,35 @@ +package tun + +import ( + "fmt" + "math/rand" + "testing" +) + +func BenchmarkChecksum(b *testing.B) { + lengths := []int{ + 64, + 128, + 256, + 512, + 1024, + 1500, + 2048, + 4096, + 8192, + 9000, + 9001, + } + + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksum(buf, 0) + } + }) + } +} diff --git a/tun/errors.go b/tun/errors.go new file mode 100644 index 0000000..75ae3a4 --- /dev/null +++ b/tun/errors.go @@ -0,0 +1,12 @@ +package tun + +import ( + "errors" +) + +var ( + // ErrTooManySegments is returned by Device.Read() when segmentation + // overflows the length of supplied buffers. This error should not cause + // reads to cease. + ErrTooManySegments = errors.New("too many segments") +) diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go new file mode 100644 index 0000000..ccd32ed --- /dev/null +++ b/tun/netstack/examples/http_client.go @@ -0,0 +1,54 @@ +//go:build ignore + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package main + +import ( + "io" + "log" + "net/http" + "net/netip" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +func main() { + tun, tnet, err := netstack.CreateNetTUN( + []netip.Addr{netip.MustParseAddr("192.168.4.28")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8")}, + 1420) + if err != nil { + log.Panic(err) + } + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) + err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379 +public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28 +allowed_ip=0.0.0.0/0 +endpoint=127.0.0.1:58120 +`) + err = dev.Up() + if err != nil { + log.Panic(err) + } + + client := http.Client{ + Transport: &http.Transport{ + DialContext: tnet.DialContext, + }, + } + resp, err := client.Get("http://192.168.4.29/") + if err != nil { + log.Panic(err) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Panic(err) + } + log.Println(string(body)) +} diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go new file mode 100644 index 0000000..f5b7a8f --- /dev/null +++ b/tun/netstack/examples/http_server.go @@ -0,0 +1,51 @@ +//go:build ignore + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package main + +import ( + "io" + "log" + "net" + "net/http" + "net/netip" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +func main() { + tun, tnet, err := netstack.CreateNetTUN( + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, + 1420, + ) + if err != nil { + log.Panic(err) + } + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) + dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641 +listen_port=58120 +public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c +allowed_ip=192.168.4.28/32 +persistent_keepalive_interval=25 +`) + dev.Up() + listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80}) + if err != nil { + log.Panicln(err) + } + http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { + log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent()) + io.WriteString(writer, "Hello from userspace TCP!") + }) + err = http.Serve(listener, nil) + if err != nil { + log.Panicln(err) + } +} diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go new file mode 100644 index 0000000..2eef0fb --- /dev/null +++ b/tun/netstack/examples/ping_client.go @@ -0,0 +1,75 @@ +//go:build ignore + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package main + +import ( + "bytes" + "log" + "math/rand" + "net/netip" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +func main() { + tun, tnet, err := netstack.CreateNetTUN( + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8")}, + 1420) + if err != nil { + log.Panic(err) + } + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) + dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f +public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b +endpoint=163.172.161.0:12912 +allowed_ip=0.0.0.0/0 +`) + err = dev.Up() + if err != nil { + log.Panic(err) + } + + socket, err := tnet.Dial("ping4", "zx2c4.com") + if err != nil { + log.Panic(err) + } + requestPing := icmp.Echo{ + Seq: rand.Intn(1 << 16), + Data: []byte("gopher burrow"), + } + icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + socket.SetReadDeadline(time.Now().Add(time.Second * 10)) + start := time.Now() + _, err = socket.Write(icmpBytes) + if err != nil { + log.Panic(err) + } + n, err := socket.Read(icmpBytes[:]) + if err != nil { + log.Panic(err) + } + replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) + if err != nil { + log.Panic(err) + } + replyPing, ok := replyPacket.Body.(*icmp.Echo) + if !ok { + log.Panicf("invalid reply type: %v", replyPacket) + } + if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { + log.Panicf("invalid ping reply: %v", replyPing) + } + log.Printf("Ping latency: %v", time.Since(start)) +} diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go new file mode 100644 index 0000000..2b73054 --- /dev/null +++ b/tun/netstack/tun.go @@ -0,0 +1,1055 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package netstack + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" + "os" + "regexp" + "strconv" + "strings" + "syscall" + "time" + + "golang.zx2c4.com/wireguard/tun" + + "golang.org/x/net/dns/dnsmessage" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +type netTun struct { + ep *channel.Endpoint + stack *stack.Stack + events chan tun.Event + incomingPacket chan *buffer.View + mtu int + dnsServers []netip.Addr + hasV4, hasV6 bool +} + +type Net netTun + +func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: true, + } + dev := &netTun{ + ep: channel.New(1024, uint32(mtu), ""), + stack: stack.New(opts), + events: make(chan tun.Event, 10), + incomingPacket: make(chan *buffer.View), + dnsServers: dnsServers, + mtu: mtu, + } + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default + tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) + } + dev.ep.AddNotify(dev) + tcpipErr = dev.stack.CreateNIC(1, dev.ep) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) + } + for _, ip := range localAddresses { + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { + dev.hasV4 = true + } else if ip.Is6() { + dev.hasV6 = true + } + } + if dev.hasV4 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) + } + if dev.hasV6 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) + } + + dev.events <- tun.EventUp + return dev, (*Net)(dev), nil +} + +func (tun *netTun) Name() (string, error) { + return "go", nil +} + +func (tun *netTun) File() *os.File { + return nil +} + +func (tun *netTun) Events() <-chan tun.Event { + return tun.events +} + +func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { + view, ok := <-tun.incomingPacket + if !ok { + return 0, os.ErrClosed + } + + n, err := view.Read(buf[0][offset:]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { + for _, buf := range buf { + packet := buf[offset:] + if len(packet) == 0 { + continue + } + + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) + switch packet[0] >> 4 { + case 4: + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) + case 6: + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + default: + return 0, syscall.EAFNOSUPPORT + } + } + return len(buf), nil +} + +func (tun *netTun) WriteNotify() { + pkt := tun.ep.Read() + if pkt.IsNil() { + return + } + + view := pkt.ToView() + pkt.DecRef() + + tun.incomingPacket <- view +} + +func (tun *netTun) Close() error { + tun.stack.RemoveNIC(1) + + if tun.events != nil { + close(tun.events) + } + + tun.ep.Close() + + if tun.incomingPacket != nil { + close(tun.incomingPacket) + } + + return nil +} + +func (tun *netTun) MTU() (int, error) { + return tun.mtu, nil +} + +func (tun *netTun) BatchSize() int { + return 1 +} + +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber + } else { + protoNumber = ipv6.ProtocolNumber + } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) +} + +func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { + if addr == nil { + return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialTCP(net.stack, fa, pn) +} + +func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { + if addr == nil { + return net.DialTCPAddrPort(netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { + fa, pn := convertToFullAddr(addr) + return gonet.ListenTCP(net.stack, fa, pn) +} + +func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { + if addr == nil { + return net.ListenTCPAddrPort(netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { + var lfa, rfa *tcpip.FullAddress + var pn tcpip.NetworkProtocolNumber + if laddr.IsValid() || laddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(laddr) + lfa = &addr + } + if raddr.IsValid() || raddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(raddr) + rfa = &addr + } + return gonet.DialUDP(net.stack, lfa, rfa, pn) +} + +func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { + return net.DialUDPAddrPort(laddr, netip.AddrPort{}) +} + +func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { + var la, ra netip.AddrPort + if laddr != nil { + ip, _ := netip.AddrFromSlice(laddr.IP) + la = netip.AddrPortFrom(ip, uint16(laddr.Port)) + } + if raddr != nil { + ip, _ := netip.AddrFromSlice(raddr.IP) + ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) + } + return net.DialUDPAddrPort(la, ra) +} + +func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { + return net.DialUDP(laddr, nil) +} + +type PingConn struct { + laddr PingAddr + raddr PingAddr + wq waiter.Queue + ep tcpip.Endpoint + deadline *time.Timer +} + +type PingAddr struct{ addr netip.Addr } + +func (ia PingAddr) String() string { + return ia.addr.String() +} + +func (ia PingAddr) Network() string { + if ia.addr.Is4() { + return "ping4" + } else if ia.addr.Is6() { + return "ping6" + } + return "ping" +} + +func (ia PingAddr) Addr() netip.Addr { + return ia.addr +} + +func PingAddrFromAddr(addr netip.Addr) *PingAddr { + return &PingAddr{addr} +} + +func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) { + if !laddr.IsValid() && !raddr.IsValid() { + return nil, errors.New("ping dial: invalid address") + } + v6 := laddr.Is6() || raddr.Is6() + bind := laddr.IsValid() + if !bind { + if v6 { + laddr = netip.IPv6Unspecified() + } else { + laddr = netip.IPv4Unspecified() + } + } + + tn := icmp.ProtocolNumber4 + pn := ipv4.ProtocolNumber + if v6 { + tn = icmp.ProtocolNumber6 + pn = ipv6.ProtocolNumber + } + + pc := &PingConn{ + laddr: PingAddr{laddr}, + deadline: time.NewTimer(time.Hour << 10), + } + pc.deadline.Stop() + + ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) + if tcpipErr != nil { + return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) + } + pc.ep = ep + + if bind { + fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0)) + if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping bind: %s", tcpipErr) + } + } + + if raddr.IsValid() { + pc.raddr = PingAddr{raddr} + fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0)) + if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping connect: %s", tcpipErr) + } + } + + return pc, nil +} + +func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) { + return net.DialPingAddr(laddr, netip.Addr{}) +} + +func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) { + var la, ra netip.Addr + if laddr != nil { + la = laddr.addr + } + if raddr != nil { + ra = raddr.addr + } + return net.DialPingAddr(la, ra) +} + +func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) { + var la netip.Addr + if laddr != nil { + la = laddr.addr + } + return net.ListenPingAddr(la) +} + +func (pc *PingConn) LocalAddr() net.Addr { + return pc.laddr +} + +func (pc *PingConn) RemoteAddr() net.Addr { + return pc.raddr +} + +func (pc *PingConn) Close() error { + pc.deadline.Reset(0) + pc.ep.Close() + return nil +} + +func (pc *PingConn) SetWriteDeadline(t time.Time) error { + return errors.New("not implemented") +} + +func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + var na netip.Addr + switch v := addr.(type) { + case *PingAddr: + na = v.addr + case *net.IPAddr: + na, _ = netip.AddrFromSlice(v.IP) + default: + return 0, fmt.Errorf("ping write: wrong net.Addr type") + } + if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) { + return 0, fmt.Errorf("ping write: mismatched protocols") + } + + buf := bytes.NewReader(p) + rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) + // won't block, no deadlines + n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ + To: &rfa, + }) + if tcpipErr != nil { + return int(n64), fmt.Errorf("ping write: %s", tcpipErr) + } + + return int(n64), nil +} + +func (pc *PingConn) Write(p []byte) (n int, err error) { + return pc.WriteTo(p, &pc.raddr) +} + +func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + e, notifyCh := waiter.NewChannelEntry(waiter.EventIn) + pc.wq.EventRegister(&e) + defer pc.wq.EventUnregister(&e) + + select { + case <-pc.deadline.C: + return 0, nil, os.ErrDeadlineExceeded + case <-notifyCh: + } + + w := tcpip.SliceWriter(p) + + res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if tcpipErr != nil { + return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) + } + + remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) + return res.Count, &PingAddr{remoteAddr}, nil +} + +func (pc *PingConn) Read(p []byte) (n int, err error) { + n, _, err = pc.ReadFrom(p) + return +} + +func (pc *PingConn) SetDeadline(t time.Time) error { + // pc.SetWriteDeadline is unimplemented + + return pc.SetReadDeadline(t) +} + +func (pc *PingConn) SetReadDeadline(t time.Time) error { + pc.deadline.Reset(time.Until(t)) + return nil +} + +var ( + errNoSuchHost = errors.New("no such host") + errLameReferral = errors.New("lame referral") + errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") + errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") + errServerMisbehaving = errors.New("server misbehaving") + errInvalidDNSResponse = errors.New("invalid DNS response") + errNoAnswerFromDNSServer = errors.New("no answer from DNS server") + errServerTemporarilyMisbehaving = errors.New("server misbehaving") + errCanceled = errors.New("operation was canceled") + errTimeout = errors.New("i/o timeout") + errNumericPort = errors.New("port must be numeric") + errNoSuitableAddress = errors.New("no suitable address found") + errMissingAddress = errors.New("missing address") +) + +func (net *Net) LookupHost(host string) (addrs []string, err error) { + return net.LookupContextHost(context.Background(), host) +} + +func isDomainName(s string) bool { + l := len(s) + if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { + return false + } + last := byte('.') + nonNumeric := false + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + nonNumeric = true + partlen++ + case '0' <= c && c <= '9': + partlen++ + case c == '-': + if last == '.' { + return false + } + partlen++ + nonNumeric = true + case c == '.': + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + return nonNumeric +} + +func randU16() uint16 { + var b [2]byte + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + return binary.LittleEndian.Uint16(b[:]) +} + +func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { + id = randU16() + b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) + b.EnableCompression() + if err := b.StartQuestions(); err != nil { + return 0, nil, nil, err + } + if err := b.Question(q); err != nil { + return 0, nil, nil, err + } + tcpReq, err = b.Finish() + udpReq = tcpReq[2:] + l := len(tcpReq) - 2 + tcpReq[0] = byte(l >> 8) + tcpReq[1] = byte(l) + return id, udpReq, tcpReq, err +} + +func equalASCIIName(x, y dnsmessage.Name) bool { + if x.Length != y.Length { + return false + } + for i := 0; i < int(x.Length); i++ { + a := x.Data[i] + b := y.Data[i] + if 'A' <= a && a <= 'Z' { + a += 0x20 + } + if 'A' <= b && b <= 'Z' { + b += 0x20 + } + if a != b { + return false + } + } + return true +} + +func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { + if !respHdr.Response { + return false + } + if reqID != respHdr.ID { + return false + } + if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { + return false + } + return true +} + +func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + b = make([]byte, 512) + for { + n, err := c.Read(b) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + continue + } + q, err := p.Question() + if err != nil || !checkResponse(id, query, h, q) { + continue + } + return p, h, nil + } +} + +func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + b = make([]byte, 1280) + if _, err := io.ReadFull(c, b[:2]); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + l := int(b[0])<<8 | int(b[1]) + if l > len(b) { + b = make([]byte, l) + } + n, err := io.ReadFull(c, b[:l]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + q, err := p.Question() + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + if !checkResponse(id, query, h, q) { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + return p, h, nil +} + +func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { + q.Class = dnsmessage.ClassINET + id, udpReq, tcpReq, err := newRequest(q) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage + } + + for _, useUDP := range []bool{true, false} { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + + var c net.Conn + var err error + if useUDP { + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) + } else { + c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) + } + + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if d, ok := ctx.Deadline(); ok && !d.IsZero() { + err := c.SetDeadline(d) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + } + var p dnsmessage.Parser + var h dnsmessage.Header + if useUDP { + p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) + } else { + p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) + } + c.Close() + if err != nil { + if err == context.Canceled { + err = errCanceled + } else if err == context.DeadlineExceeded { + err = errTimeout + } + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + if h.Truncated { + continue + } + return p, h, nil + } + return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer +} + +func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { + if h.RCode == dnsmessage.RCodeNameError { + return errNoSuchHost + } + _, err := p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + return errCannotUnmarshalDNSMessage + } + if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { + return errLameReferral + } + if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { + if h.RCode == dnsmessage.RCodeServerFailure { + return errServerTemporarilyMisbehaving + } + return errServerMisbehaving + } + return nil +} + +func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + return errNoSuchHost + } + if err != nil { + return errCannotUnmarshalDNSMessage + } + if h.Type == qtype { + return nil + } + if err := p.SkipAnswer(); err != nil { + return errCannotUnmarshalDNSMessage + } + } +} + +func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { + var lastErr error + + n, err := dnsmessage.NewName(name) + if err != nil { + return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage + } + q := dnsmessage.Question{ + Name: n, + Type: qtype, + Class: dnsmessage.ClassINET, + } + + for i := 0; i < 2; i++ { + for _, server := range tnet.dnsServers { + p, h, err := tnet.exchange(ctx, server, q, time.Second*5) + if err != nil { + dnsErr := &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + dnsErr.IsTimeout = true + } + if _, ok := err.(*net.OpError); ok { + dnsErr.IsTemporary = true + } + lastErr = dnsErr + continue + } + + if err := checkHeader(&p, h); err != nil { + dnsErr := &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if err == errServerTemporarilyMisbehaving { + dnsErr.IsTemporary = true + } + if err == errNoSuchHost { + dnsErr.IsNotFound = true + return p, server.String(), dnsErr + } + lastErr = dnsErr + continue + } + + err = skipToAnswer(&p, qtype) + if err == nil { + return p, server.String(), nil + } + lastErr = &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if err == errNoSuchHost { + lastErr.(*net.DNSError).IsNotFound = true + return p, server.String(), lastErr + } + } + } + return dnsmessage.Parser{}, "", lastErr +} + +func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { + if host == "" || (!tnet.hasV6 && !tnet.hasV4) { + return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + } + zlen := len(host) + if strings.IndexByte(host, ':') != -1 { + if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { + zlen = zidx + } + } + if ip, err := netip.ParseAddr(host[:zlen]); err == nil { + return []string{ip.String()}, nil + } + + if !isDomainName(host) { + return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + } + type result struct { + p dnsmessage.Parser + server string + error + } + var addrsV4, addrsV6 []netip.Addr + lanes := 0 + if tnet.hasV4 { + lanes++ + } + if tnet.hasV6 { + lanes++ + } + lane := make(chan result, lanes) + var lastErr error + if tnet.hasV4 { + go func() { + p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) + lane <- result{p, server, err} + }() + } + if tnet.hasV6 { + go func() { + p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) + lane <- result{p, server, err} + }() + } + for l := 0; l < lanes; l++ { + result := <-lane + if result.error != nil { + if lastErr == nil { + lastErr = result.error + } + continue + } + + loop: + for { + h, err := result.p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + } + if err != nil { + break + } + switch h.Type { + case dnsmessage.TypeA: + a, err := result.p.AResource() + if err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) + + case dnsmessage.TypeAAAA: + aaaa, err := result.p.AAAAResource() + if err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) + + default: + if err := result.p.SkipAnswer(); err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + continue + } + } + } + // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled + var addrs []netip.Addr + if tnet.hasV6 { + addrs = append(addrsV6, addrsV4...) + } else { + addrs = append(addrsV4, addrsV6...) + } + + if len(addrs) == 0 && lastErr != nil { + return nil, lastErr + } + saddrs := make([]string, 0, len(addrs)) + for _, ip := range addrs { + saddrs = append(saddrs, ip.String()) + } + return saddrs, nil +} + +func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { + if deadline.IsZero() { + return deadline, nil + } + timeRemaining := deadline.Sub(now) + if timeRemaining <= 0 { + return time.Time{}, errTimeout + } + timeout := timeRemaining / time.Duration(addrsRemaining) + const saneMinimum = 2 * time.Second + if timeout < saneMinimum { + if timeRemaining < saneMinimum { + timeout = timeRemaining + } else { + timeout = saneMinimum + } + } + return now.Add(timeout), nil +} + +var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`) + +func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if ctx == nil { + panic("nil context") + } + var acceptV4, acceptV6 bool + matches := protoSplitter.FindStringSubmatch(network) + if matches == nil { + return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} + } else if len(matches[2]) == 0 { + acceptV4 = true + acceptV6 = true + } else { + acceptV4 = matches[2][0] == '4' + acceptV6 = !acceptV4 + } + var host string + var port int + if matches[1] == "ping" { + host = address + } else { + var sport string + var err error + host, sport, err = net.SplitHostPort(address) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + port, err = strconv.Atoi(sport) + if err != nil || port < 0 || port > 65535 { + return nil, &net.OpError{Op: "dial", Err: errNumericPort} + } + } + allAddr, err := tnet.LookupContextHost(ctx, host) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + var addrs []netip.AddrPort + for _, addr := range allAddr { + ip, err := netip.ParseAddr(addr) + if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { + addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) + } + } + if len(addrs) == 0 && len(allAddr) != 0 { + return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} + } + + var firstErr error + for i, addr := range addrs { + select { + case <-ctx.Done(): + err := ctx.Err() + if err == context.Canceled { + err = errCanceled + } else if err == context.DeadlineExceeded { + err = errTimeout + } + return nil, &net.OpError{Op: "dial", Err: err} + default: + } + + dialCtx := ctx + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i) + if err != nil { + if firstErr == nil { + firstErr = &net.OpError{Op: "dial", Err: err} + } + break + } + if partialDeadline.Before(deadline) { + var cancel context.CancelFunc + dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) + defer cancel() + } + } + + var c net.Conn + switch matches[1] { + case "tcp": + c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) + case "udp": + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) + case "ping": + c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr()) + } + if err == nil { + return c, nil + } + if firstErr == nil { + firstErr = err + } + } + if firstErr == nil { + firstErr = &net.OpError{Op: "dial", Err: errMissingAddress} + } + return nil, firstErr +} + +func (tnet *Net) Dial(network, address string) (net.Conn, error) { + return tnet.DialContext(context.Background(), network, address) +} diff --git a/tun/offload_linux.go b/tun/offload_linux.go new file mode 100644 index 0000000..9ff7fea --- /dev/null +++ b/tun/offload_linux.go @@ -0,0 +1,993 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "unsafe" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" +) + +const tcpFlagsOffset = 13 + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// tcpFlowKey represents the key for a TCP flow. +type tcpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. + isV6 bool +} + +// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO. +type tcpGROTable struct { + itemsByFlow map[tcpFlowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize), + itemsPool: make([][]tcpGROItem, conn.IdealBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize) + } + return t +} + +func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey { + key := tcpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key tcpFlowKey + sentSeq uint32 // the sequence number + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// udpFlowKey represents the key for a UDP flow. +type udpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + isV6 bool +} + +// udpGROTable holds flow and coalescing information for the purposes of UDP GRO. +type udpGROTable struct { + itemsByFlow map[udpFlowKey][]udpGROItem + itemsPool [][]udpGROItem +} + +func newUDPGROTable() *udpGROTable { + u := &udpGROTable{ + itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize), + itemsPool: make([][]udpGROItem, conn.IdealBatchSize), + } + for i := range u.itemsPool { + u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize) + } + return u +} + +func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey { + key := udpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + items, ok := u.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + item := udpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[udphOffset+udphLen:])), + iphLen: uint8(udphOffset), + cSumKnownInvalid: cSumKnownInvalid, + } + items, ok := u.itemsByFlow[key] + if !ok { + items = u.newItems() + } + items = append(items, item) + u.itemsByFlow[key] = items +} + +func (u *udpGROTable) updateAt(item udpGROItem, i int) { + items, _ := u.itemsByFlow[item.key] + items[i] = item +} + +// udpGROItem represents bookkeeping data for a UDP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type udpGROItem struct { + key udpFlowKey + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. +} + +func (u *udpGROTable) newItems() []udpGROItem { + var items []udpGROItem + items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1] + return items +} + +func (u *udpGROTable) reset() { + for k, items := range u.itemsByFlow { + items = items[:0] + u.itemsPool = append(u.itemsPool, items) + delete(u.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB +// meet all requirements to be merged as part of a GRO operation, otherwise it +// returns false. +func ipHeadersCanCoalesce(pktA, pktB []byte) bool { + if len(pktA) < 9 || len(pktB) < 9 { + return false + } + if pktA[0]>>4 == 6 { + if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 { + // cannot coalesce with unequal Traffic class values + return false + } + if pktA[7] != pktB[7] { + // cannot coalesce with unequal Hop limit values + return false + } + } else { + if pktA[1] != pktB[1] { + // cannot coalesce with unequal ToS values + return false + } + if pktA[6]>>5 != pktB[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return false + } + if pktA[8] != pktB[8] { + // cannot coalesce with unequal TTL values + return false + } + } + return true +} + +// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. iphLen and gsoSize describe pkt. bufs is the vector of +// packets involved in the current GRO evaluation. bufsOffset is the offset at +// which packet data begins within bufs. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend +} + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + lenForPseudo := uint16(len(pkt) - int(iphLen)) + cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) + return ^checksum(pkt[iphLen:], cSum) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess +) + +// coalesceUDPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front + headersLen := item.iphLen + udphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalescePktInvalidCSum + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + + item.numMerged++ + return coalesceSuccess +} + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. This function may swap bufs elements in the +// event of a prepend as item's bufs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) + copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) + // Flip the slice headers in bufs as part of prepend. The index of item + // is already being tracked for writing. + bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + } else { + pktHead = bufs[item.bufsIndex][bufsOffset:] + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments uint8 = 0x20 +) + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 + maxUint16 = 1<<16 - 1 +) + +type groResult int + +const ( + groResultNoop groResult = iota + groResultTableInsert + groResultCoalesced +) + +// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return groResultNoop + } + if len(pkt) < iphLen+tcphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return groResultNoop + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return groResultTableInsert + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return groResultCoalesced + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return groResultNoop + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return groResultTableInsert +} + +// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if item.key.isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + udphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 6, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 + if item.key.isV6 { + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Recalculate the UDP len field value + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + + // Calculate the pseudo header checksum and place it at the UDP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the udp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +type groCandidateType uint8 + +const ( + notGROCandidate groCandidateType = iota + tcp4GROCandidate + tcp6GROCandidate + udp4GROCandidate + udp6GROCandidate +) + +func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType { + if len(b) < 28 { + return notGROCandidate + } + if b[0]>>4 == 4 { + if b[0]&0x0F != 5 { + // IPv4 packets w/IP options do not coalesce + return notGROCandidate + } + if b[9] == unix.IPPROTO_TCP && len(b) >= 40 { + return tcp4GROCandidate + } + if b[9] == unix.IPPROTO_UDP && canUDPGRO { + return udp4GROCandidate + } + } else if b[0]>>4 == 6 { + if b[6] == unix.IPPROTO_TCP && len(b) >= 60 { + return tcp6GROCandidate + } + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO { + return udp6GROCandidate + } + } + return notGROCandidate +} + +const ( + udphLen = 8 +) + +// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + if len(pkt) < iphLen+udphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + gsoSize := uint16(len(pkt) - udphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + if !existing { + return groResultTableInsert + } + // With UDP we only check the last item, otherwise we could reorder packets + // for a given flow. We must also always insert a new item, or successfully + // coalesce with an existing item, for the same reason. + item := items[len(items)-1] + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + var pktCSumKnownInvalid bool + if can == coalesceAppend { + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, len(items)-1) + return groResultCoalesced + case coalesceItemInvalidCSum: + // If the existing item has an invalid csum we take no action. A new + // item will be stored after it, and the existing item will never be + // revisited as part of future coalescing candidacy checks. + case coalescePktInvalidCSum: + // We must insert a new item, but we also mark it as invalid csum + // to prevent a repeat checksum validation. + pktCSumKnownInvalid = true + default: + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + return groResultTableInsert +} + +// handleGRO evaluates bufs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcpTable, and udpTable should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is +// supported. +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error { + for i := range bufs { + if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { + return errors.New("invalid offset") + } + var result groResult + switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) { + case tcp4GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, false) + case tcp6GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, true) + case udp4GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, false) + case udp6GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, true) + } + switch result { + case groResultNoop: + hdr := virtioNetHdr{} + err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + fallthrough + case groResultTableInsert: + *toWrite = append(*toWrite, i) + } + } + errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) + errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + return errors.Join(errTCP, errUDP) +} + +// gsoSplit splits packets from in into outBuffs, writing the size of each +// element into sizes. It returns the number of buffers populated, and/or an +// error. +func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) { + iphLen := int(hdr.csumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if !isV6 { + in[10], in[11] = 0, 0 // clear ipv4 header checksum + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + transportCsumAt := int(hdr.csumStart + hdr.csumOffset) + in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum + var firstTCPSeqNum uint32 + var protocol uint8 + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 { + protocol = unix.IPPROTO_TCP + firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + } else { + protocol = unix.IPPROTO_UDP + } + nextSegmentDataAt := int(hdr.hdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBuffs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(hdr.hdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBuffs[i][outOffset:] + + copy(out, in[:iphLen]) + if !isV6 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksum(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // copy transport header + copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) + + if protocol == unix.IPPROTO_TCP { + // set TCP seq and adjust TCP flags + tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) + binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + } + } else { + // set UDP header len + binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart)) + } + + // payload + copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // transport checksum + transportHeaderLen := int(hdr.hdrLen - hdr.csumStart) + lenForPseudo := uint16(transportHeaderLen + segmentDataLen) + transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold) + binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum) + + nextSegmentDataAt += int(hdr.gsoSize) + } + return i, nil +} + +func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { + cSumAt := cSumStart + cSumOffset + // The initial value at the checksum offset should be summed with the + // checksum we compute. This is typically the pseudo-header checksum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) + return nil +} diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go new file mode 100644 index 0000000..ae55c8c --- /dev/null +++ b/tun/offload_linux_test.go @@ -0,0 +1,752 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "net/netip" + "testing" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + offset = virtioNetHdrLen +) + +var ( + ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") + ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") + ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") + ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") + ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") + ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") +) + +func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 28 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_UDP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + udpH := header.UDP(b[offset+20:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 48 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_UDP, + HopLimit: 64, + PayloadLength: uint16(payloadLen + udphLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + udpH := header.UDP(b[offset+40:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 40 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_TCP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + tcpH := header.TCP(b[offset+20:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 60 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_TCP, + HopLimit: 64, + PayloadLength: uint16(segmentSize + 20), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + tcpH := header.TCP(b[offset+40:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func Test_handleVirtioRead(t *testing.T) { + tests := []struct { + name string + hdr virtioNetHdr + pktIn []byte + wantLens []int + wantErr bool + }{ + { + "tcp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + gsoSize: 100, + hdrLen: 40, + csumStart: 20, + csumOffset: 16, + }, + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{140, 140}, + false, + }, + { + "tcp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + gsoSize: 100, + hdrLen: 60, + csumStart: 40, + csumOffset: 16, + }, + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{160, 160}, + false, + }, + { + "udp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 28, + csumStart: 20, + csumOffset: 6, + }, + udp4Packet(ip4PortA, ip4PortB, 200), + []int{128, 128}, + false, + }, + { + "udp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 48, + csumStart: 40, + csumOffset: 6, + }, + udp6Packet(ip6PortA, ip6PortB, 200), + []int{148, 148}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := make([][]byte, conn.IdealBatchSize) + sizes := make([]int, conn.IdealBatchSize) + for i := range out { + out[i] = make([]byte, 65535) + } + tt.hdr.encode(tt.pktIn) + n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if n != len(tt.wantLens) { + t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) + } + for i := range tt.wantLens { + if tt.wantLens[i] != sizes[i] { + t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + } + } + }) + } +} + +func flipTCP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func flipUDP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func Fuzz_handleGRO(f *testing.F) { + pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) + pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) + pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) + pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) + pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) + pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) + pkt6 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt7 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt8 := udp4Packet(ip4PortA, ip4PortC, 100) + pkt9 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt10 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) + f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset) + f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) { + pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} + toWrite := make([]int, 0, len(pkts)) + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite) + if len(toWrite) > len(pkts) { + t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) + } + seenWriteI := make(map[int]bool) + for _, writeI := range toWrite { + if writeI < 0 || writeI > len(pkts)-1 { + t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) + } + if seenWriteI[writeI] { + t.Errorf("duplicate toWrite value: %d", writeI) + } + seenWriteI[writeI] = true + } + }) +} + +func Test_handleGRO(t *testing.T) { + tests := []struct { + name string + pktsIn [][]byte + canUDPGRO bool + wantToWrite []int + wantLens []int + wantErr bool + }{ + { + "multiple protocols and flows", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + true, + []int{0, 1, 2, 4, 5, 7, 9}, + []int{240, 228, 128, 140, 260, 160, 248}, + false, + }, + { + "multiple protocols and flows no UDP GRO", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + false, + []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, + []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, + false, + }, + { + "PSH interleaved", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 + }, + true, + []int{0, 2, 4, 6}, + []int{240, 240, 260, 260}, + false, + }, + { + "coalesceItemInvalidCSum", + [][]byte{ + flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4Packet(ip4PortA, ip4PortB, 100), + }, + true, + []int{0, 1, 3, 4}, + []int{140, 240, 128, 228}, + false, + }, + { + "out of order", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + }, + true, + []int{0}, + []int{340}, + false, + }, + { + "unequal TTL", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal ToS", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal flags more fragments set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal flags DF set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "ipv6 unequal hop limit", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, + }, + { + "ipv6 unequal traffic class", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + toWrite := make([]int, 0, len(tt.pktsIn)) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if len(toWrite) != len(tt.wantToWrite) { + t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) + } + for i, pktI := range tt.wantToWrite { + if tt.wantToWrite[i] != toWrite[i] { + t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + } + if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { + t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + } + } + }) + } +} + +func Test_packetIsGROCandidate(t *testing.T) { + tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp4TooShort := tcp4[:39] + ip4InvalidHeaderLen := make([]byte, len(tcp4)) + copy(ip4InvalidHeaderLen, tcp4) + ip4InvalidHeaderLen[0] = 0x46 + ip4InvalidProtocol := make([]byte, len(tcp4)) + copy(ip4InvalidProtocol, tcp4) + ip4InvalidProtocol[9] = unix.IPPROTO_GRE + + tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp6TooShort := tcp6[:59] + ip6InvalidProtocol := make([]byte, len(tcp6)) + copy(ip6InvalidProtocol, tcp6) + ip6InvalidProtocol[6] = unix.IPPROTO_GRE + + udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:] + udp4TooShort := udp4[:27] + + udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:] + udp6TooShort := udp6[:47] + + tests := []struct { + name string + b []byte + canUDPGRO bool + want groCandidateType + }{ + { + "tcp4", + tcp4, + true, + tcp4GROCandidate, + }, + { + "tcp6", + tcp6, + true, + tcp6GROCandidate, + }, + { + "udp4", + udp4, + true, + udp4GROCandidate, + }, + { + "udp4 no support", + udp4, + false, + notGROCandidate, + }, + { + "udp6", + udp6, + true, + udp6GROCandidate, + }, + { + "udp6 no support", + udp6, + false, + notGROCandidate, + }, + { + "udp4 too short", + udp4TooShort, + true, + notGROCandidate, + }, + { + "udp6 too short", + udp6TooShort, + true, + notGROCandidate, + }, + { + "tcp4 too short", + tcp4TooShort, + true, + notGROCandidate, + }, + { + "tcp6 too short", + tcp6TooShort, + true, + notGROCandidate, + }, + { + "invalid IP version", + []byte{0x00}, + true, + notGROCandidate, + }, + { + "invalid IP header len", + ip4InvalidHeaderLen, + true, + notGROCandidate, + }, + { + "ip4 invalid protocol", + ip4InvalidProtocol, + true, + notGROCandidate, + }, + { + "ip6 invalid protocol", + ip6InvalidProtocol, + true, + notGROCandidate, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want { + t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_udpPacketsCanCoalesce(t *testing.T) { + udp4a := udp4Packet(ip4PortA, ip4PortB, 100) + udp4b := udp4Packet(ip4PortA, ip4PortB, 100) + udp4c := udp4Packet(ip4PortA, ip4PortB, 110) + + type args struct { + pkt []byte + iphLen uint8 + gsoSize uint16 + item udpGROItem + bufs [][]byte + bufsOffset int + } + tests := []struct { + name string + args args + want canCoalesce + }{ + { + "coalesceAppend equal gso", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceAppend smaller gso", + args{ + pkt: udp4a[offset : len(udp4a)-90], + iphLen: 20, + gsoSize: 10, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceUnavailable smaller gso previously appended", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4c, + udp4b, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable larger following smaller", + args{ + pkt: udp4c[offset:], + iphLen: 20, + gsoSize: 110, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4c, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want { + t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tun/operateonfd.go b/tun/operateonfd.go index 31747a2..f1beb6d 100644 --- a/tun/operateonfd.go +++ b/tun/operateonfd.go @@ -1,8 +1,8 @@ -// +build !windows +//go:build darwin || freebsd /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -18,12 +18,36 @@ const ( ) type Device interface { - File() *os.File // returns the file descriptor of the device - Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) - Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) - Flush() error // flush all previous writes to the device - MTU() (int, error) // returns the MTU of the device - Name() (string, error) // fetches and returns the current name - Events() chan Event // returns a constant channel of events related to the device - Close() error // stops the device and closes the event channel + // File returns the file descriptor of the device. + File() *os.File + + // Read one or more packets from the Device (without any additional headers). + // On a successful read it returns the number of packets read, and sets + // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). + // A nonzero offset can be used to instruct the Device on where to begin + // reading into each element of the bufs slice. + Read(bufs [][]byte, sizes []int, offset int) (n int, err error) + + // Write one or more packets to the device (without any additional headers). + // On a successful write it returns the number of packets written. A nonzero + // offset can be used to instruct the Device on where to begin writing from + // each packet contained within the bufs slice. + Write(bufs [][]byte, offset int) (int, error) + + // MTU returns the MTU of the Device. + MTU() (int, error) + + // Name returns the current name of the Device. + Name() (string, error) + + // Events returns a channel of type Event, which is fed Device events. + Events() <-chan Event + + // Close stops the Device and closes the Event channel. + Close() error + + // BatchSize returns the preferred/max number of packets that can be read or + // written in a single read/write call. BatchSize must not change over the + // lifetime of a Device. + BatchSize() int } diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index 6d2e6dd..c9a6c0b 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -1,46 +1,46 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( + "errors" "fmt" - "io/ioutil" + "io" "net" "os" + "sync" "syscall" + "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) const utunControlName = "com.apple.net.utun_control" -// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h -const _CTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3 - -// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h -type sockaddrCtl struct { - scLen uint8 - scFamily uint8 - ssSysaddr uint16 - scID uint32 - scUnit uint32 - scReserved [5]uint32 -} - type NativeTun struct { name string tunFile *os.File events chan Event errors chan error routeSocket int + closeOnce sync.Once } -var sockaddrCtlSize uintptr = 32 +func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { + for i := 0; i < 20; i++ { + iface, err = net.InterfaceByIndex(index) + if err != nil && errors.Is(err, unix.ENOMEM) { + time.Sleep(time.Duration(i) * time.Second / 3) + continue + } + return iface, err + } + return nil, err +} func (tun *NativeTun) routineRouteListener(tunIfindex int) { var ( @@ -55,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { retry: n, err := unix.Read(tun.routeSocket, data) if err != nil { - if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { + if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR { goto retry } tun.errors <- err @@ -74,7 +74,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { continue } - iface, err := net.InterfaceByIndex(ifindex) + iface, err := retryInterfaceByIndex(ifindex) if err != nil { tun.errors <- err return @@ -107,53 +107,33 @@ func CreateTUN(name string, mtu int) (Device, error) { } } - fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) - + fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) if err != nil { return nil, err } - var ctlInfo = &struct { - ctlID uint32 - ctlName [96]byte - }{} - - copy(ctlInfo.ctlName[:], []byte(utunControlName)) - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(_CTLIOCGINFO), - uintptr(unsafe.Pointer(ctlInfo)), - ) - - if errno != 0 { - return nil, fmt.Errorf("_CTLIOCGINFO: %v", errno) + ctlInfo := &unix.CtlInfo{} + copy(ctlInfo.Name[:], []byte(utunControlName)) + err = unix.IoctlCtlInfo(fd, ctlInfo) + if err != nil { + unix.Close(fd) + return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err) } - sc := sockaddrCtl{ - scLen: uint8(sockaddrCtlSize), - scFamily: unix.AF_SYSTEM, - ssSysaddr: 2, - scID: ctlInfo.ctlID, - scUnit: uint32(ifIndex) + 1, + sc := &unix.SockaddrCtl{ + ID: ctlInfo.Id, + Unit: uint32(ifIndex) + 1, } - scPointer := unsafe.Pointer(&sc) - - _, _, errno = unix.RawSyscall( - unix.SYS_CONNECT, - uintptr(fd), - uintptr(scPointer), - uintptr(sockaddrCtlSize), - ) - - if errno != 0 { - return nil, fmt.Errorf("SYS_CONNECT: %v", errno) + err = unix.Connect(fd, sc) + if err != nil { + unix.Close(fd) + return nil, err } - err = syscall.SetNonblock(fd, true) + err = unix.SetNonblock(fd, true) if err != nil { + unix.Close(fd) return nil, err } tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu) @@ -161,7 +141,7 @@ func CreateTUN(name string, mtu int) (Device, error) { if err == nil && name == "utun" { fname := os.Getenv("WG_TUN_NAME_FILE") if fname != "" { - ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400) + os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) } } @@ -193,7 +173,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { tun.tunFile.Close() return nil, err @@ -213,27 +193,19 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { } func (tun *NativeTun) Name() (string, error) { - var ifName struct { - name [16]byte - } - ifNameSize := uintptr(16) - - var errno syscall.Errno + var err error tun.operateOnFd(func(fd uintptr) { - _, _, errno = unix.Syscall6( - unix.SYS_GETSOCKOPT, - fd, + tun.name, err = unix.GetsockoptString( + int(fd), 2, /* #define SYSPROTO_CONTROL 2 */ 2, /* #define UTUN_OPT_IFNAME 2 */ - uintptr(unsafe.Pointer(&ifName)), - uintptr(unsafe.Pointer(&ifNameSize)), 0) + ) }) - if errno != 0 { - return "", fmt.Errorf("SYS_GETSOCKOPT: %v", errno) + if err != nil { + return "", fmt.Errorf("GetSockoptString: %w", err) } - tun.name = string(ifName.name[:ifNameSize-1]) return tun.name, nil } @@ -241,61 +213,63 @@ func (tun *NativeTun) File() *os.File { return tun.tunFile } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + // TODO: the BSDs look very similar in Read() and Write(). They should be + // collapsed, with platform-specific files containing the varying parts of + // their implementations. select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) + buf := bufs[0][offset-4:] + n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - buff[2] = 0x00 - - if buff[4]>>4 == ipv6.Version { - buff[3] = unix.AF_INET6 - } else { - buff[3] = unix.AF_INET +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + if offset < 4 { + return 0, io.ErrShortBuffer } - - // write - - return tun.tunFile.Write(buff) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + for i, buf := range bufs { + buf = buf[offset-4:] + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } + } + return len(bufs), nil } func (tun *NativeTun) Close() error { - var err2 error - err1 := tun.tunFile.Close() - if tun.routeSocket != -1 { - unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) - err2 = unix.Close(tun.routeSocket) - } else if tun.events != nil { - close(tun.events) - } + var err1, err2 error + tun.closeOnce.Do(func() { + err1 = tun.tunFile.Close() + if tun.routeSocket != -1 { + unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) + err2 = unix.Close(tun.routeSocket) + } else if tun.events != nil { + close(tun.events) + } + }) if err1 != nil { return err1 } @@ -303,71 +277,60 @@ func (tun *NativeTun) Close() error { } func (tun *NativeTun) setMTU(n int) error { - - // open datagram socket - - var fd int - - fd, err := unix.Socket( + fd, err := socketCloexec( unix.AF_INET, unix.SOCK_DGRAM, 0, ) - if err != nil { return err } defer unix.Close(fd) - // do ioctl call - - var ifr [32]byte - copy(ifr[:], tun.name) - *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCSIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - - if errno != 0 { - return fmt.Errorf("failed to set MTU on %s", tun.name) + var ifr unix.IfreqMTU + copy(ifr.Name[:], tun.name) + ifr.MTU = int32(n) + err = unix.IoctlSetIfreqMTU(fd, &ifr) + if err != nil { + return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err) } return nil } func (tun *NativeTun) MTU() (int, error) { - - // open datagram socket - - fd, err := unix.Socket( + fd, err := socketCloexec( unix.AF_INET, unix.SOCK_DGRAM, 0, ) - if err != nil { return 0, err } defer unix.Close(fd) - // do ioctl call - - var ifr [64]byte - copy(ifr[:], tun.name) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCGIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - return 0, fmt.Errorf("failed to get MTU on %s", tun.name) + ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name) + if err != nil { + return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err) } - return int(*(*int32)(unsafe.Pointer(&ifr[16]))), nil + return int(ifr.MTU), nil +} + +func (tun *NativeTun) BatchSize() int { + return 1 +} + +func socketCloexec(family, sotype, proto int) (fd int, err error) { + // See go/src/net/sys_cloexec.go for background. + syscall.ForkLock.RLock() + defer syscall.ForkLock.RUnlock() + + fd, err = unix.Socket(family, sotype, proto) + if err == nil { + unix.CloseOnExec(fd) + } + return } diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 6cf9313..7c65fd9 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -1,66 +1,57 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( - "bytes" "errors" "fmt" + "io" "net" "os" + "sync" "syscall" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) -// _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h -// const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96)) const ( _TUNSIFHEAD = 0x80047460 _TUNSIFMODE = 0x8004745e + _TUNGIFNAME = 0x4020745d _TUNSIFPID = 0x2000745f -) -// TODO: move into x/sys/unix -const ( - SIOCGIFINFO_IN6 = 0xc048696c - SIOCSIFINFO_IN6 = 0xc048696d - ND6_IFF_AUTO_LINKLOCAL = 0x20 - ND6_IFF_NO_DAD = 0x100 + _SIOCGIFINFO_IN6 = 0xc048696c + _SIOCSIFINFO_IN6 = 0xc048696d + _ND6_IFF_AUTO_LINKLOCAL = 0x20 + _ND6_IFF_NO_DAD = 0x100 ) -// Iface status string max len -const _IFSTATMAX = 800 - -const SIZEOF_UINTPTR = 4 << (^uintptr(0) >> 32 & 1) +// Iface requests with just the name +type ifreqName struct { + Name [unix.IFNAMSIZ]byte + _ [16]byte +} -// structure for iface requests with a pointer -type ifreq_ptr struct { +// Iface requests with a pointer +type ifreqPtr struct { Name [unix.IFNAMSIZ]byte Data uintptr - Pad0 [16 - SIZEOF_UINTPTR]byte + _ [16 - unsafe.Sizeof(uintptr(0))]byte } -// Structure for iface mtu get/set ioctls -type ifreq_mtu struct { +// Iface requests with MTU +type ifreqMtu struct { Name [unix.IFNAMSIZ]byte MTU uint32 - Pad0 [12]byte -} - -// Structure for interface status request ioctl -type ifstat struct { - IfsName [unix.IFNAMSIZ]byte - Ascii [_IFSTATMAX]byte + _ [12]byte } -// Structures for nd6 flag manipulation -type in6_ndireq struct { +// ND6 flag manipulation +type nd6Req struct { Name [unix.IFNAMSIZ]byte Linkmtu uint32 Maxmtu uint32 @@ -82,6 +73,7 @@ type NativeTun struct { events chan Event errors chan error routeSocket int + closeOnce sync.Once } func (tun *NativeTun) routineRouteListener(tunIfindex int) { @@ -97,7 +89,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { retry: n, err := unix.Read(tun.routeSocket, data) if err != nil { - if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { + if errors.Is(err, syscall.EINTR) { goto retry } tun.errors <- err @@ -141,91 +133,17 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { } func tunName(fd uintptr) (string, error) { - //Terrible hack to make up for freebsd not having a TUNGIFNAME - - //First, make sure the tun pid matches this proc's pid - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(_TUNSIFPID), - uintptr(0), - ) - - if errno != 0 { - return "", fmt.Errorf("failed to set tun device PID: %s", errno.Error()) - } - - // Open iface control socket - - confd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { + var ifreq ifreqName + _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq))) + if err != 0 { return "", err } - - defer unix.Close(confd) - - procPid := os.Getpid() - - //Try to find interface with matching PID - for i := 1; ; i++ { - iface, _ := net.InterfaceByIndex(i) - if err != nil || iface == nil { - break - } - - // Structs for getting data in and out of SIOCGIFSTATUS ioctl - var ifstatus ifstat - copy(ifstatus.IfsName[:], iface.Name) - - // Make the syscall to get the status string - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(confd), - uintptr(unix.SIOCGIFSTATUS), - uintptr(unsafe.Pointer(&ifstatus)), - ) - - if errno != 0 { - continue - } - - nullStr := ifstatus.Ascii[:] - i := bytes.IndexByte(nullStr, 0) - if i < 1 { - continue - } - statStr := string(nullStr[:i]) - var pidNum int = 0 - - // Finally get the owning PID - // Format string taken from sys/net/if_tun.c - _, err := fmt.Sscanf(statStr, "\tOpened by PID %d\n", &pidNum) - if err != nil { - continue - } - - if pidNum == procPid { - return iface.Name, nil - } - } - - return "", nil + return unix.ByteSliceToString(ifreq.Name[:]), nil } // Destroy a named system interface func tunDestroy(name string) error { - // Open control socket. - var fd int - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { return err } @@ -233,14 +151,9 @@ func tunDestroy(name string) error { var ifr [32]byte copy(ifr[:], name) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCIFDESTROY), - uintptr(unsafe.Pointer(&ifr[0])), - ) + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0]))) if errno != 0 { - return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error()) + return fmt.Errorf("failed to destroy interface %s: %w", name, errno) } return nil @@ -257,7 +170,7 @@ func CreateTUN(name string, mtu int) (Device, error) { return nil, fmt.Errorf("interface %s already exists", name) } - tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR, 0) + tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { return nil, err } @@ -276,103 +189,94 @@ func CreateTUN(name string, mtu int) (Device, error) { ifheadmode := 1 var errno syscall.Errno tun.operateOnFd(func(fd uintptr) { - _, _, errno = unix.Syscall( - unix.SYS_IOCTL, - fd, - uintptr(_TUNSIFHEAD), - uintptr(unsafe.Pointer(&ifheadmode)), - ) + _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode))) }) if errno != 0 { tunFile.Close() tunDestroy(assignedName) - return nil, fmt.Errorf("Unable to put into IFHEAD mode: %v", errno) + return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno) } - // Open control sockets - confd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - if err != nil { + // Get out of PTP mode. + ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST + tun.operateOnFd(func(fd uintptr) { + _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags))) + }) + + if errno != 0 { tunFile.Close() tunDestroy(assignedName) - return nil, err + return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno) } - defer unix.Close(confd) - confd6, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM, - 0, - ) + + // Disable link-local v6, not just because WireGuard doesn't do that anyway, but + // also because there are serious races with attaching and detaching LLv6 addresses + // in relation to interface lifetime within the FreeBSD kernel. + confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { tunFile.Close() tunDestroy(assignedName) return nil, err } defer unix.Close(confd6) - - // Disable link-local v6, not just because WireGuard doesn't do that anyway, but - // also because there are serious races with attaching and detaching LLv6 addresses - // in relation to interface lifetime within the FreeBSD kernel. - var ndireq in6_ndireq + var ndireq nd6Req copy(ndireq.Name[:], assignedName) - _, _, errno = unix.Syscall( - unix.SYS_IOCTL, - uintptr(confd6), - uintptr(SIOCGIFINFO_IN6), - uintptr(unsafe.Pointer(&ndireq)), - ) + _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq))) if errno != 0 { tunFile.Close() tunDestroy(assignedName) - return nil, fmt.Errorf("Unable to get nd6 flags for %s: %v", assignedName, errno) + return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno) } - ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL - ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD - _, _, errno = unix.Syscall( - unix.SYS_IOCTL, - uintptr(confd6), - uintptr(SIOCSIFINFO_IN6), - uintptr(unsafe.Pointer(&ndireq)), - ) + ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL + ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD + _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq))) if errno != 0 { tunFile.Close() tunDestroy(assignedName) - return nil, fmt.Errorf("Unable to set nd6 flags for %s: %v", assignedName, errno) + return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno) } - // Rename the interface - var newnp [unix.IFNAMSIZ]byte - copy(newnp[:], name) - var ifr ifreq_ptr - copy(ifr.Name[:], assignedName) - ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) - _, _, errno = unix.Syscall( - unix.SYS_IOCTL, - uintptr(confd), - uintptr(unix.SIOCSIFNAME), - uintptr(unsafe.Pointer(&ifr)), - ) - if errno != 0 { - tunFile.Close() - tunDestroy(assignedName) - return nil, fmt.Errorf("Failed to rename %s to %s: %v", assignedName, name, errno) + if name != "" { + confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) + if err != nil { + tunFile.Close() + tunDestroy(assignedName) + return nil, err + } + defer unix.Close(confd) + var newnp [unix.IFNAMSIZ]byte + copy(newnp[:], name) + var ifr ifreqPtr + copy(ifr.Name[:], assignedName) + ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) + _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr))) + if errno != 0 { + tunFile.Close() + tunDestroy(assignedName) + return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno) + } } return CreateTUNFromFile(tunFile, mtu) } func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { - tun := &NativeTun{ tunFile: file, events: make(chan Event, 10), errors: make(chan error, 1), } + var errno syscall.Errno + tun.operateOnFd(func(fd uintptr) { + _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0)) + }) + if errno != 0 { + tun.tunFile.Close() + return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno) + } + name, err := tun.Name() if err != nil { tun.tunFile.Close() @@ -391,7 +295,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC) if err != nil { tun.tunFile.Close() return nil, err @@ -425,63 +329,65 @@ func (tun *NativeTun) File() *os.File { return tun.tunFile } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) + buf := bufs[0][offset-4:] + n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - buff[2] = 0x00 - - if buff[4]>>4 == ipv6.Version { - buff[3] = unix.AF_INET6 - } else { - buff[3] = unix.AF_INET +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + if offset < 4 { + return 0, io.ErrShortBuffer } - - // write - - return tun.tunFile.Write(buff) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + for i, buf := range bufs { + buf = buf[offset-4:] + if len(buf) < 5 { + return i, io.ErrShortBuffer + } + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } + } + return len(bufs), nil } func (tun *NativeTun) Close() error { - var err3 error - err1 := tun.tunFile.Close() - err2 := tunDestroy(tun.name) - if tun.routeSocket != -1 { - unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) - err3 = unix.Close(tun.routeSocket) - tun.routeSocket = -1 - } else if tun.events != nil { - close(tun.events) - } + var err1, err2, err3 error + tun.closeOnce.Do(func() { + err1 = tun.tunFile.Close() + err2 = tunDestroy(tun.name) + if tun.routeSocket != -1 { + unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) + err3 = unix.Close(tun.routeSocket) + tun.routeSocket = -1 + } else if tun.events != nil { + close(tun.events) + } + }) if err1 != nil { return err1 } @@ -492,70 +398,38 @@ func (tun *NativeTun) Close() error { } func (tun *NativeTun) setMTU(n int) error { - // open datagram socket - - var fd int - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { return err } - defer unix.Close(fd) - // do ioctl call - - var ifr ifreq_mtu + var ifr ifreqMtu copy(ifr.Name[:], tun.name) ifr.MTU = uint32(n) - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCSIFMTU), - uintptr(unsafe.Pointer(&ifr)), - ) - + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr))) if errno != 0 { - return fmt.Errorf("failed to set MTU on %s", tun.name) + return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno) } - return nil } func (tun *NativeTun) MTU() (int, error) { - // open datagram socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { return 0, err } - defer unix.Close(fd) - // do ioctl call - var ifr ifreq_mtu + var ifr ifreqMtu copy(ifr.Name[:], tun.name) - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCGIFMTU), - uintptr(unsafe.Pointer(&ifr)), - ) + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr))) if errno != 0 { - return 0, fmt.Errorf("failed to get MTU on %s", tun.name) + return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno) } - return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } + +func (tun *NativeTun) BatchSize() int { + return 1 +} diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 61902e9..bd69cb5 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -9,18 +9,16 @@ package tun */ import ( - "bytes" "errors" "fmt" - "net" "os" "sync" "syscall" "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) @@ -32,14 +30,29 @@ const ( type NativeTun struct { tunFile *os.File index int32 // if index - name string // name of interface errors chan error // async error handling events chan Event // device related events - nopi bool // the device was pased IFF_NO_PI netlinkSock int netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + batchSize int + vnetHdr bool + udpGSO bool + + closeOnce sync.Once + + nameOnce sync.Once // guards calling initNameCache, which sets following fields + nameCache string // name of interface + nameErr error + + readOpMu sync.Mutex // readOpMu guards readBuff + readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr + + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable + toWrite []int + tcpGROTable *tcpGROTable + udpGROTable *udpGROTable } func (tun *NativeTun) File() *os.File { @@ -51,6 +64,11 @@ func (tun *NativeTun) routineHackListener() { /* This is needed for the detection to work across network namespaces * If you are reading this and know a better method, please get in touch. */ + last := 0 + const ( + up = 1 + down = 2 + ) for { sysconn, err := tun.tunFile.SyscallConn() if err != nil { @@ -64,14 +82,25 @@ func (tun *NativeTun) routineHackListener() { } switch err { case unix.EINVAL: - tun.events <- EventUp + if last != up { + // If the tunnel is up, it reports that write() is + // allowed but we provided invalid data. + tun.events <- EventUp + last = up + } case unix.EIO: - tun.events <- EventDown + if last != down { + // If the tunnel is down, it reports that no I/O + // is possible, without checking our provided data. + tun.events <- EventDown + last = down + } default: return } select { case <-time.After(time.Second): + // nothing case <-tun.statusListenersShutdown: return } @@ -79,13 +108,13 @@ func (tun *NativeTun) routineHackListener() { } func createNetlinkSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) if err != nil { return -1, err } saddr := &unix.SockaddrNetlink{ Family: unix.AF_NETLINK, - Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))), + Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR, } err = unix.Bind(sock, saddr) if err != nil { @@ -99,10 +128,10 @@ func (tun *NativeTun) routineNetlinkListener() { unix.Close(tun.netlinkSock) tun.hackListenerClosed.Lock() close(tun.events) + tun.netlinkCancel.Close() }() for msg := make([]byte, 1<<16); ; { - var err error var msgn int for { @@ -111,12 +140,12 @@ func (tun *NativeTun) routineNetlinkListener() { break } if !tun.netlinkCancel.ReadyRead() { - tun.errors <- fmt.Errorf("netlink socket closed: %s", err.Error()) + tun.errors <- fmt.Errorf("netlink socket closed: %w", err) return } } if err != nil { - tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error()) + tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err) return } @@ -126,6 +155,7 @@ func (tun *NativeTun) routineNetlinkListener() { default: } + wasEverUp := false for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) @@ -149,10 +179,16 @@ func (tun *NativeTun) routineNetlinkListener() { if info.Flags&unix.IFF_RUNNING != 0 { tun.events <- EventUp + wasEverUp = true } if info.Flags&unix.IFF_RUNNING == 0 { - tun.events <- EventDown + // Don't emit EventDown before we've ever emitted EventUp. + // This avoids a startup race with HackListener, which + // might detect Up before we have finished reporting Down. + if wasEverUp { + tun.events <- EventDown + } } tun.events <- EventMTUUpdate @@ -164,15 +200,10 @@ func (tun *NativeTun) routineNetlinkListener() { } } -func (tun *NativeTun) isUp() (bool, error) { - inter, err := net.InterfaceByName(tun.name) - return inter.Flags&net.FlagUp != 0, err -} - func getIFIndex(name string) (int32, error) { fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) if err != nil { @@ -198,13 +229,17 @@ func getIFIndex(name string) (int32, error) { } func (tun *NativeTun) setMTU(n int) error { + name, err := tun.Name() + if err != nil { + return err + } + // open datagram socket fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) - if err != nil { return err } @@ -212,9 +247,8 @@ func (tun *NativeTun) setMTU(n int) error { defer unix.Close(fd) // do ioctl call - var ifr [ifReqSize]byte - copy(ifr[:], tun.name) + copy(ifr[:], name) *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n) _, _, errno := unix.Syscall( unix.SYS_IOCTL, @@ -224,20 +258,24 @@ func (tun *NativeTun) setMTU(n int) error { ) if errno != 0 { - return errors.New("failed to set MTU of TUN device") + return fmt.Errorf("failed to set MTU of TUN device: %w", errno) } return nil } func (tun *NativeTun) MTU() (int, error) { + name, err := tun.Name() + if err != nil { + return 0, err + } + // open datagram socket fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) - if err != nil { return 0, err } @@ -247,7 +285,7 @@ func (tun *NativeTun) MTU() (int, error) { // do ioctl call var ifr [ifReqSize]byte - copy(ifr[:], tun.name) + copy(ifr[:], name) _, _, errno := unix.Syscall( unix.SYS_IOCTL, uintptr(fd), @@ -255,13 +293,22 @@ func (tun *NativeTun) MTU() (int, error) { uintptr(unsafe.Pointer(&ifr[0])), ) if errno != 0 { - return 0, errors.New("failed to get MTU of TUN device: " + errno.Error()) + return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno) } return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil } func (tun *NativeTun) Name() (string, error) { + tun.nameOnce.Do(tun.initNameCache) + return tun.nameCache, tun.nameErr +} + +func (tun *NativeTun) initNameCache() { + tun.nameCache, tun.nameErr = tun.nameSlow() +} + +func (tun *NativeTun) nameSlow() (string, error) { sysconn, err := tun.tunFile.SyscallConn() if err != nil { return "", err @@ -277,147 +324,287 @@ func (tun *NativeTun) Name() (string, error) { ) }) if err != nil { - return "", errors.New("failed to get name of TUN device: " + err.Error()) + return "", fmt.Errorf("failed to get name of TUN device: %w", err) } if errno != 0 { - return "", errors.New("failed to get name of TUN device: " + errno.Error()) - } - nullStr := ifr[:] - i := bytes.IndexByte(nullStr, 0) - if i != -1 { - nullStr = nullStr[:i] + return "", fmt.Errorf("failed to get name of TUN device: %w", errno) } - tun.name = string(nullStr) - return tun.name, nil + return unix.ByteSliceToString(ifr[:]), nil } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - - if tun.nopi { - buff = buff[offset:] +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + tun.writeOpMu.Lock() + defer func() { + tun.tcpGROTable.reset() + tun.udpGROTable.reset() + tun.writeOpMu.Unlock() + }() + var ( + errs error + total int + ) + tun.toWrite = tun.toWrite[:0] + if tun.vnetHdr { + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) + if err != nil { + return 0, err + } + offset -= virtioNetHdrLen } else { - // reserve space for header + for i := range bufs { + tun.toWrite = append(tun.toWrite, i) + } + } + for _, bufsI := range tun.toWrite { + n, err := tun.tunFile.Write(bufs[bufsI][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = errors.Join(errs, err) + } else { + total += n + } + } + return total, errs +} - buff = buff[offset-4:] +// handleVirtioRead splits in into bufs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of bufs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + // This means CHECKSUM_PARTIAL in skb context. We are responsible + // for computing the checksum starting at hdr.csumStart and placing + // at hdr.csumOffset. + err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) + if err != nil { + return 0, err + } + } + if len(in) > len(bufs[0][offset:]) { + return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) + } + n := copy(bufs[0][offset:], in) + sizes[0] = n + return 1, nil + } + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) + } - // add packet information header + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + case 6: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } - buff[0] = 0x00 - buff[1] = 0x00 + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the transport header length and add it onto + // csumStart, which is synonymous for IP header length. + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + hdr.hdrLen = hdr.csumStart + 8 + } else { + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } - if buff[4]>>4 == ipv6.Version { - buff[2] = 0x86 - buff[3] = 0xdd - } else { - buff[2] = 0x08 - buff[3] = 0x00 + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) } + hdr.hdrLen = hdr.csumStart + tcpHLen } - // write + if len(in) < int(hdr.hdrLen) { + return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) + } - return tun.tunFile.Write(buff) -} + if hdr.hdrLen < hdr.csumStart { + return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) + } + cSumAt := int(hdr.csumStart + hdr.csumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6) } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + tun.readOpMu.Lock() + defer tun.readOpMu.Unlock() select { case err := <-tun.errors: return 0, err default: - if tun.nopi { - return tun.tunFile.Read(buff[offset:]) + readInto := bufs[0][offset:] + if tun.vnetHdr { + readInto = tun.readBuff[:] + } + n, err := tun.tunFile.Read(readInto) + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + if err != nil { + return 0, err + } + if tun.vnetHdr { + return handleVirtioRead(readInto[:n], bufs, sizes, offset) } else { - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) - if n < 4 { - return 0, err - } - return n - 4, err + sizes[0] = n + return 1, nil } } } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Close() error { - var err1 error - if tun.statusListenersShutdown != nil { - close(tun.statusListenersShutdown) - if tun.netlinkCancel != nil { - err1 = tun.netlinkCancel.Cancel() + var err1, err2 error + tun.closeOnce.Do(func() { + if tun.statusListenersShutdown != nil { + close(tun.statusListenersShutdown) + if tun.netlinkCancel != nil { + err1 = tun.netlinkCancel.Cancel() + } + } else if tun.events != nil { + close(tun.events) } - } else if tun.events != nil { - close(tun.events) - } - err2 := tun.tunFile.Close() - + err2 = tun.tunFile.Close() + }) if err1 != nil { return err1 } return err2 } +func (tun *NativeTun) BatchSize() int { + return tun.batchSize +} + +const ( + // TODO: support TSO with ECN bits + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 +) + +func (tun *NativeTun) initFromFlags(name string) error { + sc, err := tun.tunFile.SyscallConn() + if err != nil { + return err + } + if e := sc.Control(func(fd uintptr) { + var ( + ifr *unix.Ifreq + ) + ifr, err = unix.NewIfreq(name) + if err != nil { + return + } + err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) + if err != nil { + return + } + got := ifr.Uint16() + if got&unix.IFF_VNET_HDR != 0 { + // tunTCPOffloads were added in Linux v2.6. We require their support + // if IFF_VNET_HDR is set. + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) + if err != nil { + return + } + tun.vnetHdr = true + tun.batchSize = conn.IdealBatchSize + // tunUDPOffloads were added in Linux v6.2. We do not return an + // error if they are unsupported at runtime. + tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil + } else { + tun.batchSize = 1 + } + }); e != nil { + return e + } + return err +} + +// CreateTUN creates a Device with the provided name and MTU. func CreateTUN(name string, mtu int) (Device, error) { - nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0) + nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { + if os.IsNotExist(err) { + return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath) + } return nil, err } - var ifr [ifReqSize]byte - var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) - nameBytes := []byte(name) - if len(nameBytes) >= unix.IFNAMSIZ { - return nil, errors.New("interface name too long") + ifr, err := unix.NewIfreq(name) + if err != nil { + return nil, err } - copy(ifr[:], nameBytes) - *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(nfd), - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - return nil, errno + // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() + // where a null write will return EINVAL indicating the TUN is up. + ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) + if err != nil { + return nil, err } - err = unix.SetNonblock(nfd, true) - // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. - - fd := os.NewFile(uintptr(nfd), cloneDevicePath) + err = unix.SetNonblock(nfd, true) if err != nil { + unix.Close(nfd) return nil, err } + // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. + + fd := os.NewFile(uintptr(nfd), cloneDevicePath) return CreateTUNFromFile(fd, mtu) } +// CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - nopi: false, + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } - var err error - _, err = tun.Name() + name, err := tun.Name() if err != nil { return nil, err } - // start event listener + err = tun.initFromFlags(name) + if err != nil { + return nil, err + } - tun.index, err = getIFIndex(tun.name) + // start event listener + tun.index, err = getIFIndex(name) if err != nil { return nil, err } @@ -445,6 +632,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return tun, nil } +// CreateUnmonitoredTUNFromFD creates a Device from the provided file +// descriptor. func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { err := unix.SetNonblock(fd, true) if err != nil { @@ -452,14 +641,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - nopi: true, + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() if err != nil { return nil, "", err } - return tun, name, nil + err = tun.initFromFlags(name) + if err != nil { + return nil, "", err + } + return tun, name, err } diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index 44cedaa..ae571b9 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -1,19 +1,20 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( + "errors" "fmt" - "io/ioutil" + "io" "net" "os" + "sync" "syscall" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) @@ -32,6 +33,7 @@ type NativeTun struct { events chan Event errors chan error routeSocket int + closeOnce sync.Once } func (tun *NativeTun) routineRouteListener(tunIfindex int) { @@ -99,16 +101,6 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { } } -func errorIsEBUSY(err error) bool { - if pe, ok := err.(*os.PathError); ok { - err = pe.Err - } - if errno, ok := err.(syscall.Errno); ok && errno == syscall.EBUSY { - return true - } - return false -} - func CreateTUN(name string, mtu int) (Device, error) { ifIndex := -1 if name != "tun" { @@ -122,11 +114,11 @@ func CreateTUN(name string, mtu int) (Device, error) { var err error if ifIndex != -1 { - tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0) + tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) } else { - for ifIndex = 0; ifIndex < 256; ifIndex += 1 { - tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0) - if err == nil || !errorIsEBUSY(err) { + for ifIndex = 0; ifIndex < 256; ifIndex++ { + tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) + if err == nil || !errors.Is(err, syscall.EBUSY) { break } } @@ -141,7 +133,7 @@ func CreateTUN(name string, mtu int) (Device, error) { if err == nil && name == "tun" { fname := os.Getenv("WG_TUN_NAME_FILE") if fname != "" { - ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400) + os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) } } @@ -173,7 +165,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC) if err != nil { tun.tunFile.Close() return nil, err @@ -208,62 +200,61 @@ func (tun *NativeTun) File() *os.File { return tun.tunFile } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) + buf := bufs[0][offset-4:] + n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - buff[2] = 0x00 - - if buff[4]>>4 == ipv6.Version { - buff[3] = unix.AF_INET6 - } else { - buff[3] = unix.AF_INET +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + if offset < 4 { + return 0, io.ErrShortBuffer } - - // write - - return tun.tunFile.Write(buff) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + for i, buf := range bufs { + buf = buf[offset-4:] + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } + } + return len(bufs), nil } func (tun *NativeTun) Close() error { - var err2 error - err1 := tun.tunFile.Close() - if tun.routeSocket != -1 { - unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) - err2 = unix.Close(tun.routeSocket) - tun.routeSocket = -1 - } else if tun.events != nil { - close(tun.events) - } + var err1, err2 error + tun.closeOnce.Do(func() { + err1 = tun.tunFile.Close() + if tun.routeSocket != -1 { + unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) + err2 = unix.Close(tun.routeSocket) + tun.routeSocket = -1 + } else if tun.events != nil { + close(tun.events) + } + }) if err1 != nil { return err1 } @@ -277,10 +268,9 @@ func (tun *NativeTun) setMTU(n int) error { fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) - if err != nil { return err } @@ -312,10 +302,9 @@ func (tun *NativeTun) MTU() (int, error) { fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) - if err != nil { return 0, err } @@ -338,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) { return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } + +func (tun *NativeTun) BatchSize() int { + return 1 +} diff --git a/tun/tun_windows.go b/tun/tun_windows.go index daad4aa..2af8e3e 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2018-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -9,13 +9,13 @@ import ( "errors" "fmt" "os" + "sync" "sync/atomic" "time" - "unsafe" + _ "unsafe" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/tun/wintun" + "golang.zx2c4.com/wintun" ) const ( @@ -25,24 +25,31 @@ const ( ) type rateJuggler struct { - current uint64 - nextByteCount uint64 - nextStartTime int64 - changing int32 + current atomic.Uint64 + nextByteCount atomic.Uint64 + nextStartTime atomic.Int64 + changing atomic.Bool } type NativeTun struct { - wt *wintun.Interface + wt *wintun.Adapter + name string handle windows.Handle - close bool - rings wintun.RingDescriptor + rate rateJuggler + session wintun.Session + readWait windows.Handle events chan Event - errors chan error + running sync.WaitGroup + closeOnce sync.Once + close atomic.Bool forcedMTU int - rate rateJuggler + outSizes []int } -const WintunPool = wintun.Pool("WireGuard") +var ( + WintunTunnelType = "WireGuard" + WintunStaticRequestedGUID *windows.GUID +) //go:linkname procyield runtime.procyield func procyield(cycles uint32) @@ -50,34 +57,18 @@ func procyield(cycles uint32) //go:linkname nanotime runtime.nanotime func nanotime() int64 -// // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. -// func CreateTUN(ifname string, mtu int) (Device, error) { - return CreateTUNWithRequestedGUID(ifname, nil, mtu) + return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) } -// // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // a requested GUID. Should a Wintun interface with the same name exist, it is reused. -// func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { - var err error - var wt *wintun.Interface - - // Does an interface with this name already exist? - wt, err = WintunPool.GetInterface(ifname) - if err == nil { - // If so, we delete it, in case it has weird residual configuration. - _, err = wt.DeleteInterface() - if err != nil { - return nil, fmt.Errorf("Error deleting already existing interface: %v", err) - } - } - wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID) + wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil { - return nil, fmt.Errorf("Error creating interface: %v", err) + return nil, fmt.Errorf("Error creating interface: %w", err) } forcedMTU := 1420 @@ -87,52 +78,46 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu tun := &NativeTun{ wt: wt, + name: ifname, handle: windows.InvalidHandle, events: make(chan Event, 10), - errors: make(chan error, 1), forcedMTU: forcedMTU, } - err = tun.rings.Init() - if err != nil { - tun.Close() - return nil, fmt.Errorf("Error creating events: %v", err) - } - - tun.handle, err = tun.wt.Register(&tun.rings) + tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB if err != nil { - tun.Close() - return nil, fmt.Errorf("Error registering rings: %v", err) + tun.wt.Close() + close(tun.events) + return nil, fmt.Errorf("Error starting session: %w", err) } + tun.readWait = tun.session.ReadWaitEvent() return tun, nil } func (tun *NativeTun) Name() (string, error) { - return tun.wt.Name() + return tun.name, nil } func (tun *NativeTun) File() *os.File { return nil } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Close() error { - tun.close = true - if tun.rings.Send.TailMoved != 0 { - windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping - } - if tun.handle != windows.InvalidHandle { - windows.CloseHandle(tun.handle) - } - tun.rings.Close() var err error - if tun.wt != nil { - _, err = tun.wt.DeleteInterface() - } - close(tun.events) + tun.closeOnce.Do(func() { + tun.close.Store(true) + windows.SetEvent(tun.readWait) + tun.running.Wait() + tun.session.End() + if tun.wt != nil { + tun.wt.Close() + } + close(tun.events) + }) return err } @@ -142,129 +127,115 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } + update := tun.forcedMTU != mtu tun.forcedMTU = mtu + if update { + tun.events <- EventMTUUpdate + } +} + +func (tun *NativeTun) BatchSize() int { + // TODO: implement batching with wintun + return 1 } // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + tun.running.Add(1) + defer tun.running.Done() retry: - select { - case err := <-tun.errors: - return 0, err - default: - } - if tun.close { - return 0, os.ErrClosed - } - - buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head) - if buffHead >= wintun.PacketCapacity { + if tun.close.Load() { return 0, os.ErrClosed } - start := nanotime() - shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 - var buffTail uint32 + shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 for { - buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail) - if buffHead != buffTail { - break - } - if tun.close { + if tun.close.Load() { return 0, os.ErrClosed } - if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { - windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE) - goto retry + packet, err := tun.session.ReceivePacket() + switch err { + case nil: + n := copy(bufs[0][offset:], packet) + sizes[0] = n + tun.session.ReleaseReceivePacket(packet) + tun.rate.update(uint64(n)) + return 1, nil + case windows.ERROR_NO_MORE_ITEMS: + if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { + windows.WaitForSingleObject(tun.readWait, windows.INFINITE) + goto retry + } + procyield(1) + continue + case windows.ERROR_HANDLE_EOF: + return 0, os.ErrClosed + case windows.ERROR_INVALID_DATA: + return 0, errors.New("Send ring corrupt") } - procyield(1) - } - if buffTail >= wintun.PacketCapacity { - return 0, os.ErrClosed - } - - buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead) - if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) { - return 0, errors.New("incomplete packet header in send ring") - } - - packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead])) - if packet.Size > wintun.PacketSizeMax { - return 0, errors.New("packet too big in send ring") - } - - alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size) - if alignedPacketSize > buffContent { - return 0, errors.New("incomplete packet in send ring") + return 0, fmt.Errorf("Read failed: %w", err) } - - copy(buff[offset:], packet.Data[:packet.Size]) - buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize) - atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead) - tun.rate.update(uint64(packet.Size)) - return int(packet.Size), nil -} - -func (tun *NativeTun) Flush() error { - return nil } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - if tun.close { - return 0, os.ErrClosed - } - - packetSize := uint32(len(buff) - offset) - tun.rate.update(uint64(packetSize)) - alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize) - - buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head) - if buffHead >= wintun.PacketCapacity { - return 0, os.ErrClosed - } - - buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail) - if buffTail >= wintun.PacketCapacity { +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + tun.running.Add(1) + defer tun.running.Done() + if tun.close.Load() { return 0, os.ErrClosed } - buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment) - if alignedPacketSize > buffSpace { - return 0, nil // Dropping when ring is full. - } - - packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail])) - packet.Size = packetSize - copy(packet.Data[:packetSize], buff[offset:]) - atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize)) - if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 { - windows.SetEvent(tun.rings.Receive.TailMoved) + for i, buf := range bufs { + packetSize := len(buf) - offset + tun.rate.update(uint64(packetSize)) + + packet, err := tun.session.AllocateSendPacket(packetSize) + switch err { + case nil: + // TODO: Explore options to eliminate this copy. + copy(packet, buf[offset:]) + tun.session.SendPacket(packet) + continue + case windows.ERROR_HANDLE_EOF: + return i, os.ErrClosed + case windows.ERROR_BUFFER_OVERFLOW: + continue // Dropping when ring is full. + default: + return i, fmt.Errorf("Write failed: %w", err) + } } - return int(packetSize), nil + return len(bufs), nil } // LUID returns Windows interface instance ID. func (tun *NativeTun) LUID() uint64 { + tun.running.Add(1) + defer tun.running.Done() + if tun.close.Load() { + return 0 + } return tun.wt.LUID() } -// Version returns the version of the Wintun driver and NDIS system currently loaded. -func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) { - return tun.wt.Version() +// RunningVersion returns the running version of the Wintun driver. +func (tun *NativeTun) RunningVersion() (version uint32, err error) { + return wintun.RunningVersion() } func (rate *rateJuggler) update(packetLen uint64) { now := nanotime() - total := atomic.AddUint64(&rate.nextByteCount, packetLen) - period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) + total := rate.nextByteCount.Add(packetLen) + period := uint64(now - rate.nextStartTime.Load()) if period >= rateMeasurementGranularity { - if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { + if !rate.changing.CompareAndSwap(false, true) { return } - atomic.StoreInt64(&rate.nextStartTime, now) - atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) - atomic.StoreUint64(&rate.nextByteCount, 0) - atomic.StoreInt32(&rate.changing, 0) + rate.nextStartTime.Store(now) + rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) + rate.nextByteCount.Store(0) + rate.changing.Store(false) } } diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go new file mode 100644 index 0000000..d07e860 --- /dev/null +++ b/tun/tuntest/tuntest.go @@ -0,0 +1,155 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tuntest + +import ( + "encoding/binary" + "io" + "net/netip" + "os" + + "golang.zx2c4.com/wireguard/tun" +) + +func Ping(dst, src netip.Addr) []byte { + localPort := uint16(1337) + seq := uint16(0) + + payload := make([]byte, 4) + binary.BigEndian.PutUint16(payload[0:], localPort) + binary.BigEndian.PutUint16(payload[2:], seq) + + return genICMPv4(payload, dst, src) +} + +// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. +func checksum(buf []byte, initial uint16) uint16 { + v := uint32(initial) + for i := 0; i < len(buf)-1; i += 2 { + v += uint32(binary.BigEndian.Uint16(buf[i:])) + } + if len(buf)%2 == 1 { + v += uint32(buf[len(buf)-1]) << 8 + } + for v > 0xffff { + v = (v >> 16) + (v & 0xffff) + } + return ^uint16(v) +} + +func genICMPv4(payload []byte, dst, src netip.Addr) []byte { + const ( + icmpv4ProtocolNumber = 1 + icmpv4Echo = 8 + icmpv4ChecksumOffset = 2 + icmpv4Size = 8 + ipv4Size = 20 + ipv4TotalLenOffset = 2 + ipv4ChecksumOffset = 10 + ttl = 65 + headerSize = ipv4Size + icmpv4Size + ) + + pkt := make([]byte, headerSize+len(payload)) + + ip := pkt[0:ipv4Size] + icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size] + + // https://tools.ietf.org/html/rfc792 + icmpv4[0] = icmpv4Echo // type + icmpv4[1] = 0 // code + chksum := ^checksum(icmpv4, checksum(payload, 0)) + binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) + + // https://tools.ietf.org/html/rfc760 section 3.1 + length := uint16(len(pkt)) + ip[0] = (4 << 4) | (ipv4Size / 4) + binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) + ip[8] = ttl + ip[9] = icmpv4ProtocolNumber + copy(ip[12:], src.AsSlice()) + copy(ip[16:], dst.AsSlice()) + chksum = ^checksum(ip[:], 0) + binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) + + copy(pkt[headerSize:], payload) + return pkt +} + +type ChannelTUN struct { + Inbound chan []byte // incoming packets, closed on TUN close + Outbound chan []byte // outbound packets, blocks forever on TUN close + + closed chan struct{} + events chan tun.Event + tun chTun +} + +func NewChannelTUN() *ChannelTUN { + c := &ChannelTUN{ + Inbound: make(chan []byte), + Outbound: make(chan []byte), + closed: make(chan struct{}), + events: make(chan tun.Event, 1), + } + c.tun.c = c + c.events <- tun.EventUp + return c +} + +func (c *ChannelTUN) TUN() tun.Device { + return &c.tun +} + +type chTun struct { + c *ChannelTUN +} + +func (t *chTun) File() *os.File { return nil } + +func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { + select { + case <-t.c.closed: + return 0, os.ErrClosed + case msg := <-t.c.Outbound: + n := copy(packets[0][offset:], msg) + sizes[0] = n + return 1, nil + } +} + +// Write is called by the wireguard device to deliver a packet for routing. +func (t *chTun) Write(packets [][]byte, offset int) (int, error) { + if offset == -1 { + close(t.c.closed) + close(t.c.events) + return 0, io.EOF + } + for i, data := range packets { + msg := make([]byte, len(data)-offset) + copy(msg, data[offset:]) + select { + case <-t.c.closed: + return i, os.ErrClosed + case t.c.Inbound <- msg: + } + } + return len(packets), nil +} + +func (t *chTun) BatchSize() int { + return 1 +} + +const DefaultMTU = 1420 + +func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } +func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } +func (t *chTun) Events() <-chan tun.Event { return t.c.events } +func (t *chTun) Close() error { + t.Write(nil, -1) + return nil +} diff --git a/tun/wintun/iphlpapi/conversion_windows.go b/tun/wintun/iphlpapi/conversion_windows.go deleted file mode 100644 index a19e961..0000000 --- a/tun/wintun/iphlpapi/conversion_windows.go +++ /dev/null @@ -1,25 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package iphlpapi - -import "golang.org/x/sys/windows" - -//sys convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid -//sys convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) = iphlpapi.ConvertInterfaceAliasToLuid - -func InterfaceGUIDFromAlias(alias string) (*windows.GUID, error) { - var luid uint64 - var guid windows.GUID - err := convertInterfaceAliasToLUID(windows.StringToUTF16Ptr(alias), &luid) - if err != nil { - return nil, err - } - err = convertInterfaceLUIDToGUID(&luid, &guid) - if err != nil { - return nil, err - } - return &guid, nil -} diff --git a/tun/wintun/iphlpapi/mksyscall.go b/tun/wintun/iphlpapi/mksyscall.go deleted file mode 100644 index fc7dba4..0000000 --- a/tun/wintun/iphlpapi/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package iphlpapi - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go diff --git a/tun/wintun/iphlpapi/zsyscall_windows.go b/tun/wintun/iphlpapi/zsyscall_windows.go deleted file mode 100644 index dc14294..0000000 --- a/tun/wintun/iphlpapi/zsyscall_windows.go +++ /dev/null @@ -1,60 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package iphlpapi - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") - - procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid") - procConvertInterfaceAliasToLuid = modiphlpapi.NewProc("ConvertInterfaceAliasToLuid") -) - -func convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) { - r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0) - if r0 != 0 { - ret = syscall.Errno(r0) - } - return -} - -func convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) { - r0, _, _ := syscall.Syscall(procConvertInterfaceAliasToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceAlias)), uintptr(unsafe.Pointer(interfaceLUID)), 0) - if r0 != 0 { - ret = syscall.Errno(r0) - } - return -} diff --git a/tun/wintun/namespace_windows.go b/tun/wintun/namespace_windows.go deleted file mode 100644 index f4316fe..0000000 --- a/tun/wintun/namespace_windows.go +++ /dev/null @@ -1,98 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "encoding/hex" - "errors" - "fmt" - "sync" - "unsafe" - - "golang.org/x/crypto/blake2s" - "golang.org/x/sys/windows" - "golang.org/x/text/unicode/norm" - - "golang.zx2c4.com/wireguard/tun/wintun/namespaceapi" -) - -var ( - wintunObjectSecurityAttributes *windows.SecurityAttributes - hasInitializedNamespace bool - initializingNamespace sync.Mutex -) - -func initializeNamespace() error { - initializingNamespace.Lock() - defer initializingNamespace.Unlock() - if hasInitializedNamespace { - return nil - } - sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)") - if err != nil { - return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err) - } - wintunObjectSecurityAttributes = &windows.SecurityAttributes{ - Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})), - SecurityDescriptor: sd, - } - sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid) - if err != nil { - return fmt.Errorf("CreateWellKnownSid(LOCAL_SYSTEM) failed: %v", err) - } - - boundary, err := namespaceapi.CreateBoundaryDescriptor("Wintun") - if err != nil { - return fmt.Errorf("CreateBoundaryDescriptor failed: %v", err) - } - err = boundary.AddSid(sid) - if err != nil { - return fmt.Errorf("AddSIDToBoundaryDescriptor failed: %v", err) - } - for { - _, err = namespaceapi.CreatePrivateNamespace(wintunObjectSecurityAttributes, boundary, "Wintun") - if err == windows.ERROR_ALREADY_EXISTS { - _, err = namespaceapi.OpenPrivateNamespace(boundary, "Wintun") - if err == windows.ERROR_PATH_NOT_FOUND { - continue - } - } - if err != nil { - return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err) - } - break - } - hasInitializedNamespace = true - return nil -} - -func (pool Pool) takeNameMutex() (windows.Handle, error) { - err := initializeNamespace() - if err != nil { - return 0, err - } - - const mutexLabel = "WireGuard Adapter Name Mutex Stable Suffix v1 jason@zx2c4.com" - b2, _ := blake2s.New256(nil) - b2.Write([]byte(mutexLabel)) - b2.Write(norm.NFC.Bytes([]byte(string(pool)))) - mutexName := `Wintun\Wintun-Name-Mutex-` + hex.EncodeToString(b2.Sum(nil)) - mutex, err := windows.CreateMutex(wintunObjectSecurityAttributes, false, windows.StringToUTF16Ptr(mutexName)) - if err != nil { - err = fmt.Errorf("Error creating name mutex: %v", err) - return 0, err - } - event, err := windows.WaitForSingleObject(mutex, windows.INFINITE) - if err != nil { - windows.CloseHandle(mutex) - return 0, fmt.Errorf("Error waiting on name mutex: %v", err) - } - if event != windows.WAIT_OBJECT_0 && event != windows.WAIT_ABANDONED { - windows.CloseHandle(mutex) - return 0, errors.New("Error with event trigger of name mutex") - } - return mutex, nil -} diff --git a/tun/wintun/namespaceapi/mksyscall.go b/tun/wintun/namespaceapi/mksyscall.go deleted file mode 100644 index 93d43b0..0000000 --- a/tun/wintun/namespaceapi/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package namespaceapi - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go diff --git a/tun/wintun/namespaceapi/namespaceapi_windows.go b/tun/wintun/namespaceapi/namespaceapi_windows.go deleted file mode 100644 index a3a6274..0000000 --- a/tun/wintun/namespaceapi/namespaceapi_windows.go +++ /dev/null @@ -1,83 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package namespaceapi - -import "golang.org/x/sys/windows" - -//sys createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) = kernel32.CreateBoundaryDescriptorW -//sys deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) = kernel32.DeleteBoundaryDescriptor -//sys addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) = kernel32.AddSIDToBoundaryDescriptor -//sys createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.CreatePrivateNamespaceW -//sys openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.OpenPrivateNamespaceW -//sys closePrivateNamespace(handle windows.Handle, flags uint32) (err error) = kernel32.ClosePrivateNamespace - -// BoundaryDescriptor represents a boundary that defines how the objects in the namespace are to be isolated. -type BoundaryDescriptor windows.Handle - -// CreateBoundaryDescriptor creates a boundary descriptor. -func CreateBoundaryDescriptor(name string) (BoundaryDescriptor, error) { - name16, err := windows.UTF16PtrFromString(name) - if err != nil { - return 0, err - } - handle, err := createBoundaryDescriptor(name16, 0) - if err != nil { - return 0, err - } - return BoundaryDescriptor(handle), nil -} - -// Delete deletes the specified boundary descriptor. -func (bd BoundaryDescriptor) Delete() { - deleteBoundaryDescriptor(windows.Handle(bd)) -} - -// AddSid adds a security identifier (SID) to the specified boundary descriptor. -func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error { - return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid) -} - -// PrivateNamespace represents a private namespace. -type PrivateNamespace windows.Handle - -// CreatePrivateNamespace creates a private namespace. -func CreatePrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) { - aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix) - if err != nil { - return 0, err - } - handle, err := createPrivateNamespace(privateNamespaceAttributes, windows.Handle(boundaryDescriptor), aliasPrefix16) - if err != nil { - return 0, err - } - return PrivateNamespace(handle), nil -} - -// OpenPrivateNamespace opens a private namespace. -func OpenPrivateNamespace(boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) { - aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix) - if err != nil { - return 0, err - } - handle, err := openPrivateNamespace(windows.Handle(boundaryDescriptor), aliasPrefix16) - if err != nil { - return 0, err - } - return PrivateNamespace(handle), nil -} - -// ClosePrivateNamespaceFlags describes flags that are used by PrivateNamespace's Close() method. -type ClosePrivateNamespaceFlags uint32 - -const ( - // PrivateNamespaceFlagDestroy makes the close to destroy the namespace. - PrivateNamespaceFlagDestroy = ClosePrivateNamespaceFlags(0x1) -) - -// Close closes an open namespace handle. -func (pns PrivateNamespace) Close(flags ClosePrivateNamespaceFlags) error { - return closePrivateNamespace(windows.Handle(pns), uint32(flags)) -} diff --git a/tun/wintun/namespaceapi/zsyscall_windows.go b/tun/wintun/namespaceapi/zsyscall_windows.go deleted file mode 100644 index 508c223..0000000 --- a/tun/wintun/namespaceapi/zsyscall_windows.go +++ /dev/null @@ -1,116 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package namespaceapi - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - - procCreateBoundaryDescriptorW = modkernel32.NewProc("CreateBoundaryDescriptorW") - procDeleteBoundaryDescriptor = modkernel32.NewProc("DeleteBoundaryDescriptor") - procAddSIDToBoundaryDescriptor = modkernel32.NewProc("AddSIDToBoundaryDescriptor") - procCreatePrivateNamespaceW = modkernel32.NewProc("CreatePrivateNamespaceW") - procOpenPrivateNamespaceW = modkernel32.NewProc("OpenPrivateNamespaceW") - procClosePrivateNamespace = modkernel32.NewProc("ClosePrivateNamespace") -) - -func createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procCreateBoundaryDescriptorW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(flags), 0) - handle = windows.Handle(r0) - if handle == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) { - syscall.Syscall(procDeleteBoundaryDescriptor.Addr(), 1, uintptr(boundaryDescriptor), 0, 0) - return -} - -func addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) { - r1, _, e1 := syscall.Syscall(procAddSIDToBoundaryDescriptor.Addr(), 2, uintptr(unsafe.Pointer(boundaryDescriptor)), uintptr(unsafe.Pointer(requiredSid)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procCreatePrivateNamespaceW.Addr(), 3, uintptr(unsafe.Pointer(privateNamespaceAttributes)), uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix))) - handle = windows.Handle(r0) - if handle == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procOpenPrivateNamespaceW.Addr(), 2, uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)), 0) - handle = windows.Handle(r0) - if handle == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func closePrivateNamespace(handle windows.Handle, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall(procClosePrivateNamespace.Addr(), 2, uintptr(handle), uintptr(flags), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} diff --git a/tun/wintun/nci/mksyscall.go b/tun/wintun/nci/mksyscall.go deleted file mode 100644 index 019da93..0000000 --- a/tun/wintun/nci/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package nci - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go diff --git a/tun/wintun/nci/nci_windows.go b/tun/wintun/nci/nci_windows.go deleted file mode 100644 index 9dc6699..0000000 --- a/tun/wintun/nci/nci_windows.go +++ /dev/null @@ -1,28 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package nci - -import "golang.org/x/sys/windows" - -//sys nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) = nci.NciSetConnectionName -//sys nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) = nci.NciGetConnectionName - -func SetConnectionName(guid *windows.GUID, newName string) error { - newName16, err := windows.UTF16PtrFromString(newName) - if err != nil { - return err - } - return nciSetConnectionName(guid, newName16) -} - -func ConnectionName(guid *windows.GUID) (string, error) { - var name [0x400]uint16 - err := nciGetConnectionName(guid, &name[0], uint32(len(name)*2), nil) - if err != nil { - return "", err - } - return windows.UTF16ToString(name[:]), nil -} diff --git a/tun/wintun/nci/zsyscall_windows.go b/tun/wintun/nci/zsyscall_windows.go deleted file mode 100644 index 2a7b79e..0000000 --- a/tun/wintun/nci/zsyscall_windows.go +++ /dev/null @@ -1,60 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package nci - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modnci = windows.NewLazySystemDLL("nci.dll") - - procNciSetConnectionName = modnci.NewProc("NciSetConnectionName") - procNciGetConnectionName = modnci.NewProc("NciGetConnectionName") -) - -func nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) { - r0, _, _ := syscall.Syscall(procNciSetConnectionName.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0) - if r0 != 0 { - ret = syscall.Errno(r0) - } - return -} - -func nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) { - r0, _, _ := syscall.Syscall6(procNciGetConnectionName.Addr(), 4, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(destName)), uintptr(inDestNameBytes), uintptr(unsafe.Pointer(outDestNameBytes)), 0, 0) - if r0 != 0 { - ret = syscall.Errno(r0) - } - return -} diff --git a/tun/wintun/registry/mksyscall.go b/tun/wintun/registry/mksyscall.go deleted file mode 100644 index 6ad82d2..0000000 --- a/tun/wintun/registry/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package registry - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zregistry_windows.go registry_windows.go diff --git a/tun/wintun/registry/registry_windows.go b/tun/wintun/registry/registry_windows.go deleted file mode 100644 index 12a0336..0000000 --- a/tun/wintun/registry/registry_windows.go +++ /dev/null @@ -1,272 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package registry - -import ( - "errors" - "fmt" - "runtime" - "strings" - "time" - "unsafe" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" -) - -const ( - // REG_NOTIFY_CHANGE_NAME notifies the caller if a subkey is added or deleted. - REG_NOTIFY_CHANGE_NAME uint32 = 0x00000001 - - // REG_NOTIFY_CHANGE_ATTRIBUTES notifies the caller of changes to the attributes of the key, such as the security descriptor information. - REG_NOTIFY_CHANGE_ATTRIBUTES uint32 = 0x00000002 - - // REG_NOTIFY_CHANGE_LAST_SET notifies the caller of changes to a value of the key. This can include adding or deleting a value, or changing an existing value. - REG_NOTIFY_CHANGE_LAST_SET uint32 = 0x00000004 - - // REG_NOTIFY_CHANGE_SECURITY notifies the caller of changes to the security descriptor of the key. - REG_NOTIFY_CHANGE_SECURITY uint32 = 0x00000008 - - // REG_NOTIFY_THREAD_AGNOSTIC indicates that the lifetime of the registration must not be tied to the lifetime of the thread issuing the RegNotifyChangeKeyValue call. Note: This flag value is only supported in Windows 8 and later. - REG_NOTIFY_THREAD_AGNOSTIC uint32 = 0x10000000 -) - -//sys regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) = advapi32.RegNotifyChangeKeyValue - -func OpenKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - deadline := time.Now().Add(timeout) - pathSpl := strings.Split(path, "\\") - for i := 0; ; i++ { - keyName := pathSpl[i] - isLast := i+1 == len(pathSpl) - - event, err := windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return 0, fmt.Errorf("Error creating event: %v", err) - } - defer windows.CloseHandle(event) - - var key registry.Key - for { - err = regNotifyChangeKeyValue(windows.Handle(k), false, REG_NOTIFY_CHANGE_NAME, windows.Handle(event), true) - if err != nil { - return 0, fmt.Errorf("Setting up change notification on registry key failed: %v", err) - } - - var accessFlags uint32 - if isLast { - accessFlags = access - } else { - accessFlags = registry.NOTIFY - } - key, err = registry.OpenKey(k, keyName, accessFlags) - if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND { - timeout := time.Until(deadline) / time.Millisecond - if timeout < 0 { - timeout = 0 - } - s, err := windows.WaitForSingleObject(event, uint32(timeout)) - if err != nil { - return 0, fmt.Errorf("Unable to wait on registry key: %v", err) - } - if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows - return 0, errors.New("Timeout waiting for registry key") - } - } else if err != nil { - return 0, fmt.Errorf("Error opening registry key %v: %v", path, err) - } else { - if isLast { - return key, nil - } - defer key.Close() - break - } - } - - k = key - } -} - -func WaitForKey(k registry.Key, path string, timeout time.Duration) error { - key, err := OpenKeyWait(k, path, registry.NOTIFY, timeout) - if err != nil { - return err - } - key.Close() - return nil -} - -// -// getValue is more or less the same as windows/registry's getValue. -// -func getValue(k registry.Key, name string, buf []byte) (value []byte, valueType uint32, err error) { - var name16 *uint16 - name16, err = windows.UTF16PtrFromString(name) - if err != nil { - return - } - n := uint32(len(buf)) - for { - err = windows.RegQueryValueEx(windows.Handle(k), name16, nil, &valueType, (*byte)(unsafe.Pointer(&buf[0])), &n) - if err == nil { - value = buf[:n] - return - } - if err != windows.ERROR_MORE_DATA { - return - } - if n <= uint32(len(buf)) { - return - } - buf = make([]byte, n) - } -} - -// -// getValueRetry function reads any value from registry. It waits for -// the registry value to become available or returns error on timeout. -// -// Key must be opened with at least QUERY_VALUE|NOTIFY access. -// -func getValueRetry(key registry.Key, name string, buf []byte, timeout time.Duration) ([]byte, uint32, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - event, err := windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return nil, 0, fmt.Errorf("Error creating event: %v", err) - } - defer windows.CloseHandle(event) - - deadline := time.Now().Add(timeout) - for { - err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true) - if err != nil { - return nil, 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err) - } - - buf, valueType, err := getValue(key, name, buf) - if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND { - timeout := time.Until(deadline) / time.Millisecond - if timeout < 0 { - timeout = 0 - } - s, err := windows.WaitForSingleObject(event, uint32(timeout)) - if err != nil { - return nil, 0, fmt.Errorf("Unable to wait on registry value: %v", err) - } - if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows - return nil, 0, errors.New("Timeout waiting for registry value") - } - } else if err != nil { - return nil, 0, fmt.Errorf("Error reading registry value %v: %v", name, err) - } else { - return buf, valueType, nil - } - } -} - -func toString(buf []byte, valueType uint32, err error) (string, error) { - if err != nil { - return "", err - } - - var value string - switch valueType { - case registry.SZ, registry.EXPAND_SZ, registry.MULTI_SZ: - if len(buf) == 0 { - return "", nil - } - value = windows.UTF16ToString((*[(1 << 30) - 1]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2]) - - default: - return "", registry.ErrUnexpectedType - } - - if valueType != registry.EXPAND_SZ { - // Value does not require expansion. - return value, nil - } - - valueExp, err := registry.ExpandString(value) - if err != nil { - // Expanding failed: return original sting value. - return value, nil - } - - // Return expanded value. - return valueExp, nil -} - -func toInteger(buf []byte, valueType uint32, err error) (uint64, error) { - if err != nil { - return 0, err - } - - switch valueType { - case registry.DWORD: - if len(buf) != 4 { - return 0, errors.New("DWORD value is not 4 bytes long") - } - var val uint32 - copy((*[4]byte)(unsafe.Pointer(&val))[:], buf) - return uint64(val), nil - - case registry.QWORD: - if len(buf) != 8 { - return 0, errors.New("QWORD value is not 8 bytes long") - } - var val uint64 - copy((*[8]byte)(unsafe.Pointer(&val))[:], buf) - return val, nil - - default: - return 0, registry.ErrUnexpectedType - } -} - -// -// GetStringValueWait function reads a string value from registry. It waits -// for the registry value to become available or returns error on timeout. -// -// Key must be opened with at least QUERY_VALUE|NOTIFY access. -// -// If the value type is REG_EXPAND_SZ the environment variables are expanded. -// Should expanding fail, original string value and nil error are returned. -// -// If the value type is REG_MULTI_SZ only the first string is returned. -// -func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) { - return toString(getValueRetry(key, name, make([]byte, 256), timeout)) -} - -// -// GetStringValue function reads a string value from registry. -// -// Key must be opened with at least QUERY_VALUE access. -// -// If the value type is REG_EXPAND_SZ the environment variables are expanded. -// Should expanding fail, original string value and nil error are returned. -// -// If the value type is REG_MULTI_SZ only the first string is returned. -// -func GetStringValue(key registry.Key, name string) (string, error) { - return toString(getValue(key, name, make([]byte, 256))) -} - -// -// GetIntegerValueWait function reads a DWORD32 or QWORD value from registry. -// It waits for the registry value to become available or returns error on -// timeout. -// -// Key must be opened with at least QUERY_VALUE|NOTIFY access. -// -func GetIntegerValueWait(key registry.Key, name string, timeout time.Duration) (uint64, error) { - return toInteger(getValueRetry(key, name, make([]byte, 8), timeout)) -} diff --git a/tun/wintun/registry/registry_windows_test.go b/tun/wintun/registry/registry_windows_test.go deleted file mode 100644 index c56b51b..0000000 --- a/tun/wintun/registry/registry_windows_test.go +++ /dev/null @@ -1,103 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package registry - -import ( - "testing" - "time" - - "golang.org/x/sys/windows/registry" -) - -const keyRoot = registry.CURRENT_USER -const pathRoot = "Software\\WireGuardRegistryTest" -const path = pathRoot + "\\foobar" -const pathFake = pathRoot + "\\raboof" - -func Test_WaitForKey(t *testing.T) { - registry.DeleteKey(keyRoot, path) - registry.DeleteKey(keyRoot, pathRoot) - go func() { - time.Sleep(time.Second * 1) - key, _, err := registry.CreateKey(keyRoot, pathFake, registry.QUERY_VALUE) - if err != nil { - t.Errorf("Error creating registry key: %v", err) - } - key.Close() - registry.DeleteKey(keyRoot, pathFake) - - key, _, err = registry.CreateKey(keyRoot, path, registry.QUERY_VALUE) - if err != nil { - t.Errorf("Error creating registry key: %v", err) - } - key.Close() - }() - err := WaitForKey(keyRoot, path, time.Second*2) - if err != nil { - t.Errorf("Error waiting for registry key: %v", err) - } - registry.DeleteKey(keyRoot, path) - registry.DeleteKey(keyRoot, pathRoot) - - err = WaitForKey(keyRoot, path, time.Second*1) - if err == nil { - t.Error("Registry key notification expected to timeout but it succeeded.") - } -} - -func Test_GetValueWait(t *testing.T) { - registry.DeleteKey(keyRoot, path) - registry.DeleteKey(keyRoot, pathRoot) - go func() { - time.Sleep(time.Second * 1) - key, _, err := registry.CreateKey(keyRoot, path, registry.SET_VALUE) - if err != nil { - t.Errorf("Error creating registry key: %v", err) - } - time.Sleep(time.Second * 1) - key.SetStringValue("name1", "eulav") - key.SetExpandStringValue("name2", "value") - time.Sleep(time.Second * 1) - key.SetDWordValue("name3", ^uint32(123)) - key.SetDWordValue("name4", 123) - key.Close() - }() - - key, err := OpenKeyWait(keyRoot, path, registry.QUERY_VALUE|registry.NOTIFY, time.Second*2) - if err != nil { - t.Errorf("Error waiting for registry key: %v", err) - } - - valueStr, err := GetStringValueWait(key, "name2", time.Second*2) - if err != nil { - t.Errorf("Error waiting for registry value: %v", err) - } - if valueStr != "value" { - t.Errorf("Wrong value read: %v", valueStr) - } - - _, err = GetStringValueWait(key, "nonexisting", time.Second*1) - if err == nil { - t.Error("Registry value notification expected to timeout but it succeeded.") - } - - valueInt, err := GetIntegerValueWait(key, "name4", time.Second*2) - if err != nil { - t.Errorf("Error waiting for registry value: %v", err) - } - if valueInt != 123 { - t.Errorf("Wrong value read: %v", valueInt) - } - - _, err = GetIntegerValueWait(key, "nonexisting", time.Second*1) - if err == nil { - t.Error("Registry value notification expected to timeout but it succeeded.") - } - - key.Close() - registry.DeleteKey(keyRoot, path) - registry.DeleteKey(keyRoot, pathRoot) -} diff --git a/tun/wintun/registry/zregistry_windows.go b/tun/wintun/registry/zregistry_windows.go deleted file mode 100644 index f7ac33b..0000000 --- a/tun/wintun/registry/zregistry_windows.go +++ /dev/null @@ -1,63 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package registry - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") - - procRegNotifyChangeKeyValue = modadvapi32.NewProc("RegNotifyChangeKeyValue") -) - -func regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) { - var _p0 uint32 - if watchSubtree { - _p0 = 1 - } else { - _p0 = 0 - } - var _p1 uint32 - if asynchronous { - _p1 = 1 - } else { - _p1 = 0 - } - r0, _, _ := syscall.Syscall6(procRegNotifyChangeKeyValue.Addr(), 5, uintptr(key), uintptr(_p0), uintptr(notifyFilter), uintptr(event), uintptr(_p1), 0) - if r0 != 0 { - regerrno = syscall.Errno(r0) - } - return -} diff --git a/tun/wintun/ring_windows.go b/tun/wintun/ring_windows.go deleted file mode 100644 index 8f46bc9..0000000 --- a/tun/wintun/ring_windows.go +++ /dev/null @@ -1,97 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - PacketAlignment = 4 // Number of bytes packets are aligned to in rings - PacketSizeMax = 0xffff // Maximum packet size - PacketCapacity = 0x800000 // Ring capacity, 8MiB - PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment - ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14) -) - -type PacketHeader struct { - Size uint32 -} - -type Packet struct { - PacketHeader - Data [PacketSizeMax]byte -} - -type Ring struct { - Head uint32 - Tail uint32 - Alertable int32 - Data [PacketCapacity + PacketTrailingSize]byte -} - -type RingDescriptor struct { - Send, Receive struct { - Size uint32 - Ring *Ring - TailMoved windows.Handle - } -} - -// Wrap returns value modulo ring capacity -func (rb *Ring) Wrap(value uint32) uint32 { - return value & (PacketCapacity - 1) -} - -// Aligns a packet size to PacketAlignment -func PacketAlign(size uint32) uint32 { - return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1) -} - -func (descriptor *RingDescriptor) Init() (err error) { - descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{})) - descriptor.Send.Ring = &Ring{} - descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return - } - - descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{})) - descriptor.Receive.Ring = &Ring{} - descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - windows.CloseHandle(descriptor.Send.TailMoved) - return - } - - return -} - -func (descriptor *RingDescriptor) Close() { - if descriptor.Send.TailMoved != 0 { - windows.CloseHandle(descriptor.Send.TailMoved) - descriptor.Send.TailMoved = 0 - } - if descriptor.Send.TailMoved != 0 { - windows.CloseHandle(descriptor.Receive.TailMoved) - descriptor.Receive.TailMoved = 0 - } -} - -func (wintun *Interface) Register(descriptor *RingDescriptor) (windows.Handle, error) { - handle, err := wintun.handle() - if err != nil { - return 0, err - } - var bytesReturned uint32 - err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil) - if err != nil { - return 0, err - } - return handle, nil -} diff --git a/tun/wintun/setupapi/mksyscall.go b/tun/wintun/setupapi/mksyscall.go deleted file mode 100644 index ac103a1..0000000 --- a/tun/wintun/setupapi/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsetupapi_windows.go setupapi_windows.go diff --git a/tun/wintun/setupapi/setupapi_windows.go b/tun/wintun/setupapi/setupapi_windows.go deleted file mode 100644 index 60a8eb7..0000000 --- a/tun/wintun/setupapi/setupapi_windows.go +++ /dev/null @@ -1,506 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -import ( - "encoding/binary" - "fmt" - "runtime" - "unsafe" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" -) - -//sys setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiCreateDeviceInfoListExW - -// SetupDiCreateDeviceInfoListEx function creates an empty device information set on a remote or a local computer and optionally associates the set with a device setup class. -func SetupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName string) (deviceInfoSet DevInfo, err error) { - var machineNameUTF16 *uint16 - if machineName != "" { - machineNameUTF16, err = windows.UTF16PtrFromString(machineName) - if err != nil { - return - } - } - return setupDiCreateDeviceInfoListEx(classGUID, hwndParent, machineNameUTF16, 0) -} - -//sys setupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo, deviceInfoSetDetailData *DevInfoListDetailData) (err error) = setupapi.SetupDiGetDeviceInfoListDetailW - -// SetupDiGetDeviceInfoListDetail function retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name. -func SetupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo) (deviceInfoSetDetailData *DevInfoListDetailData, err error) { - data := &DevInfoListDetailData{} - data.size = sizeofDevInfoListDetailData - - return data, setupDiGetDeviceInfoListDetail(deviceInfoSet, data) -} - -// DeviceInfoListDetail method retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name. -func (deviceInfoSet DevInfo) DeviceInfoListDetail() (*DevInfoListDetailData, error) { - return SetupDiGetDeviceInfoListDetail(deviceInfoSet) -} - -//sys setupDiCreateDeviceInfo(deviceInfoSet DevInfo, DeviceName *uint16, classGUID *windows.GUID, DeviceDescription *uint16, hwndParent uintptr, CreationFlags DICD, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiCreateDeviceInfoW - -// SetupDiCreateDeviceInfo function creates a new device information element and adds it as a new member to the specified device information set. -func SetupDiCreateDeviceInfo(deviceInfoSet DevInfo, deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (deviceInfoData *DevInfoData, err error) { - deviceNameUTF16, err := windows.UTF16PtrFromString(deviceName) - if err != nil { - return - } - - var deviceDescriptionUTF16 *uint16 - if deviceDescription != "" { - deviceDescriptionUTF16, err = windows.UTF16PtrFromString(deviceDescription) - if err != nil { - return - } - } - - data := &DevInfoData{} - data.size = uint32(unsafe.Sizeof(*data)) - - return data, setupDiCreateDeviceInfo(deviceInfoSet, deviceNameUTF16, classGUID, deviceDescriptionUTF16, hwndParent, creationFlags, data) -} - -// CreateDeviceInfo method creates a new device information element and adds it as a new member to the specified device information set. -func (deviceInfoSet DevInfo) CreateDeviceInfo(deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (*DevInfoData, error) { - return SetupDiCreateDeviceInfo(deviceInfoSet, deviceName, classGUID, deviceDescription, hwndParent, creationFlags) -} - -//sys setupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex uint32, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiEnumDeviceInfo - -// SetupDiEnumDeviceInfo function returns a DevInfoData structure that specifies a device information element in a device information set. -func SetupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex int) (*DevInfoData, error) { - data := &DevInfoData{} - data.size = uint32(unsafe.Sizeof(*data)) - - return data, setupDiEnumDeviceInfo(deviceInfoSet, uint32(memberIndex), data) -} - -// EnumDeviceInfo method returns a DevInfoData structure that specifies a device information element in a device information set. -func (deviceInfoSet DevInfo) EnumDeviceInfo(memberIndex int) (*DevInfoData, error) { - return SetupDiEnumDeviceInfo(deviceInfoSet, memberIndex) -} - -// SetupDiDestroyDeviceInfoList function deletes a device information set and frees all associated memory. -//sys SetupDiDestroyDeviceInfoList(deviceInfoSet DevInfo) (err error) = setupapi.SetupDiDestroyDeviceInfoList - -// Close method deletes a device information set and frees all associated memory. -func (deviceInfoSet DevInfo) Close() error { - return SetupDiDestroyDeviceInfoList(deviceInfoSet) -} - -//sys SetupDiBuildDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) = setupapi.SetupDiBuildDriverInfoList - -// BuildDriverInfoList method builds a list of drivers that is associated with a specific device or with the global class driver list for a device information set. -func (deviceInfoSet DevInfo) BuildDriverInfoList(deviceInfoData *DevInfoData, driverType SPDIT) error { - return SetupDiBuildDriverInfoList(deviceInfoSet, deviceInfoData, driverType) -} - -//sys SetupDiCancelDriverInfoSearch(deviceInfoSet DevInfo) (err error) = setupapi.SetupDiCancelDriverInfoSearch - -// CancelDriverInfoSearch method cancels a driver list search that is currently in progress in a different thread. -func (deviceInfoSet DevInfo) CancelDriverInfoSearch() error { - return SetupDiCancelDriverInfoSearch(deviceInfoSet) -} - -//sys setupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex uint32, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiEnumDriverInfoW - -// SetupDiEnumDriverInfo function enumerates the members of a driver list. -func SetupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex int) (*DrvInfoData, error) { - data := &DrvInfoData{} - data.size = uint32(unsafe.Sizeof(*data)) - - return data, setupDiEnumDriverInfo(deviceInfoSet, deviceInfoData, driverType, uint32(memberIndex), data) -} - -// EnumDriverInfo method enumerates the members of a driver list. -func (deviceInfoSet DevInfo) EnumDriverInfo(deviceInfoData *DevInfoData, driverType SPDIT, memberIndex int) (*DrvInfoData, error) { - return SetupDiEnumDriverInfo(deviceInfoSet, deviceInfoData, driverType, memberIndex) -} - -//sys setupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiGetSelectedDriverW - -// SetupDiGetSelectedDriver function retrieves the selected driver for a device information set or a particular device information element. -func SetupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (*DrvInfoData, error) { - data := &DrvInfoData{} - data.size = uint32(unsafe.Sizeof(*data)) - - return data, setupDiGetSelectedDriver(deviceInfoSet, deviceInfoData, data) -} - -// SelectedDriver method retrieves the selected driver for a device information set or a particular device information element. -func (deviceInfoSet DevInfo) SelectedDriver(deviceInfoData *DevInfoData) (*DrvInfoData, error) { - return SetupDiGetSelectedDriver(deviceInfoSet, deviceInfoData) -} - -//sys SetupDiSetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiSetSelectedDriverW - -// SetSelectedDriver method sets, or resets, the selected driver for a device information element or the selected class driver for a device information set. -func (deviceInfoSet DevInfo) SetSelectedDriver(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) error { - return SetupDiSetSelectedDriver(deviceInfoSet, deviceInfoData, driverInfoData) -} - -//sys setupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData, driverInfoDetailData *DrvInfoDetailData, driverInfoDetailDataSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDriverInfoDetailW - -// SetupDiGetDriverInfoDetail function retrieves driver information detail for a device information set or a particular device information element in the device information set. -func SetupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) { - reqSize := uint32(2048) - for { - buf := make([]byte, reqSize) - data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0])) - data.size = sizeofDrvInfoDetailData - err := setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, uint32(len(buf)), &reqSize) - if err == windows.ERROR_INSUFFICIENT_BUFFER { - continue - } - if err != nil { - return nil, err - } - data.size = reqSize - return data, nil - } -} - -// DriverInfoDetail method retrieves driver information detail for a device information set or a particular device information element in the device information set. -func (deviceInfoSet DevInfo) DriverInfoDetail(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) { - return SetupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData) -} - -//sys SetupDiDestroyDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) = setupapi.SetupDiDestroyDriverInfoList - -// DestroyDriverInfoList method deletes a driver list. -func (deviceInfoSet DevInfo) DestroyDriverInfoList(deviceInfoData *DevInfoData, driverType SPDIT) error { - return SetupDiDestroyDriverInfoList(deviceInfoSet, deviceInfoData, driverType) -} - -//sys setupDiGetClassDevsEx(classGUID *windows.GUID, Enumerator *uint16, hwndParent uintptr, Flags DIGCF, deviceInfoSet DevInfo, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiGetClassDevsExW - -// SetupDiGetClassDevsEx function returns a handle to a device information set that contains requested device information elements for a local or a remote computer. -func SetupDiGetClassDevsEx(classGUID *windows.GUID, enumerator string, hwndParent uintptr, flags DIGCF, deviceInfoSet DevInfo, machineName string) (handle DevInfo, err error) { - var enumeratorUTF16 *uint16 - if enumerator != "" { - enumeratorUTF16, err = windows.UTF16PtrFromString(enumerator) - if err != nil { - return - } - } - var machineNameUTF16 *uint16 - if machineName != "" { - machineNameUTF16, err = windows.UTF16PtrFromString(machineName) - if err != nil { - return - } - } - return setupDiGetClassDevsEx(classGUID, enumeratorUTF16, hwndParent, flags, deviceInfoSet, machineNameUTF16, 0) -} - -// SetupDiCallClassInstaller function calls the appropriate class installer, and any registered co-installers, with the specified installation request (DIF code). -//sys SetupDiCallClassInstaller(installFunction DI_FUNCTION, deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiCallClassInstaller - -// CallClassInstaller member calls the appropriate class installer, and any registered co-installers, with the specified installation request (DIF code). -func (deviceInfoSet DevInfo) CallClassInstaller(installFunction DI_FUNCTION, deviceInfoData *DevInfoData) error { - return SetupDiCallClassInstaller(installFunction, deviceInfoSet, deviceInfoData) -} - -//sys setupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (key windows.Handle, err error) [failretval==windows.InvalidHandle] = setupapi.SetupDiOpenDevRegKey - -// SetupDiOpenDevRegKey function opens a registry key for device-specific configuration information. -func SetupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, scope DICS_FLAG, hwProfile uint32, keyType DIREG, samDesired uint32) (registry.Key, error) { - handle, err := setupDiOpenDevRegKey(deviceInfoSet, deviceInfoData, scope, hwProfile, keyType, samDesired) - return registry.Key(handle), err -} - -// OpenDevRegKey method opens a registry key for device-specific configuration information. -func (deviceInfoSet DevInfo) OpenDevRegKey(DeviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (registry.Key, error) { - return SetupDiOpenDevRegKey(deviceInfoSet, DeviceInfoData, Scope, HwProfile, KeyType, samDesired) -} - -//sys setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceRegistryPropertyW - -// SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property. -func SetupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP) (value interface{}, err error) { - reqSize := uint32(256) - for { - var dataType uint32 - buf := make([]byte, reqSize) - err = setupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &dataType, &buf[0], uint32(len(buf)), &reqSize) - if err == windows.ERROR_INSUFFICIENT_BUFFER { - continue - } - if err != nil { - return - } - return getRegistryValue(buf[:reqSize], dataType) - } -} - -func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) { - switch dataType { - case windows.REG_SZ: - ret := windows.UTF16ToString(bufToUTF16(buf)) - runtime.KeepAlive(buf) - return ret, nil - case windows.REG_EXPAND_SZ: - ret, err := registry.ExpandString(windows.UTF16ToString(bufToUTF16(buf))) - runtime.KeepAlive(buf) - return ret, err - case windows.REG_BINARY: - return buf, nil - case windows.REG_DWORD_LITTLE_ENDIAN: - return binary.LittleEndian.Uint32(buf), nil - case windows.REG_DWORD_BIG_ENDIAN: - return binary.BigEndian.Uint32(buf), nil - case windows.REG_MULTI_SZ: - bufW := bufToUTF16(buf) - a := []string{} - for i := 0; i < len(bufW); { - j := i + wcslen(bufW[i:]) - if i < j { - a = append(a, windows.UTF16ToString(bufW[i:j])) - } - i = j + 1 - } - runtime.KeepAlive(buf) - return a, nil - case windows.REG_QWORD_LITTLE_ENDIAN: - return binary.LittleEndian.Uint64(buf), nil - default: - return nil, fmt.Errorf("Unsupported registry value type: %v", dataType) - } -} - -// bufToUTF16 function reinterprets []byte buffer as []uint16 -func bufToUTF16(buf []byte) []uint16 { - sl := struct { - addr *uint16 - len int - cap int - }{(*uint16)(unsafe.Pointer(&buf[0])), len(buf) / 2, cap(buf) / 2} - return *(*[]uint16)(unsafe.Pointer(&sl)) -} - -// utf16ToBuf function reinterprets []uint16 as []byte -func utf16ToBuf(buf []uint16) []byte { - sl := struct { - addr *byte - len int - cap int - }{(*byte)(unsafe.Pointer(&buf[0])), len(buf) * 2, cap(buf) * 2} - return *(*[]byte)(unsafe.Pointer(&sl)) -} - -func wcslen(str []uint16) int { - for i := 0; i < len(str); i++ { - if str[i] == 0 { - return i - } - } - return len(str) -} - -// DeviceRegistryProperty method retrieves a specified Plug and Play device property. -func (deviceInfoSet DevInfo) DeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP) (interface{}, error) { - return SetupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property) -} - -//sys setupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffer *byte, propertyBufferSize uint32) (err error) = setupapi.SetupDiSetDeviceRegistryPropertyW - -// SetupDiSetDeviceRegistryProperty function sets a Plug and Play device property for a device. -func SetupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffers []byte) error { - return setupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &propertyBuffers[0], uint32(len(propertyBuffers))) -} - -// SetDeviceRegistryProperty function sets a Plug and Play device property for a device. -func (deviceInfoSet DevInfo) SetDeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP, propertyBuffers []byte) error { - return SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, propertyBuffers) -} - -// SetDeviceRegistryPropertyString method sets a Plug and Play device property string for a device. -func (deviceInfoSet DevInfo) SetDeviceRegistryPropertyString(deviceInfoData *DevInfoData, property SPDRP, str string) error { - str16, err := windows.UTF16FromString(str) - if err != nil { - return err - } - err = SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, utf16ToBuf(append(str16, 0))) - runtime.KeepAlive(str16) - return err -} - -//sys setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiGetDeviceInstallParamsW - -// SetupDiGetDeviceInstallParams function retrieves device installation parameters for a device information set or a particular device information element. -func SetupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (*DevInstallParams, error) { - params := &DevInstallParams{} - params.size = uint32(unsafe.Sizeof(*params)) - - return params, setupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData, params) -} - -// DeviceInstallParams method retrieves device installation parameters for a device information set or a particular device information element. -func (deviceInfoSet DevInfo) DeviceInstallParams(deviceInfoData *DevInfoData) (*DevInstallParams, error) { - return SetupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData) -} - -//sys setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceInstanceIdW - -// SetupDiGetDeviceInstanceId function retrieves the instance ID of the device. -func SetupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (string, error) { - reqSize := uint32(1024) - for { - buf := make([]uint16, reqSize) - err := setupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData, &buf[0], uint32(len(buf)), &reqSize) - if err == windows.ERROR_INSUFFICIENT_BUFFER { - continue - } - if err != nil { - return "", err - } - return windows.UTF16ToString(buf), nil - } -} - -// DeviceInstanceID method retrieves the instance ID of the device. -func (deviceInfoSet DevInfo) DeviceInstanceID(deviceInfoData *DevInfoData) (string, error) { - return SetupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData) -} - -// SetupDiGetClassInstallParams function retrieves class installation parameters for a device information set or a particular device information element. -//sys SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetClassInstallParamsW - -// ClassInstallParams method retrieves class installation parameters for a device information set or a particular device information element. -func (deviceInfoSet DevInfo) ClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) error { - return SetupDiGetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize, requiredSize) -} - -//sys SetupDiSetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiSetDeviceInstallParamsW - -// SetDeviceInstallParams member sets device installation parameters for a device information set or a particular device information element. -func (deviceInfoSet DevInfo) SetDeviceInstallParams(deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) error { - return SetupDiSetDeviceInstallParams(deviceInfoSet, deviceInfoData, deviceInstallParams) -} - -// SetupDiSetClassInstallParams function sets or clears class install parameters for a device information set or a particular device information element. -//sys SetupDiSetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) (err error) = setupapi.SetupDiSetClassInstallParamsW - -// SetClassInstallParams method sets or clears class install parameters for a device information set or a particular device information element. -func (deviceInfoSet DevInfo) SetClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) error { - return SetupDiSetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize) -} - -//sys setupDiClassNameFromGuidEx(classGUID *windows.GUID, className *uint16, classNameSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) = setupapi.SetupDiClassNameFromGuidExW - -// SetupDiClassNameFromGuidEx function retrieves the class name associated with a class GUID. The class can be installed on a local or remote computer. -func SetupDiClassNameFromGuidEx(classGUID *windows.GUID, machineName string) (className string, err error) { - var classNameUTF16 [MAX_CLASS_NAME_LEN]uint16 - - var machineNameUTF16 *uint16 - if machineName != "" { - machineNameUTF16, err = windows.UTF16PtrFromString(machineName) - if err != nil { - return - } - } - - err = setupDiClassNameFromGuidEx(classGUID, &classNameUTF16[0], MAX_CLASS_NAME_LEN, nil, machineNameUTF16, 0) - if err != nil { - return - } - - className = windows.UTF16ToString(classNameUTF16[:]) - return -} - -//sys setupDiClassGuidsFromNameEx(className *uint16, classGuidList *windows.GUID, classGuidListSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) = setupapi.SetupDiClassGuidsFromNameExW - -// SetupDiClassGuidsFromNameEx function retrieves the GUIDs associated with the specified class name. This resulting list contains the classes currently installed on a local or remote computer. -func SetupDiClassGuidsFromNameEx(className string, machineName string) ([]windows.GUID, error) { - classNameUTF16, err := windows.UTF16PtrFromString(className) - if err != nil { - return nil, err - } - - var machineNameUTF16 *uint16 - if machineName != "" { - machineNameUTF16, err = windows.UTF16PtrFromString(machineName) - if err != nil { - return nil, err - } - } - - reqSize := uint32(4) - for { - buf := make([]windows.GUID, reqSize) - err = setupDiClassGuidsFromNameEx(classNameUTF16, &buf[0], uint32(len(buf)), &reqSize, machineNameUTF16, 0) - if err == windows.ERROR_INSUFFICIENT_BUFFER { - continue - } - if err != nil { - return nil, err - } - return buf[:reqSize], nil - } -} - -//sys setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiGetSelectedDevice - -// SetupDiGetSelectedDevice function retrieves the selected device information element in a device information set. -func SetupDiGetSelectedDevice(deviceInfoSet DevInfo) (*DevInfoData, error) { - data := &DevInfoData{} - data.size = uint32(unsafe.Sizeof(*data)) - - return data, setupDiGetSelectedDevice(deviceInfoSet, data) -} - -// SelectedDevice method retrieves the selected device information element in a device information set. -func (deviceInfoSet DevInfo) SelectedDevice() (*DevInfoData, error) { - return SetupDiGetSelectedDevice(deviceInfoSet) -} - -// SetupDiSetSelectedDevice function sets a device information element as the selected member of a device information set. This function is typically used by an installation wizard. -//sys SetupDiSetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiSetSelectedDevice - -// SetSelectedDevice method sets a device information element as the selected member of a device information set. This function is typically used by an installation wizard. -func (deviceInfoSet DevInfo) SetSelectedDevice(deviceInfoData *DevInfoData) error { - return SetupDiSetSelectedDevice(deviceInfoSet, deviceInfoData) -} - -//sys cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_List_SizeW -//sys cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_ListW - -func CM_Get_Device_Interface_List(deviceID string, interfaceClass *windows.GUID, flags uint32) ([]string, error) { - deviceID16, err := windows.UTF16PtrFromString(deviceID) - if err != nil { - return nil, err - } - var buf []uint16 - var buflen uint32 - for { - if ret := cm_Get_Device_Interface_List_Size(&buflen, interfaceClass, deviceID16, flags); ret != CR_SUCCESS { - return nil, fmt.Errorf("CfgMgr error: 0x%x", ret) - } - buf = make([]uint16, buflen) - if ret := cm_Get_Device_Interface_List(interfaceClass, deviceID16, &buf[0], buflen, flags); ret == CR_SUCCESS { - break - } else if ret != CR_BUFFER_SMALL { - return nil, fmt.Errorf("CfgMgr error: 0x%x", ret) - } - } - var interfaces []string - for i := 0; i < len(buf); { - j := i + wcslen(buf[i:]) - if i < j { - interfaces = append(interfaces, windows.UTF16ToString(buf[i:j])) - } - i = j + 1 - } - if interfaces == nil { - return nil, fmt.Errorf("no interfaces found") - } - return interfaces, nil -} diff --git a/tun/wintun/setupapi/setupapi_windows_test.go b/tun/wintun/setupapi/setupapi_windows_test.go deleted file mode 100644 index a9e6b89..0000000 --- a/tun/wintun/setupapi/setupapi_windows_test.go +++ /dev/null @@ -1,488 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -import ( - "runtime" - "strings" - "testing" - - "golang.org/x/sys/windows" -) - -var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}} -var computerName string - -func init() { - computerName, _ = windows.ComputerName() -} - -func TestSetupDiCreateDeviceInfoListEx(t *testing.T) { - devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "") - if err != nil { - t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error()) - } else { - devInfoList.Close() - } - - devInfoList, err = SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName) - if err != nil { - t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error()) - } else { - devInfoList.Close() - } - - devInfoList, err = SetupDiCreateDeviceInfoListEx(nil, 0, "") - if err != nil { - t.Errorf("Error calling SetupDiCreateDeviceInfoListEx(nil): %s", err.Error()) - } else { - devInfoList.Close() - } -} - -func TestSetupDiGetDeviceInfoListDetail(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - data, err := devInfoList.DeviceInfoListDetail() - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error()) - } else { - if data.ClassGUID != deviceClassNetGUID { - t.Error("SetupDiGetDeviceInfoListDetail returned different class GUID") - } - - if data.RemoteMachineHandle != windows.Handle(0) { - t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine handle") - } - - if data.RemoteMachineName() != "" { - t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine name") - } - } - - devInfoList, err = SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), computerName) - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - data, err = devInfoList.DeviceInfoListDetail() - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error()) - } else { - if data.ClassGUID != deviceClassNetGUID { - t.Error("SetupDiGetDeviceInfoListDetail returned different class GUID") - } - - if data.RemoteMachineHandle == windows.Handle(0) { - t.Error("SetupDiGetDeviceInfoListDetail returned NULL remote machine handle") - } - - if data.RemoteMachineName() != computerName { - t.Error("SetupDiGetDeviceInfoListDetail returned different remote machine name") - } - } - - data = &DevInfoListDetailData{} - data.SetRemoteMachineName("foobar") - if data.RemoteMachineName() != "foobar" { - t.Error("DevInfoListDetailData.(Get|Set)RemoteMachineName() differ") - } -} - -func TestSetupDiCreateDeviceInfo(t *testing.T) { - devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName) - if err != nil { - t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error()) - } - defer devInfoList.Close() - - deviceClassNetName, err := SetupDiClassNameFromGuidEx(&deviceClassNetGUID, computerName) - if err != nil { - t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error()) - } - - devInfoData, err := devInfoList.CreateDeviceInfo(deviceClassNetName, &deviceClassNetGUID, "This is a test device", 0, DICD_GENERATE_ID) - if err != nil { - // Access denied is expected, as the SetupDiCreateDeviceInfo() require elevation to succeed. - if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_ACCESS_DENIED { - t.Errorf("Error calling SetupDiCreateDeviceInfo: %s", err.Error()) - } - } else if devInfoData.ClassGUID != deviceClassNetGUID { - t.Error("SetupDiCreateDeviceInfo returned different class GUID") - } -} - -func TestSetupDiEnumDeviceInfo(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - data, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - if data.ClassGUID != deviceClassNetGUID { - t.Error("SetupDiEnumDeviceInfo returned different class GUID") - } - - _, err = devInfoList.DeviceInstanceID(data) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceInstanceId: %s", err.Error()) - } - } -} - -func TestDevInfo_BuildDriverInfoList(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - deviceData, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - const driverType SPDIT = SPDIT_COMPATDRIVER - err = devInfoList.BuildDriverInfoList(deviceData, driverType) - if err != nil { - t.Errorf("Error calling SetupDiBuildDriverInfoList: %s", err.Error()) - } - defer devInfoList.DestroyDriverInfoList(deviceData, driverType) - - var selectedDriverData *DrvInfoData - for j := 0; true; j++ { - driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, j) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - if driverData.DriverType == 0 { - continue - } - - if !driverData.IsNewer(windows.Filetime{}, 0) { - t.Error("Driver should have non-zero date and version") - } - if !driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime}, 0) { - t.Error("Driver should have non-zero date and version") - } - if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime + 1}, 0) { - t.Error("Driver should report newer version on high-date-time") - } - if !driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, 0) { - t.Error("Driver should have non-zero version") - } - if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime + 1}, 0) { - t.Error("Driver should report newer version on low-date-time") - } - if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, driverData.DriverVersion) { - t.Error("Driver should not be newer than itself") - } - if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, driverData.DriverVersion+1) { - t.Error("Driver should report newer version on version") - } - - err = devInfoList.SetSelectedDriver(deviceData, driverData) - if err != nil { - t.Errorf("Error calling SetupDiSetSelectedDriver: %s", err.Error()) - } else { - selectedDriverData = driverData - } - - driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) - if err != nil { - t.Errorf("Error calling SetupDiGetDriverInfoDetail: %s", err.Error()) - } - - if driverDetailData.IsCompatible("foobar-aab6e3a4-144e-4786-88d3-6cec361e1edd") { - t.Error("Invalid HWID compatibitlity reported") - } - if !driverDetailData.IsCompatible(strings.ToUpper(driverDetailData.HardwareID())) { - t.Error("HWID compatibitlity missed") - } - a := driverDetailData.CompatIDs() - for k := range a { - if !driverDetailData.IsCompatible(strings.ToUpper(a[k])) { - t.Error("HWID compatibitlity missed") - } - } - } - - selectedDriverData2, err := devInfoList.SelectedDriver(deviceData) - if err != nil { - t.Errorf("Error calling SetupDiGetSelectedDriver: %s", err.Error()) - } else if *selectedDriverData != *selectedDriverData2 { - t.Error("SetupDiGetSelectedDriver should return driver selected with SetupDiSetSelectedDriver") - } - } - - data := &DrvInfoData{} - data.SetDescription("foobar") - if data.Description() != "foobar" { - t.Error("DrvInfoData.(Get|Set)Description() differ") - } - data.SetMfgName("foobar") - if data.MfgName() != "foobar" { - t.Error("DrvInfoData.(Get|Set)MfgName() differ") - } - data.SetProviderName("foobar") - if data.ProviderName() != "foobar" { - t.Error("DrvInfoData.(Get|Set)ProviderName() differ") - } -} - -func TestSetupDiGetClassDevsEx(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "PCI", 0, DIGCF_PRESENT, DevInfo(0), computerName) - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } else { - devInfoList.Close() - } - - devInfoList, err = SetupDiGetClassDevsEx(nil, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER { - t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail with ERROR_INVALID_PARAMETER") - } - } else { - devInfoList.Close() - t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail") - } -} - -func TestSetupDiOpenDevRegKey(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - data, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - key, err := devInfoList.OpenDevRegKey(data, DICS_FLAG_GLOBAL, 0, DIREG_DRV, windows.KEY_READ) - if err != nil { - t.Errorf("Error calling SetupDiOpenDevRegKey: %s", err.Error()) - } - defer key.Close() - } -} - -func TestSetupDiGetDeviceRegistryProperty(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - data, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - val, err := devInfoList.DeviceRegistryProperty(data, SPDRP_CLASS) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASS): %s", err.Error()) - } else if class, ok := val.(string); !ok || strings.ToLower(class) != "net" { - t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASS) should return \"Net\"") - } - - val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CLASSGUID) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error()) - } else if valStr, ok := val.(string); !ok { - t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return string") - } else { - classGUID, err := windows.GUIDFromString(valStr) - if err != nil { - t.Errorf("Error parsing GUID returned by SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error()) - } else if classGUID != deviceClassNetGUID { - t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return %x", deviceClassNetGUID) - } - } - - val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_COMPATIBLEIDS) - if err != nil { - // Some devices have no SPDRP_COMPATIBLEIDS. - if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_DATA { - t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_COMPATIBLEIDS): %s", err.Error()) - } - } - - val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CONFIGFLAGS) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CONFIGFLAGS): %s", err.Error()) - } - - val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_DEVICE_POWER_DATA) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_DEVICE_POWER_DATA): %s", err.Error()) - } - } -} - -func TestSetupDiGetDeviceInstallParams(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - data, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - _, err = devInfoList.DeviceInstallParams(data) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceInstallParams: %s", err.Error()) - } - } - - params := &DevInstallParams{} - params.SetDriverPath("foobar") - if params.DriverPath() != "foobar" { - t.Error("DevInstallParams.(Get|Set)DriverPath() differ") - } -} - -func TestSetupDiClassNameFromGuidEx(t *testing.T) { - deviceClassNetName, err := SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "") - if err != nil { - t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error()) - } else if strings.ToLower(deviceClassNetName) != "net" { - t.Errorf("SetupDiClassNameFromGuidEx(%x) should return \"Net\"", deviceClassNetGUID) - } - - deviceClassNetName, err = SetupDiClassNameFromGuidEx(&deviceClassNetGUID, computerName) - if err != nil { - t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error()) - } else if strings.ToLower(deviceClassNetName) != "net" { - t.Errorf("SetupDiClassNameFromGuidEx(%x) should return \"Net\"", deviceClassNetGUID) - } - - _, err = SetupDiClassNameFromGuidEx(nil, "") - if err != nil { - if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_USER_BUFFER { - t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail with ERROR_INVALID_USER_BUFFER") - } - } else { - t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail") - } -} - -func TestSetupDiClassGuidsFromNameEx(t *testing.T) { - ClassGUIDs, err := SetupDiClassGuidsFromNameEx("Net", "") - if err != nil { - t.Errorf("Error calling SetupDiClassGuidsFromNameEx: %s", err.Error()) - } else { - found := false - for i := range ClassGUIDs { - if ClassGUIDs[i] == deviceClassNetGUID { - found = true - break - } - } - if !found { - t.Errorf("SetupDiClassGuidsFromNameEx(\"Net\") should return %x", deviceClassNetGUID) - } - } - - ClassGUIDs, err = SetupDiClassGuidsFromNameEx("foobar-34274a51-a6e6-45f0-80d6-c62be96dd5fe", computerName) - if err != nil { - t.Errorf("Error calling SetupDiClassGuidsFromNameEx: %s", err.Error()) - } else if len(ClassGUIDs) != 0 { - t.Errorf("SetupDiClassGuidsFromNameEx(\"foobar-34274a51-a6e6-45f0-80d6-c62be96dd5fe\") should return an empty GUID set") - } -} - -func TestSetupDiGetSelectedDevice(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - data, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - err = devInfoList.SetSelectedDevice(data) - if err != nil { - t.Errorf("Error calling SetupDiSetSelectedDevice: %s", err.Error()) - } - - data2, err := devInfoList.SelectedDevice() - if err != nil { - t.Errorf("Error calling SetupDiGetSelectedDevice: %s", err.Error()) - } else if *data != *data2 { - t.Error("SetupDiGetSelectedDevice returned different data than was set by SetupDiSetSelectedDevice") - } - } - - err = devInfoList.SetSelectedDevice(nil) - if err != nil { - if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER { - t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER") - } - } else { - t.Errorf("SetupDiSetSelectedDevice(nil) should fail") - } -} - -func TestUTF16ToBuf(t *testing.T) { - buf := []uint16{0x0123, 0x4567, 0x89ab, 0xcdef} - buf2 := utf16ToBuf(buf) - if len(buf)*2 != len(buf2) || - cap(buf)*2 != cap(buf2) || - buf2[0] != 0x23 || buf2[1] != 0x01 || - buf2[2] != 0x67 || buf2[3] != 0x45 || - buf2[4] != 0xab || buf2[5] != 0x89 || - buf2[6] != 0xef || buf2[7] != 0xcd { - t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER") - } - runtime.KeepAlive(buf) -} diff --git a/tun/wintun/setupapi/types_windows.go b/tun/wintun/setupapi/types_windows.go deleted file mode 100644 index 136b4be..0000000 --- a/tun/wintun/setupapi/types_windows.go +++ /dev/null @@ -1,568 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -import ( - "strings" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - MAX_DEVICE_ID_LEN = 200 - MAX_DEVNODE_ID_LEN = MAX_DEVICE_ID_LEN - MAX_GUID_STRING_LEN = 39 // 38 chars + terminator null - MAX_CLASS_NAME_LEN = 32 - MAX_PROFILE_LEN = 80 - MAX_CONFIG_VALUE = 9999 - MAX_INSTANCE_VALUE = 9999 - CONFIGMG_VERSION = 0x0400 -) - -// -// Define maximum string length constants -// -const ( - ANYSIZE_ARRAY = 1 - LINE_LEN = 256 // Windows 9x-compatible maximum for displayable strings coming from a device INF. - MAX_INF_STRING_LENGTH = 4096 // Actual maximum size of an INF string (including string substitutions). - MAX_INF_SECTION_NAME_LENGTH = 255 // For Windows 9x compatibility, INF section names should be constrained to 32 characters. - MAX_TITLE_LEN = 60 - MAX_INSTRUCTION_LEN = 256 - MAX_LABEL_LEN = 30 - MAX_SERVICE_NAME_LEN = 256 - MAX_SUBTITLE_LEN = 256 -) - -const ( - // SP_MAX_MACHINENAME_LENGTH defines maximum length of a machine name in the format expected by ConfigMgr32 CM_Connect_Machine (i.e., "\\\\MachineName\0"). - SP_MAX_MACHINENAME_LENGTH = windows.MAX_PATH + 3 -) - -// HSPFILEQ is type for setup file queue -type HSPFILEQ uintptr - -// DevInfo holds reference to device information set -type DevInfo windows.Handle - -// DevInfoData is a device information structure (references a device instance that is a member of a device information set) -type DevInfoData struct { - size uint32 - ClassGUID windows.GUID - DevInst uint32 // DEVINST handle - _ uintptr -} - -// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass). -type DevInfoListDetailData struct { - size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const. - ClassGUID windows.GUID - RemoteMachineHandle windows.Handle - remoteMachineName [SP_MAX_MACHINENAME_LENGTH]uint16 -} - -func (data *DevInfoListDetailData) RemoteMachineName() string { - return windows.UTF16ToString(data.remoteMachineName[:]) -} - -func (data *DevInfoListDetailData) SetRemoteMachineName(remoteMachineName string) error { - str, err := windows.UTF16FromString(remoteMachineName) - if err != nil { - return err - } - copy(data.remoteMachineName[:], str) - return nil -} - -// DI_FUNCTION is function type for device installer -type DI_FUNCTION uint32 - -const ( - DIF_SELECTDEVICE DI_FUNCTION = 0x00000001 - DIF_INSTALLDEVICE DI_FUNCTION = 0x00000002 - DIF_ASSIGNRESOURCES DI_FUNCTION = 0x00000003 - DIF_PROPERTIES DI_FUNCTION = 0x00000004 - DIF_REMOVE DI_FUNCTION = 0x00000005 - DIF_FIRSTTIMESETUP DI_FUNCTION = 0x00000006 - DIF_FOUNDDEVICE DI_FUNCTION = 0x00000007 - DIF_SELECTCLASSDRIVERS DI_FUNCTION = 0x00000008 - DIF_VALIDATECLASSDRIVERS DI_FUNCTION = 0x00000009 - DIF_INSTALLCLASSDRIVERS DI_FUNCTION = 0x0000000A - DIF_CALCDISKSPACE DI_FUNCTION = 0x0000000B - DIF_DESTROYPRIVATEDATA DI_FUNCTION = 0x0000000C - DIF_VALIDATEDRIVER DI_FUNCTION = 0x0000000D - DIF_DETECT DI_FUNCTION = 0x0000000F - DIF_INSTALLWIZARD DI_FUNCTION = 0x00000010 - DIF_DESTROYWIZARDDATA DI_FUNCTION = 0x00000011 - DIF_PROPERTYCHANGE DI_FUNCTION = 0x00000012 - DIF_ENABLECLASS DI_FUNCTION = 0x00000013 - DIF_DETECTVERIFY DI_FUNCTION = 0x00000014 - DIF_INSTALLDEVICEFILES DI_FUNCTION = 0x00000015 - DIF_UNREMOVE DI_FUNCTION = 0x00000016 - DIF_SELECTBESTCOMPATDRV DI_FUNCTION = 0x00000017 - DIF_ALLOW_INSTALL DI_FUNCTION = 0x00000018 - DIF_REGISTERDEVICE DI_FUNCTION = 0x00000019 - DIF_NEWDEVICEWIZARD_PRESELECT DI_FUNCTION = 0x0000001A - DIF_NEWDEVICEWIZARD_SELECT DI_FUNCTION = 0x0000001B - DIF_NEWDEVICEWIZARD_PREANALYZE DI_FUNCTION = 0x0000001C - DIF_NEWDEVICEWIZARD_POSTANALYZE DI_FUNCTION = 0x0000001D - DIF_NEWDEVICEWIZARD_FINISHINSTALL DI_FUNCTION = 0x0000001E - DIF_INSTALLINTERFACES DI_FUNCTION = 0x00000020 - DIF_DETECTCANCEL DI_FUNCTION = 0x00000021 - DIF_REGISTER_COINSTALLERS DI_FUNCTION = 0x00000022 - DIF_ADDPROPERTYPAGE_ADVANCED DI_FUNCTION = 0x00000023 - DIF_ADDPROPERTYPAGE_BASIC DI_FUNCTION = 0x00000024 - DIF_TROUBLESHOOTER DI_FUNCTION = 0x00000026 - DIF_POWERMESSAGEWAKE DI_FUNCTION = 0x00000027 - DIF_ADDREMOTEPROPERTYPAGE_ADVANCED DI_FUNCTION = 0x00000028 - DIF_UPDATEDRIVER_UI DI_FUNCTION = 0x00000029 - DIF_FINISHINSTALL_ACTION DI_FUNCTION = 0x0000002A -) - -// DevInstallParams is device installation parameters structure (associated with a particular device information element, or globally with a device information set) -type DevInstallParams struct { - size uint32 - Flags DI_FLAGS - FlagsEx DI_FLAGSEX - hwndParent uintptr - InstallMsgHandler uintptr - InstallMsgHandlerContext uintptr - FileQueue HSPFILEQ - _ uintptr - _ uint32 - driverPath [windows.MAX_PATH]uint16 -} - -func (params *DevInstallParams) DriverPath() string { - return windows.UTF16ToString(params.driverPath[:]) -} - -func (params *DevInstallParams) SetDriverPath(driverPath string) error { - str, err := windows.UTF16FromString(driverPath) - if err != nil { - return err - } - copy(params.driverPath[:], str) - return nil -} - -// DI_FLAGS is SP_DEVINSTALL_PARAMS.Flags values -type DI_FLAGS uint32 - -const ( - // Flags for choosing a device - DI_SHOWOEM DI_FLAGS = 0x00000001 // support Other... button - DI_SHOWCOMPAT DI_FLAGS = 0x00000002 // show compatibility list - DI_SHOWCLASS DI_FLAGS = 0x00000004 // show class list - DI_SHOWALL DI_FLAGS = 0x00000007 // both class & compat list shown - DI_NOVCP DI_FLAGS = 0x00000008 // don't create a new copy queue--use caller-supplied FileQueue - DI_DIDCOMPAT DI_FLAGS = 0x00000010 // Searched for compatible devices - DI_DIDCLASS DI_FLAGS = 0x00000020 // Searched for class devices - DI_AUTOASSIGNRES DI_FLAGS = 0x00000040 // No UI for resources if possible - - // Flags returned by DiInstallDevice to indicate need to reboot/restart - DI_NEEDRESTART DI_FLAGS = 0x00000080 // Reboot required to take effect - DI_NEEDREBOOT DI_FLAGS = 0x00000100 // "" - - // Flags for device installation - DI_NOBROWSE DI_FLAGS = 0x00000200 // no Browse... in InsertDisk - - // Flags set by DiBuildDriverInfoList - DI_MULTMFGS DI_FLAGS = 0x00000400 // Set if multiple manufacturers in class driver list - - // Flag indicates that device is disabled - DI_DISABLED DI_FLAGS = 0x00000800 // Set if device disabled - - // Flags for Device/Class Properties - DI_GENERALPAGE_ADDED DI_FLAGS = 0x00001000 - DI_RESOURCEPAGE_ADDED DI_FLAGS = 0x00002000 - - // Flag to indicate the setting properties for this Device (or class) caused a change so the Dev Mgr UI probably needs to be updated. - DI_PROPERTIES_CHANGE DI_FLAGS = 0x00004000 - - // Flag to indicate that the sorting from the INF file should be used. - DI_INF_IS_SORTED DI_FLAGS = 0x00008000 - - // Flag to indicate that only the the INF specified by SP_DEVINSTALL_PARAMS.DriverPath should be searched. - DI_ENUMSINGLEINF DI_FLAGS = 0x00010000 - - // Flag that prevents ConfigMgr from removing/re-enumerating devices during device - // registration, installation, and deletion. - DI_DONOTCALLCONFIGMG DI_FLAGS = 0x00020000 - - // The following flag can be used to install a device disabled - DI_INSTALLDISABLED DI_FLAGS = 0x00040000 - - // Flag that causes SetupDiBuildDriverInfoList to build a device's compatible driver - // list from its existing class driver list, instead of the normal INF search. - DI_COMPAT_FROM_CLASS DI_FLAGS = 0x00080000 - - // This flag is set if the Class Install params should be used. - DI_CLASSINSTALLPARAMS DI_FLAGS = 0x00100000 - - // This flag is set if the caller of DiCallClassInstaller does NOT want the internal default action performed if the Class installer returns ERROR_DI_DO_DEFAULT. - DI_NODI_DEFAULTACTION DI_FLAGS = 0x00200000 - - // Flags for device installation - DI_QUIETINSTALL DI_FLAGS = 0x00800000 // don't confuse the user with questions or excess info - DI_NOFILECOPY DI_FLAGS = 0x01000000 // No file Copy necessary - DI_FORCECOPY DI_FLAGS = 0x02000000 // Force files to be copied from install path - DI_DRIVERPAGE_ADDED DI_FLAGS = 0x04000000 // Prop provider added Driver page. - DI_USECI_SELECTSTRINGS DI_FLAGS = 0x08000000 // Use Class Installer Provided strings in the Select Device Dlg - DI_OVERRIDE_INFFLAGS DI_FLAGS = 0x10000000 // Override INF flags - DI_PROPS_NOCHANGEUSAGE DI_FLAGS = 0x20000000 // No Enable/Disable in General Props - - DI_NOSELECTICONS DI_FLAGS = 0x40000000 // No small icons in select device dialogs - - DI_NOWRITE_IDS DI_FLAGS = 0x80000000 // Don't write HW & Compat IDs on install -) - -// DI_FLAGSEX is SP_DEVINSTALL_PARAMS.FlagsEx values -type DI_FLAGSEX uint32 - -const ( - DI_FLAGSEX_CI_FAILED DI_FLAGSEX = 0x00000004 // Failed to Load/Call class installer - DI_FLAGSEX_FINISHINSTALL_ACTION DI_FLAGSEX = 0x00000008 // Class/co-installer wants to get a DIF_FINISH_INSTALL action in client context. - DI_FLAGSEX_DIDINFOLIST DI_FLAGSEX = 0x00000010 // Did the Class Info List - DI_FLAGSEX_DIDCOMPATINFO DI_FLAGSEX = 0x00000020 // Did the Compat Info List - DI_FLAGSEX_FILTERCLASSES DI_FLAGSEX = 0x00000040 - DI_FLAGSEX_SETFAILEDINSTALL DI_FLAGSEX = 0x00000080 - DI_FLAGSEX_DEVICECHANGE DI_FLAGSEX = 0x00000100 - DI_FLAGSEX_ALWAYSWRITEIDS DI_FLAGSEX = 0x00000200 - DI_FLAGSEX_PROPCHANGE_PENDING DI_FLAGSEX = 0x00000400 // One or more device property sheets have had changes made to them, and need to have a DIF_PROPERTYCHANGE occur. - DI_FLAGSEX_ALLOWEXCLUDEDDRVS DI_FLAGSEX = 0x00000800 - DI_FLAGSEX_NOUIONQUERYREMOVE DI_FLAGSEX = 0x00001000 - DI_FLAGSEX_USECLASSFORCOMPAT DI_FLAGSEX = 0x00002000 // Use the device's class when building compat drv list. (Ignored if DI_COMPAT_FROM_CLASS flag is specified.) - DI_FLAGSEX_NO_DRVREG_MODIFY DI_FLAGSEX = 0x00008000 // Don't run AddReg and DelReg for device's software (driver) key. - DI_FLAGSEX_IN_SYSTEM_SETUP DI_FLAGSEX = 0x00010000 // Installation is occurring during initial system setup. - DI_FLAGSEX_INET_DRIVER DI_FLAGSEX = 0x00020000 // Driver came from Windows Update - DI_FLAGSEX_APPENDDRIVERLIST DI_FLAGSEX = 0x00040000 // Cause SetupDiBuildDriverInfoList to append a new driver list to an existing list. - DI_FLAGSEX_PREINSTALLBACKUP DI_FLAGSEX = 0x00080000 // not used - DI_FLAGSEX_BACKUPONREPLACE DI_FLAGSEX = 0x00100000 // not used - DI_FLAGSEX_DRIVERLIST_FROM_URL DI_FLAGSEX = 0x00200000 // build driver list from INF(s) retrieved from URL specified in SP_DEVINSTALL_PARAMS.DriverPath (empty string means Windows Update website) - DI_FLAGSEX_EXCLUDE_OLD_INET_DRIVERS DI_FLAGSEX = 0x00800000 // Don't include old Internet drivers when building a driver list. Ignored on Windows Vista and later. - DI_FLAGSEX_POWERPAGE_ADDED DI_FLAGSEX = 0x01000000 // class installer added their own power page - DI_FLAGSEX_FILTERSIMILARDRIVERS DI_FLAGSEX = 0x02000000 // only include similar drivers in class list - DI_FLAGSEX_INSTALLEDDRIVER DI_FLAGSEX = 0x04000000 // only add the installed driver to the class or compat driver list. Used in calls to SetupDiBuildDriverInfoList - DI_FLAGSEX_NO_CLASSLIST_NODE_MERGE DI_FLAGSEX = 0x08000000 // Don't remove identical driver nodes from the class list - DI_FLAGSEX_ALTPLATFORM_DRVSEARCH DI_FLAGSEX = 0x10000000 // Build driver list based on alternate platform information specified in associated file queue - DI_FLAGSEX_RESTART_DEVICE_ONLY DI_FLAGSEX = 0x20000000 // only restart the device drivers are being installed on as opposed to restarting all devices using those drivers. - DI_FLAGSEX_RECURSIVESEARCH DI_FLAGSEX = 0x40000000 // Tell SetupDiBuildDriverInfoList to do a recursive search - DI_FLAGSEX_SEARCH_PUBLISHED_INFS DI_FLAGSEX = 0x80000000 // Tell SetupDiBuildDriverInfoList to do a "published INF" search -) - -// ClassInstallHeader is the first member of any class install parameters structure. It contains the device installation request code that defines the format of the rest of the install parameters structure. -type ClassInstallHeader struct { - size uint32 - InstallFunction DI_FUNCTION -} - -func MakeClassInstallHeader(installFunction DI_FUNCTION) *ClassInstallHeader { - hdr := &ClassInstallHeader{InstallFunction: installFunction} - hdr.size = uint32(unsafe.Sizeof(*hdr)) - return hdr -} - -// DICS_STATE specifies values indicating a change in a device's state -type DICS_STATE uint32 - -const ( - DICS_ENABLE DICS_STATE = 0x00000001 // The device is being enabled. - DICS_DISABLE DICS_STATE = 0x00000002 // The device is being disabled. - DICS_PROPCHANGE DICS_STATE = 0x00000003 // The properties of the device have changed. - DICS_START DICS_STATE = 0x00000004 // The device is being started (if the request is for the currently active hardware profile). - DICS_STOP DICS_STATE = 0x00000005 // The device is being stopped. The driver stack will be unloaded and the CSCONFIGFLAG_DO_NOT_START flag will be set for the device. -) - -// DICS_FLAG specifies the scope of a device property change -type DICS_FLAG uint32 - -const ( - DICS_FLAG_GLOBAL DICS_FLAG = 0x00000001 // make change in all hardware profiles - DICS_FLAG_CONFIGSPECIFIC DICS_FLAG = 0x00000002 // make change in specified profile only - DICS_FLAG_CONFIGGENERAL DICS_FLAG = 0x00000004 // 1 or more hardware profile-specific changes to follow (obsolete) -) - -// PropChangeParams is a structure corresponding to a DIF_PROPERTYCHANGE install function. -type PropChangeParams struct { - ClassInstallHeader ClassInstallHeader - StateChange DICS_STATE - Scope DICS_FLAG - HwProfile uint32 -} - -// DI_REMOVEDEVICE specifies the scope of the device removal -type DI_REMOVEDEVICE uint32 - -const ( - DI_REMOVEDEVICE_GLOBAL DI_REMOVEDEVICE = 0x00000001 // Make this change in all hardware profiles. Remove information about the device from the registry. - DI_REMOVEDEVICE_CONFIGSPECIFIC DI_REMOVEDEVICE = 0x00000002 // Make this change to only the hardware profile specified by HwProfile. this flag only applies to root-enumerated devices. When Windows removes the device from the last hardware profile in which it was configured, Windows performs a global removal. -) - -// RemoveDeviceParams is a structure corresponding to a DIF_REMOVE install function. -type RemoveDeviceParams struct { - ClassInstallHeader ClassInstallHeader - Scope DI_REMOVEDEVICE - HwProfile uint32 -} - -// DrvInfoData is driver information structure (member of a driver info list that may be associated with a particular device instance, or (globally) with a device information set) -type DrvInfoData struct { - size uint32 - DriverType uint32 - _ uintptr - description [LINE_LEN]uint16 - mfgName [LINE_LEN]uint16 - providerName [LINE_LEN]uint16 - DriverDate windows.Filetime - DriverVersion uint64 -} - -func (data *DrvInfoData) Description() string { - return windows.UTF16ToString(data.description[:]) -} - -func (data *DrvInfoData) SetDescription(description string) error { - str, err := windows.UTF16FromString(description) - if err != nil { - return err - } - copy(data.description[:], str) - return nil -} - -func (data *DrvInfoData) MfgName() string { - return windows.UTF16ToString(data.mfgName[:]) -} - -func (data *DrvInfoData) SetMfgName(mfgName string) error { - str, err := windows.UTF16FromString(mfgName) - if err != nil { - return err - } - copy(data.mfgName[:], str) - return nil -} - -func (data *DrvInfoData) ProviderName() string { - return windows.UTF16ToString(data.providerName[:]) -} - -func (data *DrvInfoData) SetProviderName(providerName string) error { - str, err := windows.UTF16FromString(providerName) - if err != nil { - return err - } - copy(data.providerName[:], str) - return nil -} - -// IsNewer method returns true if DrvInfoData date and version is newer than supplied parameters. -func (data *DrvInfoData) IsNewer(driverDate windows.Filetime, driverVersion uint64) bool { - if data.DriverDate.HighDateTime > driverDate.HighDateTime { - return true - } - if data.DriverDate.HighDateTime < driverDate.HighDateTime { - return false - } - - if data.DriverDate.LowDateTime > driverDate.LowDateTime { - return true - } - if data.DriverDate.LowDateTime < driverDate.LowDateTime { - return false - } - - if data.DriverVersion > driverVersion { - return true - } - if data.DriverVersion < driverVersion { - return false - } - - return false -} - -// DrvInfoDetailData is driver information details structure (provides detailed information about a particular driver information structure) -type DrvInfoDetailData struct { - size uint32 // Warning: unsafe.Sizeof(DrvInfoDetailData) > sizeof(SP_DRVINFO_DETAIL_DATA) when GOARCH == 386 => use sizeofDrvInfoDetailData const. - InfDate windows.Filetime - compatIDsOffset uint32 - compatIDsLength uint32 - _ uintptr - sectionName [LINE_LEN]uint16 - infFileName [windows.MAX_PATH]uint16 - drvDescription [LINE_LEN]uint16 - hardwareID [ANYSIZE_ARRAY]uint16 -} - -func (data *DrvInfoDetailData) SectionName() string { - return windows.UTF16ToString(data.sectionName[:]) -} - -func (data *DrvInfoDetailData) InfFileName() string { - return windows.UTF16ToString(data.infFileName[:]) -} - -func (data *DrvInfoDetailData) DrvDescription() string { - return windows.UTF16ToString(data.drvDescription[:]) -} - -func (data *DrvInfoDetailData) HardwareID() string { - if data.compatIDsOffset > 1 { - bufW := data.getBuf() - return windows.UTF16ToString(bufW[:wcslen(bufW)]) - } - - return "" -} - -func (data *DrvInfoDetailData) CompatIDs() []string { - a := make([]string, 0) - - if data.compatIDsLength > 0 { - bufW := data.getBuf() - bufW = bufW[data.compatIDsOffset : data.compatIDsOffset+data.compatIDsLength] - for i := 0; i < len(bufW); { - j := i + wcslen(bufW[i:]) - if i < j { - a = append(a, windows.UTF16ToString(bufW[i:j])) - } - i = j + 1 - } - } - - return a -} - -func (data *DrvInfoDetailData) getBuf() []uint16 { - len := (data.size - uint32(unsafe.Offsetof(data.hardwareID))) / 2 - sl := struct { - addr *uint16 - len int - cap int - }{&data.hardwareID[0], int(len), int(len)} - return *(*[]uint16)(unsafe.Pointer(&sl)) -} - -// IsCompatible method tests if given hardware ID matches the driver or is listed on the compatible ID list. -func (data *DrvInfoDetailData) IsCompatible(hwid string) bool { - hwidLC := strings.ToLower(hwid) - if strings.ToLower(data.HardwareID()) == hwidLC { - return true - } - a := data.CompatIDs() - for i := range a { - if strings.ToLower(a[i]) == hwidLC { - return true - } - } - - return false -} - -// DICD flags control SetupDiCreateDeviceInfo -type DICD uint32 - -const ( - DICD_GENERATE_ID DICD = 0x00000001 - DICD_INHERIT_CLASSDRVS DICD = 0x00000002 -) - -// -// SPDIT flags to distinguish between class drivers and -// device drivers. -// (Passed in 'DriverType' parameter of driver information list APIs) -// -type SPDIT uint32 - -const ( - SPDIT_NODRIVER SPDIT = 0x00000000 - SPDIT_CLASSDRIVER SPDIT = 0x00000001 - SPDIT_COMPATDRIVER SPDIT = 0x00000002 -) - -// DIGCF flags control what is included in the device information set built by SetupDiGetClassDevs -type DIGCF uint32 - -const ( - DIGCF_DEFAULT DIGCF = 0x00000001 // only valid with DIGCF_DEVICEINTERFACE - DIGCF_PRESENT DIGCF = 0x00000002 - DIGCF_ALLCLASSES DIGCF = 0x00000004 - DIGCF_PROFILE DIGCF = 0x00000008 - DIGCF_DEVICEINTERFACE DIGCF = 0x00000010 -) - -// DIREG specifies values for SetupDiCreateDevRegKey, SetupDiOpenDevRegKey, and SetupDiDeleteDevRegKey. -type DIREG uint32 - -const ( - DIREG_DEV DIREG = 0x00000001 // Open/Create/Delete device key - DIREG_DRV DIREG = 0x00000002 // Open/Create/Delete driver key - DIREG_BOTH DIREG = 0x00000004 // Delete both driver and Device key -) - -// -// SPDRP specifies device registry property codes -// (Codes marked as read-only (R) may only be used for -// SetupDiGetDeviceRegistryProperty) -// -// These values should cover the same set of registry properties -// as defined by the CM_DRP codes in cfgmgr32.h. -// -// Note that SPDRP codes are zero based while CM_DRP codes are one based! -// -type SPDRP uint32 - -const ( - SPDRP_DEVICEDESC SPDRP = 0x00000000 // DeviceDesc (R/W) - SPDRP_HARDWAREID SPDRP = 0x00000001 // HardwareID (R/W) - SPDRP_COMPATIBLEIDS SPDRP = 0x00000002 // CompatibleIDs (R/W) - SPDRP_SERVICE SPDRP = 0x00000004 // Service (R/W) - SPDRP_CLASS SPDRP = 0x00000007 // Class (R--tied to ClassGUID) - SPDRP_CLASSGUID SPDRP = 0x00000008 // ClassGUID (R/W) - SPDRP_DRIVER SPDRP = 0x00000009 // Driver (R/W) - SPDRP_CONFIGFLAGS SPDRP = 0x0000000A // ConfigFlags (R/W) - SPDRP_MFG SPDRP = 0x0000000B // Mfg (R/W) - SPDRP_FRIENDLYNAME SPDRP = 0x0000000C // FriendlyName (R/W) - SPDRP_LOCATION_INFORMATION SPDRP = 0x0000000D // LocationInformation (R/W) - SPDRP_PHYSICAL_DEVICE_OBJECT_NAME SPDRP = 0x0000000E // PhysicalDeviceObjectName (R) - SPDRP_CAPABILITIES SPDRP = 0x0000000F // Capabilities (R) - SPDRP_UI_NUMBER SPDRP = 0x00000010 // UiNumber (R) - SPDRP_UPPERFILTERS SPDRP = 0x00000011 // UpperFilters (R/W) - SPDRP_LOWERFILTERS SPDRP = 0x00000012 // LowerFilters (R/W) - SPDRP_BUSTYPEGUID SPDRP = 0x00000013 // BusTypeGUID (R) - SPDRP_LEGACYBUSTYPE SPDRP = 0x00000014 // LegacyBusType (R) - SPDRP_BUSNUMBER SPDRP = 0x00000015 // BusNumber (R) - SPDRP_ENUMERATOR_NAME SPDRP = 0x00000016 // Enumerator Name (R) - SPDRP_SECURITY SPDRP = 0x00000017 // Security (R/W, binary form) - SPDRP_SECURITY_SDS SPDRP = 0x00000018 // Security (W, SDS form) - SPDRP_DEVTYPE SPDRP = 0x00000019 // Device Type (R/W) - SPDRP_EXCLUSIVE SPDRP = 0x0000001A // Device is exclusive-access (R/W) - SPDRP_CHARACTERISTICS SPDRP = 0x0000001B // Device Characteristics (R/W) - SPDRP_ADDRESS SPDRP = 0x0000001C // Device Address (R) - SPDRP_UI_NUMBER_DESC_FORMAT SPDRP = 0x0000001D // UiNumberDescFormat (R/W) - SPDRP_DEVICE_POWER_DATA SPDRP = 0x0000001E // Device Power Data (R) - SPDRP_REMOVAL_POLICY SPDRP = 0x0000001F // Removal Policy (R) - SPDRP_REMOVAL_POLICY_HW_DEFAULT SPDRP = 0x00000020 // Hardware Removal Policy (R) - SPDRP_REMOVAL_POLICY_OVERRIDE SPDRP = 0x00000021 // Removal Policy Override (RW) - SPDRP_INSTALL_STATE SPDRP = 0x00000022 // Device Install State (R) - SPDRP_LOCATION_PATHS SPDRP = 0x00000023 // Device Location Paths (R) - SPDRP_BASE_CONTAINERID SPDRP = 0x00000024 // Base ContainerID (R) - - SPDRP_MAXIMUM_PROPERTY SPDRP = 0x00000025 // Upper bound on ordinals -) - -const ( - CR_SUCCESS = 0x0 - CR_BUFFER_SMALL = 0x1a -) - -const ( - CM_GET_DEVICE_INTERFACE_LIST_PRESENT = 0 // only currently 'live' device interfaces - CM_GET_DEVICE_INTERFACE_LIST_ALL_DEVICES = 1 // all registered device interfaces, live or not -) diff --git a/tun/wintun/setupapi/types_windows_386.go b/tun/wintun/setupapi/types_windows_386.go deleted file mode 100644 index 132f921..0000000 --- a/tun/wintun/setupapi/types_windows_386.go +++ /dev/null @@ -1,11 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -const ( - sizeofDevInfoListDetailData uint32 = 550 - sizeofDrvInfoDetailData uint32 = 1570 -) diff --git a/tun/wintun/setupapi/types_windows_amd64.go b/tun/wintun/setupapi/types_windows_amd64.go deleted file mode 100644 index d4dd65c..0000000 --- a/tun/wintun/setupapi/types_windows_amd64.go +++ /dev/null @@ -1,11 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -const ( - sizeofDevInfoListDetailData uint32 = 560 - sizeofDrvInfoDetailData uint32 = 1584 -) diff --git a/tun/wintun/setupapi/zsetupapi_windows.go b/tun/wintun/setupapi/zsetupapi_windows.go deleted file mode 100644 index 375862d..0000000 --- a/tun/wintun/setupapi/zsetupapi_windows.go +++ /dev/null @@ -1,398 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package setupapi - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modsetupapi = windows.NewLazySystemDLL("setupapi.dll") - modCfgMgr32 = windows.NewLazySystemDLL("CfgMgr32.dll") - - procSetupDiCreateDeviceInfoListExW = modsetupapi.NewProc("SetupDiCreateDeviceInfoListExW") - procSetupDiGetDeviceInfoListDetailW = modsetupapi.NewProc("SetupDiGetDeviceInfoListDetailW") - procSetupDiCreateDeviceInfoW = modsetupapi.NewProc("SetupDiCreateDeviceInfoW") - procSetupDiEnumDeviceInfo = modsetupapi.NewProc("SetupDiEnumDeviceInfo") - procSetupDiDestroyDeviceInfoList = modsetupapi.NewProc("SetupDiDestroyDeviceInfoList") - procSetupDiBuildDriverInfoList = modsetupapi.NewProc("SetupDiBuildDriverInfoList") - procSetupDiCancelDriverInfoSearch = modsetupapi.NewProc("SetupDiCancelDriverInfoSearch") - procSetupDiEnumDriverInfoW = modsetupapi.NewProc("SetupDiEnumDriverInfoW") - procSetupDiGetSelectedDriverW = modsetupapi.NewProc("SetupDiGetSelectedDriverW") - procSetupDiSetSelectedDriverW = modsetupapi.NewProc("SetupDiSetSelectedDriverW") - procSetupDiGetDriverInfoDetailW = modsetupapi.NewProc("SetupDiGetDriverInfoDetailW") - procSetupDiDestroyDriverInfoList = modsetupapi.NewProc("SetupDiDestroyDriverInfoList") - procSetupDiGetClassDevsExW = modsetupapi.NewProc("SetupDiGetClassDevsExW") - procSetupDiCallClassInstaller = modsetupapi.NewProc("SetupDiCallClassInstaller") - procSetupDiOpenDevRegKey = modsetupapi.NewProc("SetupDiOpenDevRegKey") - procSetupDiGetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiGetDeviceRegistryPropertyW") - procSetupDiSetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiSetDeviceRegistryPropertyW") - procSetupDiGetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiGetDeviceInstallParamsW") - procSetupDiGetDeviceInstanceIdW = modsetupapi.NewProc("SetupDiGetDeviceInstanceIdW") - procSetupDiGetClassInstallParamsW = modsetupapi.NewProc("SetupDiGetClassInstallParamsW") - procSetupDiSetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiSetDeviceInstallParamsW") - procSetupDiSetClassInstallParamsW = modsetupapi.NewProc("SetupDiSetClassInstallParamsW") - procSetupDiClassNameFromGuidExW = modsetupapi.NewProc("SetupDiClassNameFromGuidExW") - procSetupDiClassGuidsFromNameExW = modsetupapi.NewProc("SetupDiClassGuidsFromNameExW") - procSetupDiGetSelectedDevice = modsetupapi.NewProc("SetupDiGetSelectedDevice") - procSetupDiSetSelectedDevice = modsetupapi.NewProc("SetupDiSetSelectedDevice") - procCM_Get_Device_Interface_List_SizeW = modCfgMgr32.NewProc("CM_Get_Device_Interface_List_SizeW") - procCM_Get_Device_Interface_ListW = modCfgMgr32.NewProc("CM_Get_Device_Interface_ListW") -) - -func setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) { - r0, _, e1 := syscall.Syscall6(procSetupDiCreateDeviceInfoListExW.Addr(), 4, uintptr(unsafe.Pointer(classGUID)), uintptr(hwndParent), uintptr(unsafe.Pointer(machineName)), uintptr(reserved), 0, 0) - handle = DevInfo(r0) - if handle == DevInfo(windows.InvalidHandle) { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo, deviceInfoSetDetailData *DevInfoListDetailData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiGetDeviceInfoListDetailW.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoSetDetailData)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiCreateDeviceInfo(deviceInfoSet DevInfo, DeviceName *uint16, classGUID *windows.GUID, DeviceDescription *uint16, hwndParent uintptr, CreationFlags DICD, deviceInfoData *DevInfoData) (err error) { - r1, _, e1 := syscall.Syscall9(procSetupDiCreateDeviceInfoW.Addr(), 7, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(DeviceName)), uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(DeviceDescription)), uintptr(hwndParent), uintptr(CreationFlags), uintptr(unsafe.Pointer(deviceInfoData)), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex uint32, deviceInfoData *DevInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiEnumDeviceInfo.Addr(), 3, uintptr(deviceInfoSet), uintptr(memberIndex), uintptr(unsafe.Pointer(deviceInfoData))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiDestroyDeviceInfoList(deviceInfoSet DevInfo) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiDestroyDeviceInfoList.Addr(), 1, uintptr(deviceInfoSet), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiBuildDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiBuildDriverInfoList.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType)) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiCancelDriverInfoSearch(deviceInfoSet DevInfo) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiCancelDriverInfoSearch.Addr(), 1, uintptr(deviceInfoSet), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex uint32, driverInfoData *DrvInfoData) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiEnumDriverInfoW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType), uintptr(memberIndex), uintptr(unsafe.Pointer(driverInfoData)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiGetSelectedDriverW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiSetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiSetSelectedDriverW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData, driverInfoDetailData *DrvInfoDetailData, driverInfoDetailDataSize uint32, requiredSize *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiGetDriverInfoDetailW.Addr(), 6, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData)), uintptr(unsafe.Pointer(driverInfoDetailData)), uintptr(driverInfoDetailDataSize), uintptr(unsafe.Pointer(requiredSize))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiDestroyDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiDestroyDriverInfoList.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType)) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetClassDevsEx(classGUID *windows.GUID, Enumerator *uint16, hwndParent uintptr, Flags DIGCF, deviceInfoSet DevInfo, machineName *uint16, reserved uintptr) (handle DevInfo, err error) { - r0, _, e1 := syscall.Syscall9(procSetupDiGetClassDevsExW.Addr(), 7, uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(Enumerator)), uintptr(hwndParent), uintptr(Flags), uintptr(deviceInfoSet), uintptr(unsafe.Pointer(machineName)), uintptr(reserved), 0, 0) - handle = DevInfo(r0) - if handle == DevInfo(windows.InvalidHandle) { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiCallClassInstaller(installFunction DI_FUNCTION, deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiCallClassInstaller.Addr(), 3, uintptr(installFunction), uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (key windows.Handle, err error) { - r0, _, e1 := syscall.Syscall6(procSetupDiOpenDevRegKey.Addr(), 6, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(Scope), uintptr(HwProfile), uintptr(KeyType), uintptr(samDesired)) - key = windows.Handle(r0) - if key == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) { - r1, _, e1 := syscall.Syscall9(procSetupDiGetDeviceRegistryPropertyW.Addr(), 7, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(property), uintptr(unsafe.Pointer(propertyRegDataType)), uintptr(unsafe.Pointer(propertyBuffer)), uintptr(propertyBufferSize), uintptr(unsafe.Pointer(requiredSize)), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffer *byte, propertyBufferSize uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiSetDeviceRegistryPropertyW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(property), uintptr(unsafe.Pointer(propertyBuffer)), uintptr(propertyBufferSize), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiGetDeviceInstallParamsW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(deviceInstallParams))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiGetDeviceInstanceIdW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(instanceId)), uintptr(instanceIdSize), uintptr(unsafe.Pointer(instanceIdRequiredSize)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiGetClassInstallParamsW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), uintptr(unsafe.Pointer(requiredSize)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiSetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiSetDeviceInstallParamsW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(deviceInstallParams))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiSetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiSetClassInstallParamsW.Addr(), 4, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiClassNameFromGuidEx(classGUID *windows.GUID, className *uint16, classNameSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiClassNameFromGuidExW.Addr(), 6, uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(className)), uintptr(classNameSize), uintptr(unsafe.Pointer(requiredSize)), uintptr(unsafe.Pointer(machineName)), uintptr(reserved)) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiClassGuidsFromNameEx(className *uint16, classGuidList *windows.GUID, classGuidListSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiClassGuidsFromNameExW.Addr(), 6, uintptr(unsafe.Pointer(className)), uintptr(unsafe.Pointer(classGuidList)), uintptr(classGuidListSize), uintptr(unsafe.Pointer(requiredSize)), uintptr(unsafe.Pointer(machineName)), uintptr(reserved)) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiGetSelectedDevice.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiSetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiSetSelectedDevice.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) { - r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_List_SizeW.Addr(), 4, uintptr(unsafe.Pointer(len)), uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(flags), 0, 0) - ret = uint32(r0) - return -} - -func cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) { - r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_ListW.Addr(), 5, uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(unsafe.Pointer(buffer)), uintptr(bufferLen), uintptr(flags), 0) - ret = uint32(r0) - return -} diff --git a/tun/wintun/setupapi/zsetupapi_windows_test.go b/tun/wintun/setupapi/zsetupapi_windows_test.go deleted file mode 100644 index 915b427..0000000 --- a/tun/wintun/setupapi/zsetupapi_windows_test.go +++ /dev/null @@ -1,20 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -import ( - "syscall" - "testing" - - "golang.org/x/sys/windows" -) - -func TestSetupDiDestroyDeviceInfoList(t *testing.T) { - err := SetupDiDestroyDeviceInfoList(DevInfo(windows.InvalidHandle)) - if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_HANDLE { - t.Errorf("SetupDiDestroyDeviceInfoList(nil, ...) should fail with ERROR_INVALID_HANDLE") - } -} diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go deleted file mode 100644 index 4c12d97..0000000 --- a/tun/wintun/wintun_windows.go +++ /dev/null @@ -1,803 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "errors" - "fmt" - "strings" - "time" - "unsafe" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" - - "golang.zx2c4.com/wireguard/tun/wintun/iphlpapi" - "golang.zx2c4.com/wireguard/tun/wintun/nci" - registryEx "golang.zx2c4.com/wireguard/tun/wintun/registry" - "golang.zx2c4.com/wireguard/tun/wintun/setupapi" -) - -type Pool string - -type Interface struct { - cfgInstanceID windows.GUID - devInstanceID string - luidIndex uint32 - ifType uint32 - pool Pool -} - -var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}} -var deviceInterfaceNetGUID = windows.GUID{Data1: 0xcac88484, Data2: 0x7515, Data3: 0x4c03, Data4: [8]byte{0x82, 0xe6, 0x71, 0xa8, 0x7a, 0xba, 0xc3, 0x61}} - -const ( - hardwareID = "Wintun" - waitForRegistryTimeout = time.Second * 10 -) - -// makeWintun creates a Wintun interface handle and populates it from the device's registry key. -func makeWintun(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) { - // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key. - key, err := devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE) - if err != nil { - return nil, fmt.Errorf("Device-specific registry key open failed: %v", err) - } - defer key.Close() - - // Read the NetCfgInstanceId value. - valueStr, err := registryEx.GetStringValue(key, "NetCfgInstanceId") - if err != nil { - return nil, fmt.Errorf("RegQueryStringValue(\"NetCfgInstanceId\") failed: %v", err) - } - - // Convert to GUID. - ifid, err := windows.GUIDFromString(valueStr) - if err != nil { - return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: %q)", valueStr) - } - - // Read the NetLuidIndex value. - luidIdx, _, err := key.GetIntegerValue("NetLuidIndex") - if err != nil { - return nil, fmt.Errorf("RegQueryValue(\"NetLuidIndex\") failed: %v", err) - } - - // Read the NetLuidIndex value. - ifType, _, err := key.GetIntegerValue("*IfType") - if err != nil { - return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err) - } - - instanceID, err := devInfo.DeviceInstanceID(devInfoData) - if err != nil { - return nil, fmt.Errorf("DeviceInstanceID failed: %v", err) - } - - return &Interface{ - cfgInstanceID: ifid, - devInstanceID: instanceID, - luidIndex: uint32(luidIdx), - ifType: uint32(ifType), - pool: pool, - }, nil -} - -func removeNumberedSuffix(ifname string) string { - removed := strings.TrimRight(ifname, "0123456789") - if removed != ifname && len(removed) > 1 && removed[len(removed)-1] == ' ' { - return removed[:len(removed)-1] - } - return ifname -} - -// GetInterface finds a Wintun interface by its name. This function returns -// the interface if found, or windows.ERROR_OBJECT_NOT_FOUND otherwise. If -// the interface is found but not a Wintun-class or a member of the pool, -// this function returns windows.ERROR_ALREADY_EXISTS. -func (pool Pool) GetInterface(ifname string) (*Interface, error) { - mutex, err := pool.takeNameMutex() - if err != nil { - return nil, err - } - defer func() { - windows.ReleaseMutex(mutex) - windows.CloseHandle(mutex) - }() - - // Create a list of network devices. - devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") - if err != nil { - return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err) - } - defer devInfo.Close() - - // Windows requires each interface to have a different name. When - // enforcing this, Windows treats interface names case-insensitive. If an - // interface "FooBar" exists and this function reports there is no - // interface "foobar", an attempt to create a new interface and name it - // "foobar" would cause conflict with "FooBar". - ifname = strings.ToLower(ifname) - - for index := 0; ; index++ { - devInfoData, err := devInfo.EnumDeviceInfo(index) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices. - property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID) - if err != nil { - continue - } - if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID { - continue - } - - wintun, err := makeWintun(devInfo, devInfoData, pool) - if err != nil { - continue - } - - // TODO: is there a better way than comparing ifnames? - ifname2, err := wintun.Name() - if err != nil { - continue - } - ifname2 = strings.ToLower(ifname2) - ifname3 := removeNumberedSuffix(ifname2) - - if ifname == ifname2 || ifname == ifname3 { - err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) - if err != nil { - return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err) - } - defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) - - for index := 0; ; index++ { - driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - // Get driver info details. - driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData) - if err != nil { - continue - } - - if driverDetailData.IsCompatible(hardwareID) { - isMember, err := pool.isMember(devInfo, devInfoData) - if err != nil { - return nil, err - } - if !isMember { - return nil, windows.ERROR_ALREADY_EXISTS - } - - return wintun, nil - } - } - - // This interface is not using Wintun driver. - return nil, windows.ERROR_ALREADY_EXISTS - } - } - - return nil, windows.ERROR_OBJECT_NOT_FOUND -} - -// CreateInterface creates a Wintun interface. ifname is the requested name of -// the interface, while requestedGUID is the GUID of the created network -// interface, which then influences NLA generation deterministically. If it is -// set to nil, the GUID is chosen by the system at random, and hence a new NLA -// entry is created for each new interface. It is called "requested" GUID -// because the API it uses is completely undocumented, and so there could be minor -// interesting complications with its usage. This function returns the network -// interface ID and a flag if reboot is required. -func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wintun *Interface, rebootRequired bool, err error) { - mutex, err := pool.takeNameMutex() - if err != nil { - return - } - defer func() { - windows.ReleaseMutex(mutex) - windows.CloseHandle(mutex) - }() - - // Create an empty device info set for network adapter device class. - devInfo, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "") - if err != nil { - err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err) - return - } - defer devInfo.Close() - - // Get the device class name from GUID. - className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "") - if err != nil { - err = fmt.Errorf("SetupDiClassNameFromGuidEx(%v) failed: %v", deviceClassNetGUID, err) - return - } - - // Create a new device info element and add it to the device info set. - deviceTypeName := pool.deviceTypeName() - devInfoData, err := devInfo.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID) - if err != nil { - err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err) - return - } - - err = setQuietInstall(devInfo, devInfoData) - if err != nil { - err = fmt.Errorf("Setting quiet installation failed: %v", err) - return - } - - // Set a device information element as the selected member of a device information set. - err = devInfo.SetSelectedDevice(devInfoData) - if err != nil { - err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err) - return - } - - // Set Plug&Play device hardware ID property. - err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_HARDWAREID, hardwareID) - if err != nil { - err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err) - return - } - - err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms - if err != nil { - err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err) - return - } - defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) - - driverDate := windows.Filetime{} - driverVersion := uint64(0) - for index := 0; ; index++ { // TODO: This loop takes ~600ms - driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - // Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match. - if driverData.IsNewer(driverDate, driverVersion) { - driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData) - if err != nil { - continue - } - - if driverDetailData.IsCompatible(hardwareID) { - err := devInfo.SetSelectedDriver(devInfoData, driverData) - if err != nil { - continue - } - - driverDate = driverData.DriverDate - driverVersion = driverData.DriverVersion - } - } - } - - if driverVersion == 0 { - err = fmt.Errorf("No driver for device %q installed", hardwareID) - return - } - - defer func() { - if err != nil { - // The interface failed to install, or the interface ID was unobtainable. Clean-up. - removeDeviceParams := setupapi.RemoveDeviceParams{ - ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE), - Scope: setupapi.DI_REMOVEDEVICE_GLOBAL, - } - - // Set class installer parameters for DIF_REMOVE. - if devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil { - // Call appropriate class installer. - if devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) == nil { - rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData) - } - } - - wintun = nil - } - }() - - // Call appropriate class installer. - err = devInfo.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, devInfoData) - if err != nil { - err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err) - return - } - - // Register device co-installers if any. (Ignore errors) - devInfo.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, devInfoData) - - var netDevRegKey registry.Key - const pollTimeout = time.Millisecond * 50 - for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ { - if i != 0 { - time.Sleep(pollTimeout) - } - netDevRegKey, err = devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY) - if err == nil { - break - } - } - if err != nil { - err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err) - return - } - defer netDevRegKey.Close() - if requestedGUID != nil { - err = netDevRegKey.SetStringValue("NetSetupAnticipatedInstanceId", requestedGUID.String()) - if err != nil { - err = fmt.Errorf("SetStringValue(NetSetupAnticipatedInstanceId) failed: %v", err) - return - } - } - - // Install interfaces if any. (Ignore errors) - devInfo.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, devInfoData) - - // Install the device. - err = devInfo.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, devInfoData) - if err != nil { - err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err) - return - } - rebootRequired = checkReboot(devInfo, devInfoData) - - err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_DEVICEDESC, deviceTypeName) - if err != nil { - err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err) - return - } - - // DIF_INSTALLDEVICE returns almost immediately, while the device installation - // continues in the background. It might take a while, before all registry - // keys and values are populated. - _, err = registryEx.GetStringValueWait(netDevRegKey, "NetCfgInstanceId", waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("GetStringValueWait(NetCfgInstanceId) failed: %v", err) - return - } - _, err = registryEx.GetIntegerValueWait(netDevRegKey, "NetLuidIndex", waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("GetIntegerValueWait(NetLuidIndex) failed: %v", err) - return - } - _, err = registryEx.GetIntegerValueWait(netDevRegKey, "*IfType", waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("GetIntegerValueWait(*IfType) failed: %v", err) - return - } - - // Get network interface. - wintun, err = makeWintun(devInfo, devInfoData, pool) - if err != nil { - err = fmt.Errorf("makeWintun failed: %v", err) - return - } - - // Wait for TCP/IP adapter registry key to emerge and populate. - tcpipAdapterRegKey, err := registryEx.OpenKeyWait( - registry.LOCAL_MACHINE, - wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE|registry.NOTIFY, - waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", wintun.tcpipAdapterRegKeyName(), err) - return - } - defer tcpipAdapterRegKey.Close() - _, err = registryEx.GetStringValueWait(tcpipAdapterRegKey, "IpConfig", waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("GetStringValueWait(IpConfig) failed: %v", err) - return - } - - tcpipInterfaceRegKeyName, err := wintun.tcpipInterfaceRegKeyName() - if err != nil { - err = fmt.Errorf("tcpipInterfaceRegKeyName failed: %v", err) - return - } - - // Wait for TCP/IP interface registry key to emerge. - tcpipInterfaceRegKey, err := registryEx.OpenKeyWait( - registry.LOCAL_MACHINE, - tcpipInterfaceRegKeyName, registry.QUERY_VALUE|registry.SET_VALUE, - waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", tcpipInterfaceRegKeyName, err) - return - } - defer tcpipInterfaceRegKey.Close() - // Disable dead gateway detection on our interface. - tcpipInterfaceRegKey.SetDWordValue("EnableDeadGWDetect", 0) - - err = wintun.SetName(ifname) - if err != nil { - err = fmt.Errorf("Unable to set name of Wintun interface: %v", err) - return - } - - return -} - -// DeleteInterface deletes a Wintun interface. This function succeeds -// if the interface was not found. It returns a bool indicating whether -// a reboot is required. -func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) { - devInfo, devInfoData, err := wintun.devInfoData() - if err == windows.ERROR_OBJECT_NOT_FOUND { - return false, nil - } - if err != nil { - return false, err - } - defer devInfo.Close() - - // Remove the device. - removeDeviceParams := setupapi.RemoveDeviceParams{ - ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE), - Scope: setupapi.DI_REMOVEDEVICE_GLOBAL, - } - - // Set class installer parameters for DIF_REMOVE. - err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) - if err != nil { - return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err) - } - - // Call appropriate class installer. - err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) - if err != nil { - return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err) - } - - return checkReboot(devInfo, devInfoData), nil -} - -// DeleteMatchingInterfaces deletes all Wintun interfaces, which match -// given criteria, and returns which ones it deleted, whether a reboot -// is required after, and which errors occurred during the process. -func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool) (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) { - mutex, err := pool.takeNameMutex() - if err != nil { - errors = append(errors, err) - return - } - defer func() { - windows.ReleaseMutex(mutex) - windows.CloseHandle(mutex) - }() - - devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") - if err != nil { - return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())} - } - defer devInfo.Close() - - for i := 0; ; i++ { - devInfoData, err := devInfo.EnumDeviceInfo(i) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices. - property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID) - if err != nil { - continue - } - if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID { - continue - } - - err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) - if err != nil { - continue - } - defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) - - isWintun := false - for j := 0; ; j++ { - driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, j) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData) - if err != nil { - continue - } - if driverDetailData.IsCompatible(hardwareID) { - isWintun = true - break - } - } - if !isWintun { - continue - } - - isMember, err := pool.isMember(devInfo, devInfoData) - if err != nil { - errors = append(errors, err) - continue - } - if !isMember { - continue - } - - wintun, err := makeWintun(devInfo, devInfoData, pool) - if err != nil { - errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err)) - continue - } - if !matches(wintun) { - continue - } - - err = setQuietInstall(devInfo, devInfoData) - if err != nil { - errors = append(errors, err) - continue - } - - inst := devInfoData.DevInst - removeDeviceParams := setupapi.RemoveDeviceParams{ - ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE), - Scope: setupapi.DI_REMOVEDEVICE_GLOBAL, - } - err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) - if err != nil { - errors = append(errors, err) - continue - } - err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) - if err != nil { - errors = append(errors, err) - continue - } - rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData) - deviceInstancesDeleted = append(deviceInstancesDeleted, inst) - } - return -} - -// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name. -func (pool Pool) isMember(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) (bool, error) { - deviceDescVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_DEVICEDESC) - if err != nil { - return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err) - } - deviceDesc, _ := deviceDescVal.(string) - friendlyNameVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_FRIENDLYNAME) - if err != nil { - return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err) - } - friendlyName, _ := friendlyNameVal.(string) - deviceTypeName := pool.deviceTypeName() - return friendlyName == deviceTypeName || deviceDesc == deviceTypeName || - removeNumberedSuffix(friendlyName) == deviceTypeName || removeNumberedSuffix(deviceDesc) == deviceTypeName, nil -} - -// checkReboot checks device install parameters if a system reboot is required. -func checkReboot(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) bool { - devInstallParams, err := devInfo.DeviceInstallParams(devInfoData) - if err != nil { - return false - } - - return (devInstallParams.Flags & (setupapi.DI_NEEDREBOOT | setupapi.DI_NEEDRESTART)) != 0 -} - -// setQuietInstall sets device install parameters for a quiet installation -func setQuietInstall(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) error { - devInstallParams, err := devInfo.DeviceInstallParams(devInfoData) - if err != nil { - return err - } - - devInstallParams.Flags |= setupapi.DI_QUIETINSTALL - return devInfo.SetDeviceInstallParams(devInfoData, devInstallParams) -} - -// deviceTypeName returns pool-specific device type name. -func (pool Pool) deviceTypeName() string { - return fmt.Sprintf("%s Tunnel", pool) -} - -// Name returns the name of the Wintun interface. -func (wintun *Interface) Name() (string, error) { - return nci.ConnectionName(&wintun.cfgInstanceID) -} - -// SetName sets name of the Wintun interface. -func (wintun *Interface) SetName(ifname string) error { - const maxSuffix = 1000 - availableIfname := ifname - for i := 0; ; i++ { - err := nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname) - if err == windows.ERROR_DUP_NAME { - duplicateGuid, err2 := iphlpapi.InterfaceGUIDFromAlias(availableIfname) - if err2 == nil { - for j := 0; j < maxSuffix; j++ { - proposal := fmt.Sprintf("%s %d", ifname, j+1) - if proposal == availableIfname { - continue - } - err2 = nci.SetConnectionName(duplicateGuid, proposal) - if err2 == windows.ERROR_DUP_NAME { - continue - } - if err2 == nil { - err = nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname) - if err == nil { - break - } - } - break - } - } - } - if err == nil { - break - } - - if i > maxSuffix || err != windows.ERROR_DUP_NAME { - return fmt.Errorf("NciSetConnectionName failed: %v", err) - } - availableIfname = fmt.Sprintf("%s %d", ifname, i+1) - } - - // TODO: This should use NetSetup2 so that it doesn't get unset. - deviceRegKey, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.deviceRegKeyName(), registry.SET_VALUE) - if err != nil { - return fmt.Errorf("Device-level registry key open failed: %v", err) - } - defer deviceRegKey.Close() - err = deviceRegKey.SetStringValue("FriendlyName", wintun.pool.deviceTypeName()) - if err != nil { - return fmt.Errorf("SetStringValue(FriendlyName) failed: %v", err) - } - return nil -} - -// tcpipAdapterRegKeyName returns the adapter-specific TCP/IP network registry key name. -func (wintun *Interface) tcpipAdapterRegKeyName() string { - return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID) -} - -// deviceRegKeyName returns the device-level registry key name. -func (wintun *Interface) deviceRegKeyName() string { - return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Enum\\%v", wintun.devInstanceID) -} - -// Version returns the version of the Wintun driver and NDIS system currently loaded. -func (wintun *Interface) Version() (driverVersion string, ndisVersion string, err error) { - key, err := registry.OpenKey(registry.LOCAL_MACHINE, "SYSTEM\\CurrentControlSet\\Services\\Wintun", registry.QUERY_VALUE) - if err != nil { - return - } - defer key.Close() - driverMajor, _, err := key.GetIntegerValue("DriverMajorVersion") - if err != nil { - return - } - driverMinor, _, err := key.GetIntegerValue("DriverMinorVersion") - if err != nil { - return - } - ndisMajor, _, err := key.GetIntegerValue("NdisMajorVersion") - if err != nil { - return - } - ndisMinor, _, err := key.GetIntegerValue("NdisMinorVersion") - if err != nil { - return - } - driverVersion = fmt.Sprintf("%d.%d", driverMajor, driverMinor) - ndisVersion = fmt.Sprintf("%d.%d", ndisMajor, ndisMinor) - return -} - -// tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name. -func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) { - key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE) - if err != nil { - return "", fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err) - } - paths, _, err := key.GetStringsValue("IpConfig") - key.Close() - if err != nil { - return "", fmt.Errorf("Error reading IpConfig registry key: %v", err) - } - if len(paths) == 0 { - return "", errors.New("No TCP/IP interfaces found on adapter") - } - return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil -} - -// devInfoData returns TUN device info list handle and interface device info -// data. The device info list handle must be closed after use. In case the -// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned. -func (wintun *Interface) devInfoData() (setupapi.DevInfo, *setupapi.DevInfoData, error) { - // Create a list of network devices. - devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") - if err != nil { - return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error()) - } - - for index := 0; ; index++ { - devInfoData, err := devInfo.EnumDeviceInfo(index) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - // Get interface ID. - // TODO: Store some ID in the Wintun object such that this call isn't required. - wintun2, err := makeWintun(devInfo, devInfoData, wintun.pool) - if err != nil { - continue - } - - if wintun.cfgInstanceID == wintun2.cfgInstanceID { - err = setQuietInstall(devInfo, devInfoData) - if err != nil { - devInfo.Close() - return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err) - } - return devInfo, devInfoData, nil - } - } - - devInfo.Close() - return 0, nil, windows.ERROR_OBJECT_NOT_FOUND -} - -// handle returns a handle to the interface device object. -func (wintun *Interface) handle() (windows.Handle, error) { - interfaces, err := setupapi.CM_Get_Device_Interface_List(wintun.devInstanceID, &deviceInterfaceNetGUID, setupapi.CM_GET_DEVICE_INTERFACE_LIST_PRESENT) - if err != nil { - return windows.InvalidHandle, fmt.Errorf("Error listing NDIS interfaces: %v", err) - } - handle, err := windows.CreateFile(windows.StringToUTF16Ptr(interfaces[0]), windows.GENERIC_READ|windows.GENERIC_WRITE, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, 0, 0) - if err != nil { - return windows.InvalidHandle, fmt.Errorf("Error opening NDIS device: %v", err) - } - return handle, nil -} - -// GUID returns the GUID of the interface. -func (wintun *Interface) GUID() windows.GUID { - return wintun.cfgInstanceID -} - -// LUID returns the LUID of the interface. -func (wintun *Interface) LUID() uint64 { - return ((uint64(wintun.luidIndex) & ((1 << 24) - 1)) << 24) | ((uint64(wintun.ifType) & ((1 << 16) - 1)) << 48) -} |